Skip to content
Published on

KV Cache 最適化 深層分析: GQA・MLA・MHA アテンションメカニズムとメモリ効率化戦略

Authors
  • Name
    Twitter
KV Cache Optimization: GQA, MLA, MHA Attention Mechanisms

はじめに

大規模言語モデル(LLM)の推論コストにおける最大のボトルネックは、KV Cache(Key-Value Cache) のメモリ消費である。GPT-4クラスのモデルが128Kコンテキスト長で同時に数百のリクエストを処理する場合、KV Cacheだけで数百GBのGPUメモリを消費する可能性がある。このメモリ制約は、スループットとレスポンスレイテンシの直接的な制限要因となる。

TransformerのSelf-Attentionメカニズムにおいて、デコーディングの各ステップで以前のすべてのトークンのKeyとValueベクトルを再計算することは非効率的であるため、これをキャッシュに保存して再利用するのがKV Cacheの基本原理である。問題は、シーケンス長が長くなるほどこのキャッシュのサイズが線形に増加することである。

この問題を解決するために、様々なアテンションメカニズムが提案されてきた。**Multi-Query Attention(MQA)**はすべてのQueryヘッドが1つのKVヘッドを共有してメモリを劇的に削減するが、品質低下が発生した。**Grouped Query Attention(GQA)**はMHAとMQAの折衷案としてLlama 2/3に採用された。**Multi-Head Latent Attention(MLA)**はDeepSeek-V2/V3でKVを低次元潜在空間に圧縮して驚くべき効率を達成した。

本記事では、各アテンションメカニズムの数学的原理とメモリ分析、KV Cache圧縮技法、PagedAttention(vLLM)、PyTorch実装例、実際のOOM障害事例と復旧、最適化チェックリストを総合的に解説する。

Transformer Self-AttentionとKV Cache基礎

Self-Attention演算

TransformerのScaled Dot-Product Attentionは以下のように定義される。

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

ここで、Q(Query)、K(Key)、V(Value)は入力トークンの埋め込みを線形変換して得られる。

import torch
import torch.nn as nn
import math

class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_k = d_model // n_heads
        self.n_heads = n_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, kv_cache=None):
        B, T, C = x.shape
        # Q, K, V プロジェクション
        q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        k = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        v = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)

        # KV Cache 活用
        if kv_cache is not None:
            k_cache, v_cache = kv_cache
            k = torch.cat([k_cache, k], dim=2)
            v = torch.cat([v_cache, v], dim=2)

        # Attention 演算
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn = torch.softmax(scores, dim=-1)
        output = torch.matmul(attn, v)

        output = output.transpose(1, 2).contiguous().view(B, T, C)
        return self.W_o(output), (k, v)

KV Cache メモリ分析

KV Cacheのメモリ使用量は以下の式で計算される。

KV Cache メモリ = 2(KとV)x レイヤー数(L)x ヘッド数(H)x シーケンス長(S)x ヘッド次元(d_k)x バイトサイズ

例えば、Llama 2 70Bモデルの場合:

  • レイヤー数:80
  • ヘッド数:64(GQA適用前)
  • ヘッド次元:128
  • シーケンス長:4096
  • FP16(2バイト)

KV Cache = 2 x 80 x 64 x 4096 x 128 x 2 = 10.7 GB(単一リクエスト)

バッチサイズ32の場合、342.4 GBとなり、モデル重みのメモリをはるかに超過する。

Multi-Head Attention(MHA)メモリ分析

MHA構造

標準のMulti-Head Attentionでは、各アテンションヘッドが独立したQ、K、Vプロジェクションを持つ。H個のヘッドがそれぞれd_k次元のK、Vを維持するため、KV Cacheは最大サイズとなる。

class MultiHeadAttention(nn.Module):
    """標準 Multi-Head Attention(MHA)
    各ヘッドが独立したKVペアを維持"""
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)  # H * d_k 次元
        self.W_v = nn.Linear(d_model, d_model)  # H * d_k 次元
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, kv_cache=None):
        B, T, _ = x.shape

        q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        k = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        v = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)

        if kv_cache is not None:
            k = torch.cat([kv_cache[0], k], dim=2)
            v = torch.cat([kv_cache[1], v], dim=2)

        # KV Cache サイズ: [B, n_heads, S, d_k] x 2
        # メモリ: 2 * B * n_heads * S * d_k * sizeof(dtype)
        attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn = torch.softmax(attn, dim=-1)
        out = torch.matmul(attn, v)

        out = out.transpose(1, 2).contiguous().view(B, T, -1)
        return self.W_o(out), (k, v)

MHA KV Cache per token per layer = 2 x H x d_k x sizeof(dtype)

Multi-Query Attention(MQA)

MQA構造と削減効果

MQAはShazeer(2019)が提案した方法で、すべてのQueryヘッドが単一のKVヘッドを共有する。KV CacheがH倍削減されるが、品質低下が観察される。

class MultiQueryAttention(nn.Module):
    """Multi-Query Attention(MQA)
    すべてのQueryヘッドが単一KVヘッドを共有"""
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.W_q = nn.Linear(d_model, d_model)       # H * d_k
        self.W_k = nn.Linear(d_model, self.d_k)       # 1 * d_k(単一ヘッド)
        self.W_v = nn.Linear(d_model, self.d_k)       # 1 * d_k(単一ヘッド)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, kv_cache=None):
        B, T, _ = x.shape

        q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        k = self.W_k(x).view(B, T, 1, self.d_k).transpose(1, 2)
        v = self.W_v(x).view(B, T, 1, self.d_k).transpose(1, 2)

        if kv_cache is not None:
            k = torch.cat([kv_cache[0], k], dim=2)
            v = torch.cat([kv_cache[1], v], dim=2)

        # K、Vをすべてのヘッドにブロードキャスト
        k = k.expand(-1, self.n_heads, -1, -1)
        v = v.expand(-1, self.n_heads, -1, -1)

        # KV Cache サイズ: [B, 1, S, d_k] x 2
        # MHA比 1/H に縮小
        attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn = torch.softmax(attn, dim=-1)
        out = torch.matmul(attn, v)

        out = out.transpose(1, 2).contiguous().view(B, T, -1)
        return self.W_o(out), (k[:, :1], v[:, :1])

MQA KV Cache per token per layer = 2 x 1 x d_k x sizeof(dtype)、MHA比1/H縮小

Grouped Query Attention(GQA)

GQAアーキテクチャ(Llama 2/3)

GQAはAinslie et al.(2023)が提案し、Llama 2、Llama 3に採用された方法で、QueryヘッドをG個のグループに分け、各グループが1つのKVヘッドを共有する。G=1ならMQA、G=HならMHAと同等である。

class GroupedQueryAttention(nn.Module):
    """Grouped Query Attention(GQA)
    QueryヘッドをG個のグループに分け、各グループがKVを共有
    Llama 2/3で採用"""
    def __init__(self, d_model, n_heads, n_kv_heads):
        super().__init__()
        self.n_heads = n_heads        # Queryヘッド数
        self.n_kv_heads = n_kv_heads  # KVヘッド数(グループ数)
        self.d_k = d_model // n_heads
        self.n_rep = n_heads // n_kv_heads  # グループあたりのQueryヘッド数

        self.W_q = nn.Linear(d_model, n_heads * self.d_k)
        self.W_k = nn.Linear(d_model, n_kv_heads * self.d_k)
        self.W_v = nn.Linear(d_model, n_kv_heads * self.d_k)
        self.W_o = nn.Linear(d_model, d_model)

    def repeat_kv(self, x):
        """KVヘッドをQueryヘッド数に合わせて繰り返す"""
        B, H_kv, S, D = x.shape
        if self.n_rep == 1:
            return x
        return (
            x[:, :, None, :, :]
            .expand(B, H_kv, self.n_rep, S, D)
            .reshape(B, self.n_heads, S, D)
        )

    def forward(self, x, kv_cache=None):
        B, T, _ = x.shape

        q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        k = self.W_k(x).view(B, T, self.n_kv_heads, self.d_k).transpose(1, 2)
        v = self.W_v(x).view(B, T, self.n_kv_heads, self.d_k).transpose(1, 2)

        if kv_cache is not None:
            k = torch.cat([kv_cache[0], k], dim=2)
            v = torch.cat([kv_cache[1], v], dim=2)

        # KV Cache サイズ: [B, n_kv_heads, S, d_k] x 2
        new_cache = (k, v)

        # ブロードキャストのためにKVヘッドを繰り返す
        k = self.repeat_kv(k)
        v = self.repeat_kv(v)

        attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn = torch.softmax(attn, dim=-1)
        out = torch.matmul(attn, v)

        out = out.transpose(1, 2).contiguous().view(B, T, -1)
        return self.W_o(out), new_cache

Llama 3 70BのGQA設定:n_heads=64、n_kv_heads=8。KV CacheがMHA比1/8に縮小されながらも品質低下がほとんどない。

Multi-Head Latent Attention(MLA)

MLAアーキテクチャ(DeepSeek-V2/V3)

MLAはDeepSeek-V2(2024)で提案された革新的な方法で、KVベクトルを低次元の**潜在ベクトル(latent vector)**に圧縮してキャッシュに保存する。デコーディング時に潜在ベクトルからK、Vを復元する。

核心アイデアは以下の通りである:

  1. 入力xを低次元潜在ベクトルcに圧縮:c = W_down * x
  2. 潜在ベクトルからK、Vを復元:K = W_uk _ c、V = W_uv _ c
  3. キャッシュには低次元cのみ保存
class MultiHeadLatentAttention(nn.Module):
    """Multi-Head Latent Attention(MLA)
    DeepSeek-V2/V3で使用
    KVを低次元潜在空間に圧縮してキャッシュ"""
    def __init__(self, d_model, n_heads, d_latent, rope_dim=64):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.d_latent = d_latent  # 潜在次元(d_modelよりはるかに小さい)
        self.rope_dim = rope_dim

        # Queryプロジェクション
        self.W_q = nn.Linear(d_model, n_heads * self.d_k)

        # KV圧縮(Down-projection)
        self.W_dkv = nn.Linear(d_model, d_latent)

        # KV復元(Up-projection)
        self.W_uk = nn.Linear(d_latent, n_heads * self.d_k)
        self.W_uv = nn.Linear(d_latent, n_heads * self.d_k)

        # RoPE用の別途プロジェクション
        self.W_kr = nn.Linear(d_model, rope_dim)

        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, kv_cache=None):
        B, T, _ = x.shape

        # Query
        q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)

        # KV圧縮: d_model -> d_latent
        c_kv = self.W_dkv(x)  # [B, T, d_latent]

        # RoPEキー
        k_rope = self.W_kr(x)  # [B, T, rope_dim]

        if kv_cache is not None:
            c_kv_cached, k_rope_cached = kv_cache
            c_kv = torch.cat([c_kv_cached, c_kv], dim=1)
            k_rope = torch.cat([k_rope_cached, k_rope], dim=1)

        # KV Cache: 低次元c_kvとk_ropeのみ保存
        # メモリ: B * S * (d_latent + rope_dim) * sizeof(dtype)
        # MHA比: d_latent / (2 * n_heads * d_k) 比率で縮小
        new_cache = (c_kv, k_rope)

        # KV復元: d_latent -> n_heads * d_k
        S = c_kv.shape[1]
        k = self.W_uk(c_kv).view(B, S, self.n_heads, self.d_k).transpose(1, 2)
        v = self.W_uv(c_kv).view(B, S, self.n_heads, self.d_k).transpose(1, 2)

        # Attention 演算
        attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn = torch.softmax(attn, dim=-1)
        out = torch.matmul(attn, v)

        out = out.transpose(1, 2).contiguous().view(B, T, -1)
        return self.W_o(out), new_cache

DeepSeek-V2のMLA設定:d_model=5120、d_latent=512、n_heads=128。KV CacheがMHA比約93.7%縮小されながらもMHAと同等の性能を達成した。

MHA vs MQA vs GQA vs MLA 比較表

項目MHAMQAGQAMLA
KVヘッド数H1G (1 < G < H)-(潜在ベクトル)
キャッシュ per token per layer2Hd_k2d_k2Gd_kd_latent + d_rope
MHA比キャッシュ比率100%1/HG/Hd_latent/(2Hd_k)
Llama 3 70B例(H=64, G=8)16,384B256B2,048B-
DeepSeek-V2例163,840B--約4,608B
品質への影響ベースライン若干低下ほぼ同等ほぼ同等
推論速度ベースライン最速良好良好(復元コスト)
学習安定性普通
代表モデルGPT-3, BERTPaLM, FalconLlama 2/3, MistralDeepSeek-V2/V3
実装複雑度

KV Cache 圧縮技法

量子化(Quantization)

KV CacheをFP16からINT8やINT4に量子化すると、メモリを2-4倍削減できる。

class QuantizedKVCache:
    """KV Cache 量子化
    FP16 -> INT8変換でメモリ50%削減"""
    def __init__(self, n_layers, n_heads, max_seq_len, d_k, dtype=torch.int8):
        self.n_layers = n_layers
        self.scales = {}  # 量子化スケールファクター
        self.zero_points = {}
        self.cache_k = {}
        self.cache_v = {}

    def quantize(self, tensor):
        """Per-channel INT8量子化"""
        min_val = tensor.min(dim=-1, keepdim=True)[0]
        max_val = tensor.max(dim=-1, keepdim=True)[0]
        scale = (max_val - min_val) / 255.0
        zero_point = (-min_val / scale).round().clamp(0, 255)
        quantized = ((tensor / scale) + zero_point).round().clamp(0, 255).to(torch.uint8)
        return quantized, scale, zero_point

    def dequantize(self, quantized, scale, zero_point):
        """INT8 -> FP16 逆量子化"""
        return (quantized.float() - zero_point) * scale

    def update(self, layer_idx, k, v):
        q_k, s_k, z_k = self.quantize(k)
        q_v, s_v, z_v = self.quantize(v)
        if layer_idx in self.cache_k:
            self.cache_k[layer_idx] = torch.cat([self.cache_k[layer_idx], q_k], dim=2)
            self.cache_v[layer_idx] = torch.cat([self.cache_v[layer_idx], q_v], dim=2)
        else:
            self.cache_k[layer_idx] = q_k
            self.cache_v[layer_idx] = q_v
        self.scales[layer_idx] = (s_k, s_v)
        self.zero_points[layer_idx] = (z_k, z_v)

    def get(self, layer_idx):
        k = self.dequantize(
            self.cache_k[layer_idx],
            self.scales[layer_idx][0],
            self.zero_points[layer_idx][0]
        )
        v = self.dequantize(
            self.cache_v[layer_idx],
            self.scales[layer_idx][1],
            self.zero_points[layer_idx][1]
        )
        return k, v

エビクションポリシー(Eviction Policy)

シーケンスが長くなると古いトークンのKVを選択的に削除する戦略である。

**H2O(Heavy-Hitter Oracle)**方式はアテンションスコアが高いトークンのKVのみを保持する。

class H2OKVCache:
    """Heavy-Hitter Oracle(H2O)KV Cache エビクション
    アテンションスコアが高いトークンのみ保持"""
    def __init__(self, max_cache_size, n_heads, d_k):
        self.max_cache_size = max_cache_size
        self.attention_scores = None

    def update(self, k, v, attn_weights):
        """アテンション重みに基づいて重要度を蓄積しエビクションを実行"""
        B, H, S_new, _ = k.shape

        if self.attention_scores is None:
            self.attention_scores = attn_weights.sum(dim=2)  # [B, H, S]
        else:
            # 新しいトークンに対するアテンションスコアを蓄積
            new_scores = attn_weights.sum(dim=2)
            self.attention_scores = torch.cat(
                [self.attention_scores, new_scores[:, :, -S_new:]], dim=2
            )

        # キャッシュサイズ超過時にエビクション
        current_size = self.attention_scores.shape[2]
        if current_size > self.max_cache_size:
            # 重要度が低いトークンを識別(最初のトークンは常に保持)
            scores = self.attention_scores[:, :, 1:]  # 最初のトークンを除外
            _, indices = scores.topk(
                self.max_cache_size - 1, dim=2, sorted=False
            )
            indices = indices + 1  # オフセット補正
            # 最初のトークンインデックスを追加
            first_token = torch.zeros(B, H, 1, dtype=torch.long, device=k.device)
            keep_indices = torch.cat([first_token, indices], dim=2)

            # 選択されたトークンのみ保持
            k = k.gather(2, keep_indices.unsqueeze(-1).expand(-1, -1, -1, k.shape[-1]))
            v = v.gather(2, keep_indices.unsqueeze(-1).expand(-1, -1, -1, v.shape[-1]))
            self.attention_scores = self.attention_scores.gather(2, keep_indices)

        return k, v

Sliding Window Attention

Mistralなどのモデルで使用される固定サイズウィンドウベースのアテンションである。

class SlidingWindowAttention(nn.Module):
    """Sliding Window Attention
    最近W個のトークンにのみアテンションを適用
    Mistral、Gemmaなどで使用"""
    def __init__(self, d_model, n_heads, window_size=4096):
        super().__init__()
        self.window_size = window_size
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, kv_cache=None):
        B, T, _ = x.shape
        q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        k = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        v = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)

        if kv_cache is not None:
            k = torch.cat([kv_cache[0], k], dim=2)
            v = torch.cat([kv_cache[1], v], dim=2)

        # ウィンドウサイズでKV Cacheを制限
        if k.shape[2] > self.window_size:
            k = k[:, :, -self.window_size:]
            v = v[:, :, -self.window_size:]

        attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn = torch.softmax(attn, dim=-1)
        out = torch.matmul(attn, v)

        out = out.transpose(1, 2).contiguous().view(B, T, -1)
        return self.W_o(out), (k, v)

PagedAttention(vLLM)

仮想メモリベースKV Cache管理

vLLMのPagedAttentionは、OSの仮想メモリシステムからインスピレーションを得て、KV Cacheを固定サイズブロック(ページ) 単位で管理する。これにより、連続メモリ割り当ての非効率性(内部断片化、外部断片化)を解消する。

class PagedKVCache:
    """vLLM PagedAttention 概念実装
    KV Cacheをページ単位で管理"""
    def __init__(self, block_size=16, n_blocks=1024, n_heads=32, d_k=128):
        self.block_size = block_size
        self.n_blocks = n_blocks
        self.n_heads = n_heads
        self.d_k = d_k

        # 物理ブロックプール(事前割り当て)
        self.k_blocks = torch.zeros(n_blocks, n_heads, block_size, d_k)
        self.v_blocks = torch.zeros(n_blocks, n_heads, block_size, d_k)

        # 空きブロック管理
        self.free_blocks = list(range(n_blocks))

        # シーケンスごとのブロックテーブル(仮想 -> 物理マッピング)
        self.block_tables = {}  # seq_id -> list of physical block indices
        self.seq_lengths = {}

    def allocate(self, seq_id):
        """新しいシーケンスに最初のブロックを割り当て"""
        if not self.free_blocks:
            raise RuntimeError("OOM: No free blocks available")
        block_idx = self.free_blocks.pop(0)
        self.block_tables[seq_id] = [block_idx]
        self.seq_lengths[seq_id] = 0

    def append_token(self, seq_id, k, v):
        """シーケンスに新しいトークンのKVを追加"""
        seq_len = self.seq_lengths[seq_id]
        block_idx_in_seq = seq_len // self.block_size
        offset = seq_len % self.block_size

        # 現在のブロックが満杯の場合、新しいブロックを割り当て
        if offset == 0 and block_idx_in_seq > 0:
            if not self.free_blocks:
                raise RuntimeError("OOM: No free blocks available")
            new_block = self.free_blocks.pop(0)
            self.block_tables[seq_id].append(new_block)

        # 物理ブロックにKVを保存
        physical_block = self.block_tables[seq_id][block_idx_in_seq]
        self.k_blocks[physical_block, :, offset] = k
        self.v_blocks[physical_block, :, offset] = v
        self.seq_lengths[seq_id] += 1

    def get_kv(self, seq_id):
        """シーケンスの全KV Cacheを組み立て"""
        blocks = self.block_tables[seq_id]
        seq_len = self.seq_lengths[seq_id]

        k_list, v_list = [], []
        for i, block_idx in enumerate(blocks):
            if i == len(blocks) - 1:
                # 最後のブロックは実際のトークン数分のみ
                remaining = seq_len % self.block_size or self.block_size
                k_list.append(self.k_blocks[block_idx, :, :remaining])
                v_list.append(self.v_blocks[block_idx, :, :remaining])
            else:
                k_list.append(self.k_blocks[block_idx])
                v_list.append(self.v_blocks[block_idx])

        return torch.cat(k_list, dim=1), torch.cat(v_list, dim=1)

    def free(self, seq_id):
        """シーケンス完了後にブロックを返却"""
        for block_idx in self.block_tables[seq_id]:
            self.free_blocks.append(block_idx)
        del self.block_tables[seq_id]
        del self.seq_lengths[seq_id]

PagedAttentionの主要な利点は以下の通りである。

  • メモリ断片化の解消: 連続メモリが不要なため、外部断片化がない
  • メモリ使用効率の向上: 実際に使用する分だけブロックを割り当てる。vLLMは従来比最大24倍のスループットを達成
  • Copy-on-Write: ビームサーチやパラレルサンプリングでKV Cacheを共有してメモリ節約
  • プレフィックスキャッシング: 共通プレフィックス(システムプロンプト)のKV Cacheを複数リクエスト間で共有

障害事例と復旧手順

事例1: Long Context推論中のOOM発生

状況: Llama 3 70Bモデルで128Kコンテキスト長のドキュメント要約を試みたが、バッチサイズ4でOOMが発生し、推論サービス全体が停止した。

メモリ分析:

  • モデル重み(FP16):約140 GB
  • リクエストあたりKV Cache(GQA, H=64, G=8):2 x 80 x 8 x 128K x 128 x 2 = 約33.5 GB
  • バッチ4合計KV Cache:約134 GB
  • 必要総メモリ:約274 GB(8x A100 80GB = 640 GBでも活性化メモリと一時バッファを考慮すると危険)

復旧手順:

# 1. vLLMサーバー再起動時にメモリ制限を設定
python -m vllm.entrypoints.openai.api_server \
  --model meta-llama/Llama-3-70B \
  --tensor-parallel-size 8 \
  --max-model-len 65536 \
  --gpu-memory-utilization 0.9 \
  --max-num-seqs 8 \
  --enforce-eager

# 2. KV Cache量子化を有効化
python -m vllm.entrypoints.openai.api_server \
  --model meta-llama/Llama-3-70B \
  --tensor-parallel-size 8 \
  --kv-cache-dtype fp8_e5m2 \
  --max-model-len 128000

# 3. 動的バッチサイズ制限
# max-num-seqsを減らして同時処理リクエスト数を制限

事例2: KV Cache量子化による品質低下

状況: 推論コスト削減のためKV CacheをINT4で量子化したが、長い会話で応答品質が急激に低下した。特に数学計算とコード生成タスクでエラー率が増加した。

症状:

  • 会話10ターン以上でハルシネーション増加
  • 数学問題の精度15%低下
  • コード生成時に構文エラーが頻発

復旧手順:

# 1. INT4 -> INT8に量子化精度を引き上げ
# vLLMではkv-cache-dtypeを変更
# --kv-cache-dtype fp8_e5m2  (FP8を使用)

# 2. 重要レイヤーのみ高精度を維持(混合量子化)
# 初期レイヤーと最終レイヤーはFP16、中間レイヤーはINT8
kv_cache_config = {
    "default_dtype": "int8",
    "high_precision_layers": [0, 1, 2, 77, 78, 79],  # 最初/最後の3レイヤー
    "high_precision_dtype": "fp16"
}

# 3. ベンチマーク再実行で品質確認
# MMLU、HumanEval、GSM8Kなど標準ベンチマークで検証

事例3: Prefix Cachingエラーによる応答混線

状況: vLLMのプレフィックスキャッシング機能を有効化した後、異なるユーザーのリクエストに対して誤ったコンテキストが適用される問題が発生した。システムプロンプトが同一の異なるセッションでKV Cacheが誤って共有された。

復旧手順:

# 1. プレフィックスキャッシングを一時無効化
python -m vllm.entrypoints.openai.api_server \
  --model meta-llama/Llama-3-70B \
  --enable-prefix-caching false

# 2. プレフィックスハッシュキーにセッションIDを含めるよう修正
# 同一プレフィックスでも異なるセッションは別キャッシュを使用

# 3. プレフィックスキャッシング再有効化時にTTLを設定
# --prefix-cache-ttl 300  (5分後に自動期限切れ)

最適化チェックリスト

モデル選択と設定

  • モデルのアテンションメカニズムを確認(MHA/MQA/GQA/MLA)
  • GQAモデル使用時にn_kv_headsの最適値を検証(通常n_heads/8)
  • 最大シーケンス長を実際の使用パターンに合わせて制限

KV Cacheメモリ管理

  • GPUメモリに対するKV Cache割り当て比率を計算(全体の60-80%推奨)
  • KV Cache量子化の適用可否を決定(品質/メモリトレードオフ:FP8 > INT8 > INT4)
  • PagedAttention導入可否を決定(vLLM、TensorRT-LLMなど)
  • プレフィックスキャッシング有効化時のキー衝突防止を検証

推論パフォーマンス最適化

  • Continuous Batchingを有効化(リクエスト完了時に即座にスロット再利用)
  • Speculative Decoding導入可否を検討(ドラフトモデルでKV Cache活用を最大化)
  • Sliding Window Attention適用可否を評価(長文書要約など)
  • Tensor Parallelism vs Pipeline Parallelismの選択

モニタリング

  • KV Cache使用率メトリクスの収集(vLLM: vllm:cache_usage_percent
  • リクエストごとのKV Cacheメモリ使用量の追跡
  • OOMイベント自動アラートの設定
  • バッチサイズごとのスループットとレイテンシのプロファイリング

運用上の注意事項

  • 同時リクエスト数の上限を設定(max-num-seqs)
  • GPUメモリ使用率の上限を設定(gpu-memory-utilization)
  • OOM発生時のグレースフルデグラデーションを実装(リクエストキューイング、バッチ縮小)
  • 定期的なKV Cacheプロファイリングでメモリリークを確認

まとめ

KV Cache最適化はLLM推論効率の核心である。MHAからMQA、GQA、MLAへと続くアテンションメカニズムの発展は、モデル品質を維持しながらKV Cacheメモリを劇的に削減することに成功した。GQAはLlamaシリーズで実証された実用的なアプローチであり、MLAはDeepSeek-V2/V3で示されたように、さらに積極的な圧縮を可能にする。

KV Cache量子化、エビクションポリシー、Sliding Window Attentionはメモリ制約下で長いシーケンスを処理するための追加技法であり、PagedAttentionはメモリ管理自体を革新して推論システムのスループットを最大化する。

実際のプロダクション環境では、これらの技法を組み合わせて使用する必要があり、モデル特性、ワークロードパターン、GPUリソースに合った最適な設定を見つけるために継続的なプロファイリングとベンチマーキングが不可欠である。

参考資料