- Published on
KV Cache 最適化 深層分析: GQA・MLA・MHA アテンションメカニズムとメモリ効率化戦略
- Authors
- Name
- はじめに
- Transformer Self-AttentionとKV Cache基礎
- Multi-Head Attention(MHA)メモリ分析
- Multi-Query Attention(MQA)
- Grouped Query Attention(GQA)
- Multi-Head Latent Attention(MLA)
- MHA vs MQA vs GQA vs MLA 比較表
- KV Cache 圧縮技法
- PagedAttention(vLLM)
- 障害事例と復旧手順
- 最適化チェックリスト
- まとめ
- 参考資料

はじめに
大規模言語モデル(LLM)の推論コストにおける最大のボトルネックは、KV Cache(Key-Value Cache) のメモリ消費である。GPT-4クラスのモデルが128Kコンテキスト長で同時に数百のリクエストを処理する場合、KV Cacheだけで数百GBのGPUメモリを消費する可能性がある。このメモリ制約は、スループットとレスポンスレイテンシの直接的な制限要因となる。
TransformerのSelf-Attentionメカニズムにおいて、デコーディングの各ステップで以前のすべてのトークンのKeyとValueベクトルを再計算することは非効率的であるため、これをキャッシュに保存して再利用するのがKV Cacheの基本原理である。問題は、シーケンス長が長くなるほどこのキャッシュのサイズが線形に増加することである。
この問題を解決するために、様々なアテンションメカニズムが提案されてきた。**Multi-Query Attention(MQA)**はすべてのQueryヘッドが1つのKVヘッドを共有してメモリを劇的に削減するが、品質低下が発生した。**Grouped Query Attention(GQA)**はMHAとMQAの折衷案としてLlama 2/3に採用された。**Multi-Head Latent Attention(MLA)**はDeepSeek-V2/V3でKVを低次元潜在空間に圧縮して驚くべき効率を達成した。
本記事では、各アテンションメカニズムの数学的原理とメモリ分析、KV Cache圧縮技法、PagedAttention(vLLM)、PyTorch実装例、実際のOOM障害事例と復旧、最適化チェックリストを総合的に解説する。
Transformer Self-AttentionとKV Cache基礎
Self-Attention演算
TransformerのScaled Dot-Product Attentionは以下のように定義される。
ここで、Q(Query)、K(Key)、V(Value)は入力トークンの埋め込みを線形変換して得られる。
import torch
import torch.nn as nn
import math
class ScaledDotProductAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.d_k = d_model // n_heads
self.n_heads = n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None):
B, T, C = x.shape
# Q, K, V プロジェクション
q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
k = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
v = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
# KV Cache 活用
if kv_cache is not None:
k_cache, v_cache = kv_cache
k = torch.cat([k_cache, k], dim=2)
v = torch.cat([v_cache, v], dim=2)
# Attention 演算
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = torch.softmax(scores, dim=-1)
output = torch.matmul(attn, v)
output = output.transpose(1, 2).contiguous().view(B, T, C)
return self.W_o(output), (k, v)
KV Cache メモリ分析
KV Cacheのメモリ使用量は以下の式で計算される。
KV Cache メモリ = 2(KとV)x レイヤー数(L)x ヘッド数(H)x シーケンス長(S)x ヘッド次元(d_k)x バイトサイズ
例えば、Llama 2 70Bモデルの場合:
- レイヤー数:80
- ヘッド数:64(GQA適用前)
- ヘッド次元:128
- シーケンス長:4096
- FP16(2バイト)
KV Cache = 2 x 80 x 64 x 4096 x 128 x 2 = 10.7 GB(単一リクエスト)
バッチサイズ32の場合、342.4 GBとなり、モデル重みのメモリをはるかに超過する。
Multi-Head Attention(MHA)メモリ分析
MHA構造
標準のMulti-Head Attentionでは、各アテンションヘッドが独立したQ、K、Vプロジェクションを持つ。H個のヘッドがそれぞれd_k次元のK、Vを維持するため、KV Cacheは最大サイズとなる。
class MultiHeadAttention(nn.Module):
"""標準 Multi-Head Attention(MHA)
各ヘッドが独立したKVペアを維持"""
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model) # H * d_k 次元
self.W_v = nn.Linear(d_model, d_model) # H * d_k 次元
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None):
B, T, _ = x.shape
q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
k = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
v = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
if kv_cache is not None:
k = torch.cat([kv_cache[0], k], dim=2)
v = torch.cat([kv_cache[1], v], dim=2)
# KV Cache サイズ: [B, n_heads, S, d_k] x 2
# メモリ: 2 * B * n_heads * S * d_k * sizeof(dtype)
attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = torch.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, -1)
return self.W_o(out), (k, v)
MHA KV Cache per token per layer = 2 x H x d_k x sizeof(dtype)
Multi-Query Attention(MQA)
MQA構造と削減効果
MQAはShazeer(2019)が提案した方法で、すべてのQueryヘッドが単一のKVヘッドを共有する。KV CacheがH倍削減されるが、品質低下が観察される。
class MultiQueryAttention(nn.Module):
"""Multi-Query Attention(MQA)
すべてのQueryヘッドが単一KVヘッドを共有"""
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model) # H * d_k
self.W_k = nn.Linear(d_model, self.d_k) # 1 * d_k(単一ヘッド)
self.W_v = nn.Linear(d_model, self.d_k) # 1 * d_k(単一ヘッド)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None):
B, T, _ = x.shape
q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
k = self.W_k(x).view(B, T, 1, self.d_k).transpose(1, 2)
v = self.W_v(x).view(B, T, 1, self.d_k).transpose(1, 2)
if kv_cache is not None:
k = torch.cat([kv_cache[0], k], dim=2)
v = torch.cat([kv_cache[1], v], dim=2)
# K、Vをすべてのヘッドにブロードキャスト
k = k.expand(-1, self.n_heads, -1, -1)
v = v.expand(-1, self.n_heads, -1, -1)
# KV Cache サイズ: [B, 1, S, d_k] x 2
# MHA比 1/H に縮小
attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = torch.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, -1)
return self.W_o(out), (k[:, :1], v[:, :1])
MQA KV Cache per token per layer = 2 x 1 x d_k x sizeof(dtype)、MHA比1/H縮小。
Grouped Query Attention(GQA)
GQAアーキテクチャ(Llama 2/3)
GQAはAinslie et al.(2023)が提案し、Llama 2、Llama 3に採用された方法で、QueryヘッドをG個のグループに分け、各グループが1つのKVヘッドを共有する。G=1ならMQA、G=HならMHAと同等である。
class GroupedQueryAttention(nn.Module):
"""Grouped Query Attention(GQA)
QueryヘッドをG個のグループに分け、各グループがKVを共有
Llama 2/3で採用"""
def __init__(self, d_model, n_heads, n_kv_heads):
super().__init__()
self.n_heads = n_heads # Queryヘッド数
self.n_kv_heads = n_kv_heads # KVヘッド数(グループ数)
self.d_k = d_model // n_heads
self.n_rep = n_heads // n_kv_heads # グループあたりのQueryヘッド数
self.W_q = nn.Linear(d_model, n_heads * self.d_k)
self.W_k = nn.Linear(d_model, n_kv_heads * self.d_k)
self.W_v = nn.Linear(d_model, n_kv_heads * self.d_k)
self.W_o = nn.Linear(d_model, d_model)
def repeat_kv(self, x):
"""KVヘッドをQueryヘッド数に合わせて繰り返す"""
B, H_kv, S, D = x.shape
if self.n_rep == 1:
return x
return (
x[:, :, None, :, :]
.expand(B, H_kv, self.n_rep, S, D)
.reshape(B, self.n_heads, S, D)
)
def forward(self, x, kv_cache=None):
B, T, _ = x.shape
q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
k = self.W_k(x).view(B, T, self.n_kv_heads, self.d_k).transpose(1, 2)
v = self.W_v(x).view(B, T, self.n_kv_heads, self.d_k).transpose(1, 2)
if kv_cache is not None:
k = torch.cat([kv_cache[0], k], dim=2)
v = torch.cat([kv_cache[1], v], dim=2)
# KV Cache サイズ: [B, n_kv_heads, S, d_k] x 2
new_cache = (k, v)
# ブロードキャストのためにKVヘッドを繰り返す
k = self.repeat_kv(k)
v = self.repeat_kv(v)
attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = torch.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, -1)
return self.W_o(out), new_cache
Llama 3 70BのGQA設定:n_heads=64、n_kv_heads=8。KV CacheがMHA比1/8に縮小されながらも品質低下がほとんどない。
Multi-Head Latent Attention(MLA)
MLAアーキテクチャ(DeepSeek-V2/V3)
MLAはDeepSeek-V2(2024)で提案された革新的な方法で、KVベクトルを低次元の**潜在ベクトル(latent vector)**に圧縮してキャッシュに保存する。デコーディング時に潜在ベクトルからK、Vを復元する。
核心アイデアは以下の通りである:
- 入力xを低次元潜在ベクトルcに圧縮:c = W_down * x
- 潜在ベクトルからK、Vを復元:K = W_uk _ c、V = W_uv _ c
- キャッシュには低次元cのみ保存
class MultiHeadLatentAttention(nn.Module):
"""Multi-Head Latent Attention(MLA)
DeepSeek-V2/V3で使用
KVを低次元潜在空間に圧縮してキャッシュ"""
def __init__(self, d_model, n_heads, d_latent, rope_dim=64):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.d_latent = d_latent # 潜在次元(d_modelよりはるかに小さい)
self.rope_dim = rope_dim
# Queryプロジェクション
self.W_q = nn.Linear(d_model, n_heads * self.d_k)
# KV圧縮(Down-projection)
self.W_dkv = nn.Linear(d_model, d_latent)
# KV復元(Up-projection)
self.W_uk = nn.Linear(d_latent, n_heads * self.d_k)
self.W_uv = nn.Linear(d_latent, n_heads * self.d_k)
# RoPE用の別途プロジェクション
self.W_kr = nn.Linear(d_model, rope_dim)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None):
B, T, _ = x.shape
# Query
q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
# KV圧縮: d_model -> d_latent
c_kv = self.W_dkv(x) # [B, T, d_latent]
# RoPEキー
k_rope = self.W_kr(x) # [B, T, rope_dim]
if kv_cache is not None:
c_kv_cached, k_rope_cached = kv_cache
c_kv = torch.cat([c_kv_cached, c_kv], dim=1)
k_rope = torch.cat([k_rope_cached, k_rope], dim=1)
# KV Cache: 低次元c_kvとk_ropeのみ保存
# メモリ: B * S * (d_latent + rope_dim) * sizeof(dtype)
# MHA比: d_latent / (2 * n_heads * d_k) 比率で縮小
new_cache = (c_kv, k_rope)
# KV復元: d_latent -> n_heads * d_k
S = c_kv.shape[1]
k = self.W_uk(c_kv).view(B, S, self.n_heads, self.d_k).transpose(1, 2)
v = self.W_uv(c_kv).view(B, S, self.n_heads, self.d_k).transpose(1, 2)
# Attention 演算
attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = torch.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, -1)
return self.W_o(out), new_cache
DeepSeek-V2のMLA設定:d_model=5120、d_latent=512、n_heads=128。KV CacheがMHA比約93.7%縮小されながらもMHAと同等の性能を達成した。
MHA vs MQA vs GQA vs MLA 比較表
| 項目 | MHA | MQA | GQA | MLA |
|---|---|---|---|---|
| KVヘッド数 | H | 1 | G (1 < G < H) | -(潜在ベクトル) |
| キャッシュ per token per layer | 2Hd_k | 2d_k | 2Gd_k | d_latent + d_rope |
| MHA比キャッシュ比率 | 100% | 1/H | G/H | d_latent/(2Hd_k) |
| Llama 3 70B例(H=64, G=8) | 16,384B | 256B | 2,048B | - |
| DeepSeek-V2例 | 163,840B | - | - | 約4,608B |
| 品質への影響 | ベースライン | 若干低下 | ほぼ同等 | ほぼ同等 |
| 推論速度 | ベースライン | 最速 | 良好 | 良好(復元コスト) |
| 学習安定性 | 高 | 普通 | 高 | 高 |
| 代表モデル | GPT-3, BERT | PaLM, Falcon | Llama 2/3, Mistral | DeepSeek-V2/V3 |
| 実装複雑度 | 低 | 低 | 中 | 高 |
KV Cache 圧縮技法
量子化(Quantization)
KV CacheをFP16からINT8やINT4に量子化すると、メモリを2-4倍削減できる。
class QuantizedKVCache:
"""KV Cache 量子化
FP16 -> INT8変換でメモリ50%削減"""
def __init__(self, n_layers, n_heads, max_seq_len, d_k, dtype=torch.int8):
self.n_layers = n_layers
self.scales = {} # 量子化スケールファクター
self.zero_points = {}
self.cache_k = {}
self.cache_v = {}
def quantize(self, tensor):
"""Per-channel INT8量子化"""
min_val = tensor.min(dim=-1, keepdim=True)[0]
max_val = tensor.max(dim=-1, keepdim=True)[0]
scale = (max_val - min_val) / 255.0
zero_point = (-min_val / scale).round().clamp(0, 255)
quantized = ((tensor / scale) + zero_point).round().clamp(0, 255).to(torch.uint8)
return quantized, scale, zero_point
def dequantize(self, quantized, scale, zero_point):
"""INT8 -> FP16 逆量子化"""
return (quantized.float() - zero_point) * scale
def update(self, layer_idx, k, v):
q_k, s_k, z_k = self.quantize(k)
q_v, s_v, z_v = self.quantize(v)
if layer_idx in self.cache_k:
self.cache_k[layer_idx] = torch.cat([self.cache_k[layer_idx], q_k], dim=2)
self.cache_v[layer_idx] = torch.cat([self.cache_v[layer_idx], q_v], dim=2)
else:
self.cache_k[layer_idx] = q_k
self.cache_v[layer_idx] = q_v
self.scales[layer_idx] = (s_k, s_v)
self.zero_points[layer_idx] = (z_k, z_v)
def get(self, layer_idx):
k = self.dequantize(
self.cache_k[layer_idx],
self.scales[layer_idx][0],
self.zero_points[layer_idx][0]
)
v = self.dequantize(
self.cache_v[layer_idx],
self.scales[layer_idx][1],
self.zero_points[layer_idx][1]
)
return k, v
エビクションポリシー(Eviction Policy)
シーケンスが長くなると古いトークンのKVを選択的に削除する戦略である。
**H2O(Heavy-Hitter Oracle)**方式はアテンションスコアが高いトークンのKVのみを保持する。
class H2OKVCache:
"""Heavy-Hitter Oracle(H2O)KV Cache エビクション
アテンションスコアが高いトークンのみ保持"""
def __init__(self, max_cache_size, n_heads, d_k):
self.max_cache_size = max_cache_size
self.attention_scores = None
def update(self, k, v, attn_weights):
"""アテンション重みに基づいて重要度を蓄積しエビクションを実行"""
B, H, S_new, _ = k.shape
if self.attention_scores is None:
self.attention_scores = attn_weights.sum(dim=2) # [B, H, S]
else:
# 新しいトークンに対するアテンションスコアを蓄積
new_scores = attn_weights.sum(dim=2)
self.attention_scores = torch.cat(
[self.attention_scores, new_scores[:, :, -S_new:]], dim=2
)
# キャッシュサイズ超過時にエビクション
current_size = self.attention_scores.shape[2]
if current_size > self.max_cache_size:
# 重要度が低いトークンを識別(最初のトークンは常に保持)
scores = self.attention_scores[:, :, 1:] # 最初のトークンを除外
_, indices = scores.topk(
self.max_cache_size - 1, dim=2, sorted=False
)
indices = indices + 1 # オフセット補正
# 最初のトークンインデックスを追加
first_token = torch.zeros(B, H, 1, dtype=torch.long, device=k.device)
keep_indices = torch.cat([first_token, indices], dim=2)
# 選択されたトークンのみ保持
k = k.gather(2, keep_indices.unsqueeze(-1).expand(-1, -1, -1, k.shape[-1]))
v = v.gather(2, keep_indices.unsqueeze(-1).expand(-1, -1, -1, v.shape[-1]))
self.attention_scores = self.attention_scores.gather(2, keep_indices)
return k, v
Sliding Window Attention
Mistralなどのモデルで使用される固定サイズウィンドウベースのアテンションである。
class SlidingWindowAttention(nn.Module):
"""Sliding Window Attention
最近W個のトークンにのみアテンションを適用
Mistral、Gemmaなどで使用"""
def __init__(self, d_model, n_heads, window_size=4096):
super().__init__()
self.window_size = window_size
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None):
B, T, _ = x.shape
q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
k = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
v = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
if kv_cache is not None:
k = torch.cat([kv_cache[0], k], dim=2)
v = torch.cat([kv_cache[1], v], dim=2)
# ウィンドウサイズでKV Cacheを制限
if k.shape[2] > self.window_size:
k = k[:, :, -self.window_size:]
v = v[:, :, -self.window_size:]
attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = torch.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, -1)
return self.W_o(out), (k, v)
PagedAttention(vLLM)
仮想メモリベースKV Cache管理
vLLMのPagedAttentionは、OSの仮想メモリシステムからインスピレーションを得て、KV Cacheを固定サイズブロック(ページ) 単位で管理する。これにより、連続メモリ割り当ての非効率性(内部断片化、外部断片化)を解消する。
class PagedKVCache:
"""vLLM PagedAttention 概念実装
KV Cacheをページ単位で管理"""
def __init__(self, block_size=16, n_blocks=1024, n_heads=32, d_k=128):
self.block_size = block_size
self.n_blocks = n_blocks
self.n_heads = n_heads
self.d_k = d_k
# 物理ブロックプール(事前割り当て)
self.k_blocks = torch.zeros(n_blocks, n_heads, block_size, d_k)
self.v_blocks = torch.zeros(n_blocks, n_heads, block_size, d_k)
# 空きブロック管理
self.free_blocks = list(range(n_blocks))
# シーケンスごとのブロックテーブル(仮想 -> 物理マッピング)
self.block_tables = {} # seq_id -> list of physical block indices
self.seq_lengths = {}
def allocate(self, seq_id):
"""新しいシーケンスに最初のブロックを割り当て"""
if not self.free_blocks:
raise RuntimeError("OOM: No free blocks available")
block_idx = self.free_blocks.pop(0)
self.block_tables[seq_id] = [block_idx]
self.seq_lengths[seq_id] = 0
def append_token(self, seq_id, k, v):
"""シーケンスに新しいトークンのKVを追加"""
seq_len = self.seq_lengths[seq_id]
block_idx_in_seq = seq_len // self.block_size
offset = seq_len % self.block_size
# 現在のブロックが満杯の場合、新しいブロックを割り当て
if offset == 0 and block_idx_in_seq > 0:
if not self.free_blocks:
raise RuntimeError("OOM: No free blocks available")
new_block = self.free_blocks.pop(0)
self.block_tables[seq_id].append(new_block)
# 物理ブロックにKVを保存
physical_block = self.block_tables[seq_id][block_idx_in_seq]
self.k_blocks[physical_block, :, offset] = k
self.v_blocks[physical_block, :, offset] = v
self.seq_lengths[seq_id] += 1
def get_kv(self, seq_id):
"""シーケンスの全KV Cacheを組み立て"""
blocks = self.block_tables[seq_id]
seq_len = self.seq_lengths[seq_id]
k_list, v_list = [], []
for i, block_idx in enumerate(blocks):
if i == len(blocks) - 1:
# 最後のブロックは実際のトークン数分のみ
remaining = seq_len % self.block_size or self.block_size
k_list.append(self.k_blocks[block_idx, :, :remaining])
v_list.append(self.v_blocks[block_idx, :, :remaining])
else:
k_list.append(self.k_blocks[block_idx])
v_list.append(self.v_blocks[block_idx])
return torch.cat(k_list, dim=1), torch.cat(v_list, dim=1)
def free(self, seq_id):
"""シーケンス完了後にブロックを返却"""
for block_idx in self.block_tables[seq_id]:
self.free_blocks.append(block_idx)
del self.block_tables[seq_id]
del self.seq_lengths[seq_id]
PagedAttentionの主要な利点は以下の通りである。
- メモリ断片化の解消: 連続メモリが不要なため、外部断片化がない
- メモリ使用効率の向上: 実際に使用する分だけブロックを割り当てる。vLLMは従来比最大24倍のスループットを達成
- Copy-on-Write: ビームサーチやパラレルサンプリングでKV Cacheを共有してメモリ節約
- プレフィックスキャッシング: 共通プレフィックス(システムプロンプト)のKV Cacheを複数リクエスト間で共有
障害事例と復旧手順
事例1: Long Context推論中のOOM発生
状況: Llama 3 70Bモデルで128Kコンテキスト長のドキュメント要約を試みたが、バッチサイズ4でOOMが発生し、推論サービス全体が停止した。
メモリ分析:
- モデル重み(FP16):約140 GB
- リクエストあたりKV Cache(GQA, H=64, G=8):2 x 80 x 8 x 128K x 128 x 2 = 約33.5 GB
- バッチ4合計KV Cache:約134 GB
- 必要総メモリ:約274 GB(8x A100 80GB = 640 GBでも活性化メモリと一時バッファを考慮すると危険)
復旧手順:
# 1. vLLMサーバー再起動時にメモリ制限を設定
python -m vllm.entrypoints.openai.api_server \
--model meta-llama/Llama-3-70B \
--tensor-parallel-size 8 \
--max-model-len 65536 \
--gpu-memory-utilization 0.9 \
--max-num-seqs 8 \
--enforce-eager
# 2. KV Cache量子化を有効化
python -m vllm.entrypoints.openai.api_server \
--model meta-llama/Llama-3-70B \
--tensor-parallel-size 8 \
--kv-cache-dtype fp8_e5m2 \
--max-model-len 128000
# 3. 動的バッチサイズ制限
# max-num-seqsを減らして同時処理リクエスト数を制限
事例2: KV Cache量子化による品質低下
状況: 推論コスト削減のためKV CacheをINT4で量子化したが、長い会話で応答品質が急激に低下した。特に数学計算とコード生成タスクでエラー率が増加した。
症状:
- 会話10ターン以上でハルシネーション増加
- 数学問題の精度15%低下
- コード生成時に構文エラーが頻発
復旧手順:
# 1. INT4 -> INT8に量子化精度を引き上げ
# vLLMではkv-cache-dtypeを変更
# --kv-cache-dtype fp8_e5m2 (FP8を使用)
# 2. 重要レイヤーのみ高精度を維持(混合量子化)
# 初期レイヤーと最終レイヤーはFP16、中間レイヤーはINT8
kv_cache_config = {
"default_dtype": "int8",
"high_precision_layers": [0, 1, 2, 77, 78, 79], # 最初/最後の3レイヤー
"high_precision_dtype": "fp16"
}
# 3. ベンチマーク再実行で品質確認
# MMLU、HumanEval、GSM8Kなど標準ベンチマークで検証
事例3: Prefix Cachingエラーによる応答混線
状況: vLLMのプレフィックスキャッシング機能を有効化した後、異なるユーザーのリクエストに対して誤ったコンテキストが適用される問題が発生した。システムプロンプトが同一の異なるセッションでKV Cacheが誤って共有された。
復旧手順:
# 1. プレフィックスキャッシングを一時無効化
python -m vllm.entrypoints.openai.api_server \
--model meta-llama/Llama-3-70B \
--enable-prefix-caching false
# 2. プレフィックスハッシュキーにセッションIDを含めるよう修正
# 同一プレフィックスでも異なるセッションは別キャッシュを使用
# 3. プレフィックスキャッシング再有効化時にTTLを設定
# --prefix-cache-ttl 300 (5分後に自動期限切れ)
最適化チェックリスト
モデル選択と設定
- モデルのアテンションメカニズムを確認(MHA/MQA/GQA/MLA)
- GQAモデル使用時にn_kv_headsの最適値を検証(通常n_heads/8)
- 最大シーケンス長を実際の使用パターンに合わせて制限
KV Cacheメモリ管理
- GPUメモリに対するKV Cache割り当て比率を計算(全体の60-80%推奨)
- KV Cache量子化の適用可否を決定(品質/メモリトレードオフ:FP8 > INT8 > INT4)
- PagedAttention導入可否を決定(vLLM、TensorRT-LLMなど)
- プレフィックスキャッシング有効化時のキー衝突防止を検証
推論パフォーマンス最適化
- Continuous Batchingを有効化(リクエスト完了時に即座にスロット再利用)
- Speculative Decoding導入可否を検討(ドラフトモデルでKV Cache活用を最大化)
- Sliding Window Attention適用可否を評価(長文書要約など)
- Tensor Parallelism vs Pipeline Parallelismの選択
モニタリング
- KV Cache使用率メトリクスの収集(vLLM:
vllm:cache_usage_percent) - リクエストごとのKV Cacheメモリ使用量の追跡
- OOMイベント自動アラートの設定
- バッチサイズごとのスループットとレイテンシのプロファイリング
運用上の注意事項
- 同時リクエスト数の上限を設定(max-num-seqs)
- GPUメモリ使用率の上限を設定(gpu-memory-utilization)
- OOM発生時のグレースフルデグラデーションを実装(リクエストキューイング、バッチ縮小)
- 定期的なKV Cacheプロファイリングでメモリリークを確認
まとめ
KV Cache最適化はLLM推論効率の核心である。MHAからMQA、GQA、MLAへと続くアテンションメカニズムの発展は、モデル品質を維持しながらKV Cacheメモリを劇的に削減することに成功した。GQAはLlamaシリーズで実証された実用的なアプローチであり、MLAはDeepSeek-V2/V3で示されたように、さらに積極的な圧縮を可能にする。
KV Cache量子化、エビクションポリシー、Sliding Window Attentionはメモリ制約下で長いシーケンスを処理するための追加技法であり、PagedAttentionはメモリ管理自体を革新して推論システムのスループットを最大化する。
実際のプロダクション環境では、これらの技法を組み合わせて使用する必要があり、モデル特性、ワークロードパターン、GPUリソースに合った最適な設定を見つけるために継続的なプロファイリングとベンチマーキングが不可欠である。
参考資料
- GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints (Ainslie et al., 2023)
- DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model (2024)
- KV Caching Explained - Hugging Face Blog
- KV Cache Optimization via Multi-Head Latent Attention - PyImageSearch
- Efficient Memory Management for Large Language Model Serving with PagedAttention (Kwon et al., 2023)