Split View: KV Cache 최적화 심층 분석: GQA·MLA·MHA 어텐션 메커니즘과 메모리 효율화 전략
KV Cache 최적화 심층 분석: GQA·MLA·MHA 어텐션 메커니즘과 메모리 효율화 전략
- 들어가며
- Transformer Self-Attention과 KV Cache 기초
- Multi-Head Attention (MHA) 메모리 분석
- Multi-Query Attention (MQA)
- Grouped Query Attention (GQA)
- Multi-Head Latent Attention (MLA)
- MHA vs MQA vs GQA vs MLA 비교표
- KV Cache 압축 기법
- PagedAttention (vLLM)
- 장애 사례와 복구 절차
- 최적화 체크리스트
- 마치며
- 참고자료

들어가며
대규모 언어 모델(LLM)의 추론 비용에서 가장 큰 병목은 KV Cache(Key-Value Cache) 의 메모리 소비이다. GPT-4 수준의 모델이 128K 컨텍스트 길이로 동시에 수백 개의 요청을 처리할 때, KV Cache만으로 수백 GB의 GPU 메모리를 소비할 수 있다. 이 메모리 제약은 곧 동시 처리량(throughput)과 응답 지연(latency)의 직접적인 제한 요인이 된다.
Transformer의 Self-Attention 메커니즘에서 디코딩 단계마다 이전 모든 토큰의 Key와 Value 벡터를 다시 계산하는 것은 비효율적이므로, 이를 캐시에 저장하고 재사용하는 것이 KV Cache의 기본 원리이다. 문제는 시퀀스 길이가 길어질수록 이 캐시의 크기가 선형적으로 증가한다는 것이다.
이 문제를 해결하기 위해 다양한 어텐션 메커니즘이 제안되었다. **Multi-Query Attention(MQA)**은 모든 Query 헤드가 하나의 KV 헤드를 공유하여 메모리를 극적으로 줄이지만 품질 저하가 발생했다. **Grouped Query Attention(GQA)**은 MHA와 MQA의 절충안으로 Llama 2/3에 채택되었다. **Multi-Head Latent Attention(MLA)**은 DeepSeek-V2/V3에서 KV를 저차원 잠재 공간으로 압축하여 놀라운 효율을 달성했다.
이 글에서는 각 어텐션 메커니즘의 수학적 원리와 메모리 분석, KV Cache 압축 기법, PagedAttention(vLLM), PyTorch 구현 예제, 실제 OOM 장애 사례와 복구, 최적화 체크리스트를 종합적으로 다룬다.
Transformer Self-Attention과 KV Cache 기초
Self-Attention 연산
Transformer의 Scaled Dot-Product Attention은 다음과 같이 정의된다.
여기서 Q(Query), K(Key), V(Value)는 입력 토큰의 임베딩을 선형 변환하여 얻는다.
import torch
import torch.nn as nn
import math
class ScaledDotProductAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.d_k = d_model // n_heads
self.n_heads = n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None):
B, T, C = x.shape
# Q, K, V 프로젝션
q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
k = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
v = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
# KV Cache 활용
if kv_cache is not None:
k_cache, v_cache = kv_cache
k = torch.cat([k_cache, k], dim=2)
v = torch.cat([v_cache, v], dim=2)
# Attention 연산
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = torch.softmax(scores, dim=-1)
output = torch.matmul(attn, v)
output = output.transpose(1, 2).contiguous().view(B, T, C)
return self.W_o(output), (k, v)
KV Cache 메모리 분석
KV Cache의 메모리 사용량은 다음 공식으로 계산된다.
KV Cache 메모리 = 2 (K와 V) x 레이어 수(L) x 헤드 수(H) x 시퀀스 길이(S) x 헤드 차원(d_k) x 바이트 크기
예를 들어, Llama 2 70B 모델의 경우:
- 레이어 수: 80
- 헤드 수: 64 (GQA 이전 기준)
- 헤드 차원: 128
- 시퀀스 길이: 4096
- FP16 (2바이트)
KV Cache = 2 x 80 x 64 x 4096 x 128 x 2 = 10.7 GB (단일 요청)
배치 크기가 32이면 342.4 GB로, 모델 가중치 메모리를 훨씬 초과한다.
Multi-Head Attention (MHA) 메모리 분석
MHA 구조
표준 Multi-Head Attention에서는 각 어텐션 헤드가 독립적인 Q, K, V 프로젝션을 갖는다. H개의 헤드가 각각 d_k 차원의 K, V를 유지하므로 KV Cache는 최대 크기가 된다.
class MultiHeadAttention(nn.Module):
"""표준 Multi-Head Attention (MHA)
각 헤드가 독립적인 KV 쌍을 유지"""
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model) # H * d_k 차원
self.W_v = nn.Linear(d_model, d_model) # H * d_k 차원
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None):
B, T, _ = x.shape
q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
k = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
v = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
if kv_cache is not None:
k = torch.cat([kv_cache[0], k], dim=2)
v = torch.cat([kv_cache[1], v], dim=2)
# KV Cache 크기: [B, n_heads, S, d_k] x 2
# 메모리: 2 * B * n_heads * S * d_k * sizeof(dtype)
attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = torch.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, -1)
return self.W_o(out), (k, v)
MHA의 KV Cache per token per layer = 2 x H x d_k x sizeof(dtype)
Multi-Query Attention (MQA)
MQA 구조와 절감 효과
MQA는 Shazeer(2019)가 제안한 방법으로, 모든 Query 헤드가 단일 KV 헤드를 공유한다. KV Cache가 H배 줄어들지만 품질 저하가 관찰된다.
class MultiQueryAttention(nn.Module):
"""Multi-Query Attention (MQA)
모든 Query 헤드가 단일 KV 헤드를 공유"""
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model) # H * d_k
self.W_k = nn.Linear(d_model, self.d_k) # 1 * d_k (단일 헤드)
self.W_v = nn.Linear(d_model, self.d_k) # 1 * d_k (단일 헤드)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None):
B, T, _ = x.shape
q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
k = self.W_k(x).view(B, T, 1, self.d_k).transpose(1, 2)
v = self.W_v(x).view(B, T, 1, self.d_k).transpose(1, 2)
if kv_cache is not None:
k = torch.cat([kv_cache[0], k], dim=2)
v = torch.cat([kv_cache[1], v], dim=2)
# K, V를 모든 헤드에 브로드캐스트
k = k.expand(-1, self.n_heads, -1, -1)
v = v.expand(-1, self.n_heads, -1, -1)
# KV Cache 크기: [B, 1, S, d_k] x 2
# MHA 대비 1/H로 축소
attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = torch.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, -1)
return self.W_o(out), (k[:, :1], v[:, :1])
MQA의 KV Cache per token per layer = 2 x 1 x d_k x sizeof(dtype), MHA 대비 1/H 축소.
Grouped Query Attention (GQA)
GQA 아키텍처 (Llama 2/3)
GQA는 Ainslie et al.(2023)이 제안하고 Llama 2, Llama 3에 채택된 방법으로, Query 헤드를 G개의 그룹으로 나누고 각 그룹이 하나의 KV 헤드를 공유한다. G=1이면 MQA, G=H이면 MHA와 동일하다.
class GroupedQueryAttention(nn.Module):
"""Grouped Query Attention (GQA)
Query 헤드를 G개 그룹으로 나누어 각 그룹이 KV를 공유
Llama 2/3에서 채택"""
def __init__(self, d_model, n_heads, n_kv_heads):
super().__init__()
self.n_heads = n_heads # Query 헤드 수
self.n_kv_heads = n_kv_heads # KV 헤드 수 (그룹 수)
self.d_k = d_model // n_heads
self.n_rep = n_heads // n_kv_heads # 그룹당 Query 헤드 수
self.W_q = nn.Linear(d_model, n_heads * self.d_k)
self.W_k = nn.Linear(d_model, n_kv_heads * self.d_k)
self.W_v = nn.Linear(d_model, n_kv_heads * self.d_k)
self.W_o = nn.Linear(d_model, d_model)
def repeat_kv(self, x):
"""KV 헤드를 Query 헤드 수에 맞게 반복"""
B, H_kv, S, D = x.shape
if self.n_rep == 1:
return x
return (
x[:, :, None, :, :]
.expand(B, H_kv, self.n_rep, S, D)
.reshape(B, self.n_heads, S, D)
)
def forward(self, x, kv_cache=None):
B, T, _ = x.shape
q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
k = self.W_k(x).view(B, T, self.n_kv_heads, self.d_k).transpose(1, 2)
v = self.W_v(x).view(B, T, self.n_kv_heads, self.d_k).transpose(1, 2)
if kv_cache is not None:
k = torch.cat([kv_cache[0], k], dim=2)
v = torch.cat([kv_cache[1], v], dim=2)
# KV Cache 크기: [B, n_kv_heads, S, d_k] x 2
new_cache = (k, v)
# 브로드캐스트를 위해 KV 헤드를 반복
k = self.repeat_kv(k)
v = self.repeat_kv(v)
attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = torch.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, -1)
return self.W_o(out), new_cache
Llama 3 70B의 GQA 설정: n_heads=64, n_kv_heads=8. KV Cache가 MHA 대비 1/8로 축소되면서도 품질 저하가 거의 없다.
Multi-Head Latent Attention (MLA)
MLA 아키텍처 (DeepSeek-V2/V3)
MLA는 DeepSeek-V2(2024)에서 제안된 혁신적인 방법으로, KV 벡터를 저차원 **잠재 벡터(latent vector)**로 압축하여 캐시에 저장한다. 디코딩 시 잠재 벡터에서 K, V를 복원한다.
핵심 아이디어는 다음과 같다:
- 입력 x를 저차원 잠재 벡터 c로 압축: c = W_down * x
- 잠재 벡터에서 K, V를 복원: K = W_uk _ c, V = W_uv _ c
- 캐시에는 저차원 c만 저장
class MultiHeadLatentAttention(nn.Module):
"""Multi-Head Latent Attention (MLA)
DeepSeek-V2/V3에서 사용
KV를 저차원 잠재 공간으로 압축하여 캐시"""
def __init__(self, d_model, n_heads, d_latent, rope_dim=64):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.d_latent = d_latent # 잠재 차원 (d_model보다 훨씬 작음)
self.rope_dim = rope_dim
# Query 프로젝션
self.W_q = nn.Linear(d_model, n_heads * self.d_k)
# KV 압축 (Down-projection)
self.W_dkv = nn.Linear(d_model, d_latent)
# KV 복원 (Up-projection)
self.W_uk = nn.Linear(d_latent, n_heads * self.d_k)
self.W_uv = nn.Linear(d_latent, n_heads * self.d_k)
# RoPE를 위한 별도 프로젝션
self.W_kr = nn.Linear(d_model, rope_dim)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None):
B, T, _ = x.shape
# Query
q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
# KV 압축: d_model -> d_latent
c_kv = self.W_dkv(x) # [B, T, d_latent]
# RoPE 키
k_rope = self.W_kr(x) # [B, T, rope_dim]
if kv_cache is not None:
c_kv_cached, k_rope_cached = kv_cache
c_kv = torch.cat([c_kv_cached, c_kv], dim=1)
k_rope = torch.cat([k_rope_cached, k_rope], dim=1)
# KV Cache: 저차원 c_kv와 k_rope만 저장
# 메모리: B * S * (d_latent + rope_dim) * sizeof(dtype)
# MHA 대비: d_latent / (2 * n_heads * d_k) 비율로 축소
new_cache = (c_kv, k_rope)
# KV 복원: d_latent -> n_heads * d_k
S = c_kv.shape[1]
k = self.W_uk(c_kv).view(B, S, self.n_heads, self.d_k).transpose(1, 2)
v = self.W_uv(c_kv).view(B, S, self.n_heads, self.d_k).transpose(1, 2)
# Attention 연산
attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = torch.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, -1)
return self.W_o(out), new_cache
DeepSeek-V2의 MLA 설정: d_model=5120, d_latent=512, n_heads=128. KV Cache가 MHA 대비 약 93.7% 축소되면서도 MHA와 동등한 성능을 달성했다.
MHA vs MQA vs GQA vs MLA 비교표
| 항목 | MHA | MQA | GQA | MLA |
|---|---|---|---|---|
| KV 헤드 수 | H | 1 | G (1 < G < H) | - (잠재 벡터) |
| 캐시 per token per layer | 2Hd_k | 2d_k | 2Gd_k | d_latent + d_rope |
| MHA 대비 캐시 비율 | 100% | 1/H | G/H | d_latent/(2Hd_k) |
| Llama 3 70B 예시 (H=64, G=8) | 16,384B | 256B | 2,048B | - |
| DeepSeek-V2 예시 | 163,840B | - | - | 약 4,608B |
| 품질 영향 | 기준선 | 약간 저하 | 거의 동일 | 거의 동일 |
| 추론 속도 | 기준선 | 최고 | 양호 | 양호 (복원 비용) |
| 학습 안정성 | 높음 | 보통 | 높음 | 높음 |
| 대표 모델 | GPT-3, BERT | PaLM, Falcon | Llama 2/3, Mistral | DeepSeek-V2/V3 |
| 구현 복잡도 | 낮음 | 낮음 | 중간 | 높음 |
KV Cache 압축 기법
양자화 (Quantization)
KV Cache를 FP16에서 INT8이나 INT4로 양자화하면 메모리를 2-4배 줄일 수 있다.
class QuantizedKVCache:
"""KV Cache 양자화
FP16 -> INT8로 변환하여 메모리 50% 절감"""
def __init__(self, n_layers, n_heads, max_seq_len, d_k, dtype=torch.int8):
self.n_layers = n_layers
self.scales = {} # 양자화 스케일 팩터
self.zero_points = {}
self.cache_k = {}
self.cache_v = {}
def quantize(self, tensor):
"""Per-channel INT8 양자화"""
min_val = tensor.min(dim=-1, keepdim=True)[0]
max_val = tensor.max(dim=-1, keepdim=True)[0]
scale = (max_val - min_val) / 255.0
zero_point = (-min_val / scale).round().clamp(0, 255)
quantized = ((tensor / scale) + zero_point).round().clamp(0, 255).to(torch.uint8)
return quantized, scale, zero_point
def dequantize(self, quantized, scale, zero_point):
"""INT8 -> FP16 역양자화"""
return (quantized.float() - zero_point) * scale
def update(self, layer_idx, k, v):
q_k, s_k, z_k = self.quantize(k)
q_v, s_v, z_v = self.quantize(v)
if layer_idx in self.cache_k:
self.cache_k[layer_idx] = torch.cat([self.cache_k[layer_idx], q_k], dim=2)
self.cache_v[layer_idx] = torch.cat([self.cache_v[layer_idx], q_v], dim=2)
else:
self.cache_k[layer_idx] = q_k
self.cache_v[layer_idx] = q_v
self.scales[layer_idx] = (s_k, s_v)
self.zero_points[layer_idx] = (z_k, z_v)
def get(self, layer_idx):
k = self.dequantize(
self.cache_k[layer_idx],
self.scales[layer_idx][0],
self.zero_points[layer_idx][0]
)
v = self.dequantize(
self.cache_v[layer_idx],
self.scales[layer_idx][1],
self.zero_points[layer_idx][1]
)
return k, v
퇴거 정책 (Eviction Policy)
시퀀스가 길어지면 오래된 토큰의 KV를 선택적으로 제거하는 전략이다.
H2O (Heavy-Hitter Oracle) 방식은 어텐션 점수가 높은 토큰의 KV만 유지한다.
class H2OKVCache:
"""Heavy-Hitter Oracle (H2O) KV Cache 퇴거
어텐션 점수가 높은 토큰만 유지"""
def __init__(self, max_cache_size, n_heads, d_k):
self.max_cache_size = max_cache_size
self.attention_scores = None
def update(self, k, v, attn_weights):
"""어텐션 가중치 기반으로 중요도를 누적하고 퇴거 수행"""
B, H, S_new, _ = k.shape
if self.attention_scores is None:
self.attention_scores = attn_weights.sum(dim=2) # [B, H, S]
else:
# 새 토큰에 대한 어텐션 점수 누적
new_scores = attn_weights.sum(dim=2)
self.attention_scores = torch.cat(
[self.attention_scores, new_scores[:, :, -S_new:]], dim=2
)
# 캐시 크기 초과 시 퇴거
current_size = self.attention_scores.shape[2]
if current_size > self.max_cache_size:
# 중요도가 낮은 토큰 식별 (첫 번째 토큰은 항상 유지)
scores = self.attention_scores[:, :, 1:] # 첫 토큰 제외
_, indices = scores.topk(
self.max_cache_size - 1, dim=2, sorted=False
)
indices = indices + 1 # 오프셋 보정
# 첫 번째 토큰 인덱스 추가
first_token = torch.zeros(B, H, 1, dtype=torch.long, device=k.device)
keep_indices = torch.cat([first_token, indices], dim=2)
# 선택된 토큰만 유지
k = k.gather(2, keep_indices.unsqueeze(-1).expand(-1, -1, -1, k.shape[-1]))
v = v.gather(2, keep_indices.unsqueeze(-1).expand(-1, -1, -1, v.shape[-1]))
self.attention_scores = self.attention_scores.gather(2, keep_indices)
return k, v
Sliding Window Attention
Mistral과 같은 모델에서 사용하는 고정 크기 윈도우 기반 어텐션이다.
class SlidingWindowAttention(nn.Module):
"""Sliding Window Attention
최근 W개 토큰에만 어텐션을 적용
Mistral, Gemma 등에서 사용"""
def __init__(self, d_model, n_heads, window_size=4096):
super().__init__()
self.window_size = window_size
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None):
B, T, _ = x.shape
q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
k = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
v = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
if kv_cache is not None:
k = torch.cat([kv_cache[0], k], dim=2)
v = torch.cat([kv_cache[1], v], dim=2)
# 윈도우 크기로 KV Cache 제한
if k.shape[2] > self.window_size:
k = k[:, :, -self.window_size:]
v = v[:, :, -self.window_size:]
attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = torch.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, -1)
return self.W_o(out), (k, v)
PagedAttention (vLLM)
가상 메모리 기반 KV Cache 관리
vLLM의 PagedAttention은 OS의 가상 메모리 시스템에서 영감을 받아 KV Cache를 고정 크기 블록(페이지) 단위로 관리한다. 이를 통해 연속 메모리 할당의 비효율성(내부 단편화, 외부 단편화)을 해소한다.
class PagedKVCache:
"""vLLM PagedAttention 개념 구현
KV Cache를 페이지 단위로 관리"""
def __init__(self, block_size=16, n_blocks=1024, n_heads=32, d_k=128):
self.block_size = block_size
self.n_blocks = n_blocks
self.n_heads = n_heads
self.d_k = d_k
# 물리 블록 풀 (사전 할당)
self.k_blocks = torch.zeros(n_blocks, n_heads, block_size, d_k)
self.v_blocks = torch.zeros(n_blocks, n_heads, block_size, d_k)
# 빈 블록 관리
self.free_blocks = list(range(n_blocks))
# 시퀀스별 블록 테이블 (가상 -> 물리 매핑)
self.block_tables = {} # seq_id -> list of physical block indices
self.seq_lengths = {}
def allocate(self, seq_id):
"""새 시퀀스에 대해 첫 블록 할당"""
if not self.free_blocks:
raise RuntimeError("OOM: No free blocks available")
block_idx = self.free_blocks.pop(0)
self.block_tables[seq_id] = [block_idx]
self.seq_lengths[seq_id] = 0
def append_token(self, seq_id, k, v):
"""시퀀스에 새 토큰의 KV 추가"""
seq_len = self.seq_lengths[seq_id]
block_idx_in_seq = seq_len // self.block_size
offset = seq_len % self.block_size
# 현재 블록이 가득 차면 새 블록 할당
if offset == 0 and block_idx_in_seq > 0:
if not self.free_blocks:
raise RuntimeError("OOM: No free blocks available")
new_block = self.free_blocks.pop(0)
self.block_tables[seq_id].append(new_block)
# 물리 블록에 KV 저장
physical_block = self.block_tables[seq_id][block_idx_in_seq]
self.k_blocks[physical_block, :, offset] = k
self.v_blocks[physical_block, :, offset] = v
self.seq_lengths[seq_id] += 1
def get_kv(self, seq_id):
"""시퀀스의 전체 KV Cache 조합"""
blocks = self.block_tables[seq_id]
seq_len = self.seq_lengths[seq_id]
k_list, v_list = [], []
for i, block_idx in enumerate(blocks):
if i == len(blocks) - 1:
# 마지막 블록은 실제 토큰 수만큼만
remaining = seq_len % self.block_size or self.block_size
k_list.append(self.k_blocks[block_idx, :, :remaining])
v_list.append(self.v_blocks[block_idx, :, :remaining])
else:
k_list.append(self.k_blocks[block_idx])
v_list.append(self.v_blocks[block_idx])
return torch.cat(k_list, dim=1), torch.cat(v_list, dim=1)
def free(self, seq_id):
"""시퀀스 완료 후 블록 반환"""
for block_idx in self.block_tables[seq_id]:
self.free_blocks.append(block_idx)
del self.block_tables[seq_id]
del self.seq_lengths[seq_id]
PagedAttention의 주요 이점은 다음과 같다.
- 메모리 단편화 해소: 연속 메모리가 필요하지 않으므로 외부 단편화가 없다
- 메모리 사용 효율 향상: 실제 사용하는 만큼만 블록을 할당한다. vLLM은 기존 대비 최대 24배 높은 처리량을 달성
- Copy-on-Write: 빔 서치나 병렬 샘플링에서 KV Cache를 공유하여 메모리 절약
- 프리픽스 캐싱: 공통 프리픽스(시스템 프롬프트)의 KV Cache를 여러 요청 간 공유
장애 사례와 복구 절차
사례 1: Long Context 추론 중 OOM 발생
상황: Llama 3 70B 모델로 128K 컨텍스트 길이의 문서 요약을 시도했으나, 배치 크기 4에서 OOM이 발생하여 전체 추론 서비스가 중단되었다.
메모리 분석:
- 모델 가중치 (FP16): 약 140 GB
- KV Cache per request (GQA, H=64, G=8): 2 x 80 x 8 x 128K x 128 x 2 = 약 33.5 GB
- 배치 4 총 KV Cache: 약 134 GB
- 총 필요 메모리: 약 274 GB (8x A100 80GB = 640 GB에서도 활성화 메모리와 임시 버퍼를 고려하면 위험)
복구 절차:
# 1. vLLM 서버 재시작 시 메모리 제한 설정
python -m vllm.entrypoints.openai.api_server \
--model meta-llama/Llama-3-70B \
--tensor-parallel-size 8 \
--max-model-len 65536 \
--gpu-memory-utilization 0.9 \
--max-num-seqs 8 \
--enforce-eager
# 2. KV Cache 양자화 활성화
python -m vllm.entrypoints.openai.api_server \
--model meta-llama/Llama-3-70B \
--tensor-parallel-size 8 \
--kv-cache-dtype fp8_e5m2 \
--max-model-len 128000
# 3. 동적 배치 크기 제한
# max-num-seqs를 줄여 동시 처리 요청 수 제한
사례 2: KV Cache 양자화로 인한 품질 저하
상황: 추론 비용 절감을 위해 KV Cache를 INT4로 양자화했으나, 긴 대화에서 응답 품질이 급격히 저하되었다. 특히 수학 계산과 코드 생성 작업에서 오류율이 증가했다.
증상:
- 대화 10턴 이상에서 hallucination 증가
- 수학 문제 정확도 15% 하락
- 코드 생성 시 구문 오류 빈발
복구 절차:
# 1. INT4 -> INT8로 양자화 정밀도 상향
# vLLM에서는 kv-cache-dtype 변경
# --kv-cache-dtype fp8_e5m2 (FP8 사용)
# 2. 핵심 레이어만 고정밀도 유지 (혼합 양자화)
# 초기 레이어와 마지막 레이어는 FP16, 중간 레이어는 INT8
kv_cache_config = {
"default_dtype": "int8",
"high_precision_layers": [0, 1, 2, 77, 78, 79], # 처음/마지막 3개 레이어
"high_precision_dtype": "fp16"
}
# 3. 벤치마크 재실행하여 품질 확인
# MMLU, HumanEval, GSM8K 등 표준 벤치마크로 검증
사례 3: Prefix Caching 오류로 인한 응답 혼선
상황: vLLM의 프리픽스 캐싱 기능을 활성화한 후, 서로 다른 사용자의 요청에 대해 잘못된 컨텍스트가 적용되는 문제가 발생했다. 시스템 프롬프트가 동일한 서로 다른 세션에서 KV Cache가 잘못 공유되었다.
복구 절차:
# 1. 프리픽스 캐싱 일시 비활성화
python -m vllm.entrypoints.openai.api_server \
--model meta-llama/Llama-3-70B \
--enable-prefix-caching false
# 2. 프리픽스 해시 키에 세션 ID 포함하도록 수정
# 동일 프리픽스라도 다른 세션이면 별도 캐시 사용
# 3. 프리픽스 캐싱 재활성화 시 TTL 설정
# --prefix-cache-ttl 300 (5분 후 자동 만료)
최적화 체크리스트
모델 선택 및 설정
- 모델의 어텐션 메커니즘 확인 (MHA/MQA/GQA/MLA)
- GQA 모델 사용 시 n_kv_heads 최적값 검증 (보통 n_heads/8)
- 최대 시퀀스 길이를 실제 사용 패턴에 맞게 제한
KV Cache 메모리 관리
- GPU 메모리 대비 KV Cache 할당 비율 계산 (전체의 60-80% 권장)
- KV Cache 양자화 적용 여부 결정 (FP8 > INT8 > INT4 순으로 품질/메모리 트레이드오프)
- PagedAttention 도입 여부 결정 (vLLM, TensorRT-LLM 등)
- 프리픽스 캐싱 활성화 시 키 충돌 방지 검증
추론 성능 최적화
- Continuous Batching 활성화 (요청 완료 시 즉시 슬롯 재활용)
- Speculative Decoding 도입 여부 검토 (초안 모델로 KV Cache 활용 극대화)
- Sliding Window Attention 적용 가능 여부 평가 (긴 문서 요약 등)
- Tensor Parallelism vs Pipeline Parallelism 선택
모니터링
- KV Cache 사용률 메트릭 수집 (vLLM:
vllm:cache_usage_percent) - 요청별 KV Cache 메모리 사용량 추적
- OOM 이벤트 자동 알림 설정
- 배치 크기별 처리량과 지연 시간 프로파일링
운영 주의사항
- 동시 요청 수 상한선 설정 (max-num-seqs)
- GPU 메모리 사용률 상한선 설정 (gpu-memory-utilization)
- OOM 발생 시 graceful degradation 구현 (요청 큐잉, 배치 축소)
- 정기적인 KV Cache 프로파일링으로 메모리 누수 확인
마치며
KV Cache 최적화는 LLM 추론 효율의 핵심이다. MHA에서 MQA, GQA, MLA로 이어지는 어텐션 메커니즘의 발전은 모델 품질을 유지하면서 KV Cache 메모리를 극적으로 줄이는 데 성공했다. GQA는 Llama 시리즈에서 입증된 실용적인 접근이며, MLA는 DeepSeek-V2/V3에서 보여준 것처럼 더욱 공격적인 압축을 가능하게 한다.
KV Cache 양자화, 퇴거 정책, Sliding Window Attention은 메모리 제약 하에서 긴 시퀀스를 처리하기 위한 추가적인 기법이며, PagedAttention은 메모리 관리 자체를 혁신하여 추론 시스템의 처리량을 극대화한다.
실제 프로덕션 환경에서는 이러한 기법들을 조합하여 사용해야 하며, 모델 특성, 워크로드 패턴, GPU 자원에 맞는 최적의 설정을 찾기 위해 지속적인 프로파일링과 벤치마킹이 필수적이다.
참고자료
- GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints (Ainslie et al., 2023)
- DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model (2024)
- KV Caching Explained - Hugging Face Blog
- KV Cache Optimization via Multi-Head Latent Attention - PyImageSearch
- Efficient Memory Management for Large Language Model Serving with PagedAttention (Kwon et al., 2023)
KV Cache Optimization Deep Dive: GQA, MLA, and MHA Attention Mechanisms with Memory Efficiency Strategies
- Introduction
- Transformer Self-Attention and KV Cache Fundamentals
- Multi-Head Attention (MHA) Memory Analysis
- Multi-Query Attention (MQA)
- Grouped Query Attention (GQA)
- Multi-Head Latent Attention (MLA)
- MHA vs MQA vs GQA vs MLA Comparison
- KV Cache Compression Techniques
- PagedAttention (vLLM)
- Failure Cases and Recovery Procedures
- Optimization Checklist
- Conclusion
- References

Introduction
The largest bottleneck in the inference cost of Large Language Models (LLMs) is the memory consumption of the KV Cache (Key-Value Cache). When a GPT-4 class model processes hundreds of concurrent requests with a 128K context length, the KV Cache alone can consume hundreds of GB of GPU memory. This memory constraint directly limits throughput and response latency.
In Transformer Self-Attention, recomputing Key and Value vectors for all previous tokens at each decoding step is inefficient, so storing and reusing them in a cache is the fundamental principle of KV Cache. The problem is that this cache grows linearly with sequence length.
Various attention mechanisms have been proposed to address this problem. Multi-Query Attention (MQA) dramatically reduces memory by having all Query heads share a single KV head, but at the cost of quality degradation. Grouped Query Attention (GQA) serves as a compromise between MHA and MQA and was adopted in Llama 2/3. Multi-Head Latent Attention (MLA) achieves remarkable efficiency in DeepSeek-V2/V3 by compressing KV into a low-dimensional latent space.
This article comprehensively covers the mathematical principles and memory analysis of each attention mechanism, KV Cache compression techniques, PagedAttention (vLLM), PyTorch implementation examples, real-world OOM failure cases and recovery, and an optimization checklist.
Transformer Self-Attention and KV Cache Fundamentals
Self-Attention Operation
Transformer's Scaled Dot-Product Attention is defined as follows:
Here, Q (Query), K (Key), and V (Value) are obtained by linearly transforming the input token embeddings.
import torch
import torch.nn as nn
import math
class ScaledDotProductAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.d_k = d_model // n_heads
self.n_heads = n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None):
B, T, C = x.shape
# Q, K, V projections
q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
k = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
v = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
# Use KV Cache
if kv_cache is not None:
k_cache, v_cache = kv_cache
k = torch.cat([k_cache, k], dim=2)
v = torch.cat([v_cache, v], dim=2)
# Attention computation
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = torch.softmax(scores, dim=-1)
output = torch.matmul(attn, v)
output = output.transpose(1, 2).contiguous().view(B, T, C)
return self.W_o(output), (k, v)
KV Cache Memory Analysis
The memory usage of KV Cache is calculated using the following formula:
KV Cache Memory = 2 (K and V) x Number of layers (L) x Number of heads (H) x Sequence length (S) x Head dimension (d_k) x Byte size
For example, with the Llama 2 70B model:
- Number of layers: 80
- Number of heads: 64 (before GQA)
- Head dimension: 128
- Sequence length: 4096
- FP16 (2 bytes)
KV Cache = 2 x 80 x 64 x 4096 x 128 x 2 = 10.7 GB (single request)
With batch size 32, this becomes 342.4 GB, far exceeding model weight memory.
Multi-Head Attention (MHA) Memory Analysis
MHA Structure
In standard Multi-Head Attention, each attention head has independent Q, K, V projections. With H heads each maintaining K, V of d_k dimensions, the KV Cache reaches its maximum size.
class MultiHeadAttention(nn.Module):
"""Standard Multi-Head Attention (MHA)
Each head maintains an independent KV pair"""
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model) # H * d_k dimensions
self.W_v = nn.Linear(d_model, d_model) # H * d_k dimensions
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None):
B, T, _ = x.shape
q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
k = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
v = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
if kv_cache is not None:
k = torch.cat([kv_cache[0], k], dim=2)
v = torch.cat([kv_cache[1], v], dim=2)
# KV Cache size: [B, n_heads, S, d_k] x 2
# Memory: 2 * B * n_heads * S * d_k * sizeof(dtype)
attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = torch.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, -1)
return self.W_o(out), (k, v)
MHA KV Cache per token per layer = 2 x H x d_k x sizeof(dtype)
Multi-Query Attention (MQA)
MQA Structure and Savings
MQA, proposed by Shazeer (2019), has all Query heads share a single KV head. This reduces KV Cache by a factor of H but introduces quality degradation.
class MultiQueryAttention(nn.Module):
"""Multi-Query Attention (MQA)
All Query heads share a single KV head"""
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model) # H * d_k
self.W_k = nn.Linear(d_model, self.d_k) # 1 * d_k (single head)
self.W_v = nn.Linear(d_model, self.d_k) # 1 * d_k (single head)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None):
B, T, _ = x.shape
q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
k = self.W_k(x).view(B, T, 1, self.d_k).transpose(1, 2)
v = self.W_v(x).view(B, T, 1, self.d_k).transpose(1, 2)
if kv_cache is not None:
k = torch.cat([kv_cache[0], k], dim=2)
v = torch.cat([kv_cache[1], v], dim=2)
# Broadcast K, V to all heads
k = k.expand(-1, self.n_heads, -1, -1)
v = v.expand(-1, self.n_heads, -1, -1)
# KV Cache size: [B, 1, S, d_k] x 2
# Reduced to 1/H compared to MHA
attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = torch.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, -1)
return self.W_o(out), (k[:, :1], v[:, :1])
MQA KV Cache per token per layer = 2 x 1 x d_k x sizeof(dtype), 1/H reduction compared to MHA.
Grouped Query Attention (GQA)
GQA Architecture (Llama 2/3)
GQA, proposed by Ainslie et al. (2023) and adopted in Llama 2 and Llama 3, divides Query heads into G groups where each group shares one KV head. When G=1, it is equivalent to MQA; when G=H, it is equivalent to MHA.
class GroupedQueryAttention(nn.Module):
"""Grouped Query Attention (GQA)
Divides Query heads into G groups, each sharing KV
Adopted in Llama 2/3"""
def __init__(self, d_model, n_heads, n_kv_heads):
super().__init__()
self.n_heads = n_heads # Number of Query heads
self.n_kv_heads = n_kv_heads # Number of KV heads (groups)
self.d_k = d_model // n_heads
self.n_rep = n_heads // n_kv_heads # Query heads per group
self.W_q = nn.Linear(d_model, n_heads * self.d_k)
self.W_k = nn.Linear(d_model, n_kv_heads * self.d_k)
self.W_v = nn.Linear(d_model, n_kv_heads * self.d_k)
self.W_o = nn.Linear(d_model, d_model)
def repeat_kv(self, x):
"""Repeat KV heads to match Query head count"""
B, H_kv, S, D = x.shape
if self.n_rep == 1:
return x
return (
x[:, :, None, :, :]
.expand(B, H_kv, self.n_rep, S, D)
.reshape(B, self.n_heads, S, D)
)
def forward(self, x, kv_cache=None):
B, T, _ = x.shape
q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
k = self.W_k(x).view(B, T, self.n_kv_heads, self.d_k).transpose(1, 2)
v = self.W_v(x).view(B, T, self.n_kv_heads, self.d_k).transpose(1, 2)
if kv_cache is not None:
k = torch.cat([kv_cache[0], k], dim=2)
v = torch.cat([kv_cache[1], v], dim=2)
# KV Cache size: [B, n_kv_heads, S, d_k] x 2
new_cache = (k, v)
# Repeat KV heads for broadcast
k = self.repeat_kv(k)
v = self.repeat_kv(v)
attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = torch.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, -1)
return self.W_o(out), new_cache
Llama 3 70B GQA configuration: n_heads=64, n_kv_heads=8. KV Cache is reduced to 1/8 compared to MHA with virtually no quality loss.
Multi-Head Latent Attention (MLA)
MLA Architecture (DeepSeek-V2/V3)
MLA, proposed in DeepSeek-V2 (2024), is an innovative method that compresses KV vectors into a low-dimensional latent vector for cache storage. During decoding, K and V are reconstructed from the latent vector.
The core idea is:
- Compress input x into a low-dimensional latent vector c: c = W_down * x
- Reconstruct K, V from the latent vector: K = W_uk _ c, V = W_uv _ c
- Only the low-dimensional c is stored in the cache
class MultiHeadLatentAttention(nn.Module):
"""Multi-Head Latent Attention (MLA)
Used in DeepSeek-V2/V3
Compresses KV into low-dimensional latent space for caching"""
def __init__(self, d_model, n_heads, d_latent, rope_dim=64):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.d_latent = d_latent # Latent dimension (much smaller than d_model)
self.rope_dim = rope_dim
# Query projection
self.W_q = nn.Linear(d_model, n_heads * self.d_k)
# KV compression (Down-projection)
self.W_dkv = nn.Linear(d_model, d_latent)
# KV reconstruction (Up-projection)
self.W_uk = nn.Linear(d_latent, n_heads * self.d_k)
self.W_uv = nn.Linear(d_latent, n_heads * self.d_k)
# Separate projection for RoPE
self.W_kr = nn.Linear(d_model, rope_dim)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None):
B, T, _ = x.shape
# Query
q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
# KV compression: d_model -> d_latent
c_kv = self.W_dkv(x) # [B, T, d_latent]
# RoPE key
k_rope = self.W_kr(x) # [B, T, rope_dim]
if kv_cache is not None:
c_kv_cached, k_rope_cached = kv_cache
c_kv = torch.cat([c_kv_cached, c_kv], dim=1)
k_rope = torch.cat([k_rope_cached, k_rope], dim=1)
# KV Cache: only stores low-dimensional c_kv and k_rope
# Memory: B * S * (d_latent + rope_dim) * sizeof(dtype)
# Compared to MHA: d_latent / (2 * n_heads * d_k) ratio reduction
new_cache = (c_kv, k_rope)
# KV reconstruction: d_latent -> n_heads * d_k
S = c_kv.shape[1]
k = self.W_uk(c_kv).view(B, S, self.n_heads, self.d_k).transpose(1, 2)
v = self.W_uv(c_kv).view(B, S, self.n_heads, self.d_k).transpose(1, 2)
# Attention computation
attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = torch.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, -1)
return self.W_o(out), new_cache
DeepSeek-V2 MLA configuration: d_model=5120, d_latent=512, n_heads=128. KV Cache is reduced by approximately 93.7% compared to MHA while achieving performance equivalent to MHA.
MHA vs MQA vs GQA vs MLA Comparison
| Feature | MHA | MQA | GQA | MLA |
|---|---|---|---|---|
| KV Heads | H | 1 | G (1 < G < H) | - (latent vector) |
| Cache per token per layer | 2Hd_k | 2d_k | 2Gd_k | d_latent + d_rope |
| Cache ratio vs MHA | 100% | 1/H | G/H | d_latent/(2Hd_k) |
| Llama 3 70B example (H=64, G=8) | 16,384B | 256B | 2,048B | - |
| DeepSeek-V2 example | 163,840B | - | - | ~4,608B |
| Quality Impact | Baseline | Slight degradation | Nearly identical | Nearly identical |
| Inference Speed | Baseline | Fastest | Good | Good (reconstruction cost) |
| Training Stability | High | Moderate | High | High |
| Representative Models | GPT-3, BERT | PaLM, Falcon | Llama 2/3, Mistral | DeepSeek-V2/V3 |
| Implementation Complexity | Low | Low | Medium | High |
KV Cache Compression Techniques
Quantization
Quantizing KV Cache from FP16 to INT8 or INT4 can reduce memory by 2-4x.
class QuantizedKVCache:
"""KV Cache Quantization
FP16 -> INT8 conversion for 50% memory reduction"""
def __init__(self, n_layers, n_heads, max_seq_len, d_k, dtype=torch.int8):
self.n_layers = n_layers
self.scales = {} # Quantization scale factors
self.zero_points = {}
self.cache_k = {}
self.cache_v = {}
def quantize(self, tensor):
"""Per-channel INT8 quantization"""
min_val = tensor.min(dim=-1, keepdim=True)[0]
max_val = tensor.max(dim=-1, keepdim=True)[0]
scale = (max_val - min_val) / 255.0
zero_point = (-min_val / scale).round().clamp(0, 255)
quantized = ((tensor / scale) + zero_point).round().clamp(0, 255).to(torch.uint8)
return quantized, scale, zero_point
def dequantize(self, quantized, scale, zero_point):
"""INT8 -> FP16 dequantization"""
return (quantized.float() - zero_point) * scale
def update(self, layer_idx, k, v):
q_k, s_k, z_k = self.quantize(k)
q_v, s_v, z_v = self.quantize(v)
if layer_idx in self.cache_k:
self.cache_k[layer_idx] = torch.cat([self.cache_k[layer_idx], q_k], dim=2)
self.cache_v[layer_idx] = torch.cat([self.cache_v[layer_idx], q_v], dim=2)
else:
self.cache_k[layer_idx] = q_k
self.cache_v[layer_idx] = q_v
self.scales[layer_idx] = (s_k, s_v)
self.zero_points[layer_idx] = (z_k, z_v)
def get(self, layer_idx):
k = self.dequantize(
self.cache_k[layer_idx],
self.scales[layer_idx][0],
self.zero_points[layer_idx][0]
)
v = self.dequantize(
self.cache_v[layer_idx],
self.scales[layer_idx][1],
self.zero_points[layer_idx][1]
)
return k, v
Eviction Policy
A strategy that selectively removes old token KV pairs as the sequence grows longer.
The H2O (Heavy-Hitter Oracle) approach retains only the KV pairs of tokens with high attention scores.
class H2OKVCache:
"""Heavy-Hitter Oracle (H2O) KV Cache Eviction
Retains only tokens with high attention scores"""
def __init__(self, max_cache_size, n_heads, d_k):
self.max_cache_size = max_cache_size
self.attention_scores = None
def update(self, k, v, attn_weights):
"""Accumulate importance based on attention weights and perform eviction"""
B, H, S_new, _ = k.shape
if self.attention_scores is None:
self.attention_scores = attn_weights.sum(dim=2) # [B, H, S]
else:
# Accumulate attention scores for new tokens
new_scores = attn_weights.sum(dim=2)
self.attention_scores = torch.cat(
[self.attention_scores, new_scores[:, :, -S_new:]], dim=2
)
# Evict when cache size is exceeded
current_size = self.attention_scores.shape[2]
if current_size > self.max_cache_size:
# Identify low-importance tokens (always keep first token)
scores = self.attention_scores[:, :, 1:] # Exclude first token
_, indices = scores.topk(
self.max_cache_size - 1, dim=2, sorted=False
)
indices = indices + 1 # Offset correction
# Add first token index
first_token = torch.zeros(B, H, 1, dtype=torch.long, device=k.device)
keep_indices = torch.cat([first_token, indices], dim=2)
# Keep only selected tokens
k = k.gather(2, keep_indices.unsqueeze(-1).expand(-1, -1, -1, k.shape[-1]))
v = v.gather(2, keep_indices.unsqueeze(-1).expand(-1, -1, -1, v.shape[-1]))
self.attention_scores = self.attention_scores.gather(2, keep_indices)
return k, v
Sliding Window Attention
A fixed-size window-based attention used in models like Mistral.
class SlidingWindowAttention(nn.Module):
"""Sliding Window Attention
Applies attention only to the most recent W tokens
Used in Mistral, Gemma, etc."""
def __init__(self, d_model, n_heads, window_size=4096):
super().__init__()
self.window_size = window_size
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None):
B, T, _ = x.shape
q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
k = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
v = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
if kv_cache is not None:
k = torch.cat([kv_cache[0], k], dim=2)
v = torch.cat([kv_cache[1], v], dim=2)
# Limit KV Cache to window size
if k.shape[2] > self.window_size:
k = k[:, :, -self.window_size:]
v = v[:, :, -self.window_size:]
attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = torch.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, -1)
return self.W_o(out), (k, v)
PagedAttention (vLLM)
Virtual Memory-Based KV Cache Management
vLLM's PagedAttention, inspired by OS virtual memory systems, manages KV Cache in fixed-size blocks (pages). This eliminates the inefficiency of contiguous memory allocation (internal fragmentation, external fragmentation).
class PagedKVCache:
"""vLLM PagedAttention concept implementation
Manages KV Cache in page units"""
def __init__(self, block_size=16, n_blocks=1024, n_heads=32, d_k=128):
self.block_size = block_size
self.n_blocks = n_blocks
self.n_heads = n_heads
self.d_k = d_k
# Physical block pool (pre-allocated)
self.k_blocks = torch.zeros(n_blocks, n_heads, block_size, d_k)
self.v_blocks = torch.zeros(n_blocks, n_heads, block_size, d_k)
# Free block management
self.free_blocks = list(range(n_blocks))
# Per-sequence block tables (virtual -> physical mapping)
self.block_tables = {} # seq_id -> list of physical block indices
self.seq_lengths = {}
def allocate(self, seq_id):
"""Allocate first block for a new sequence"""
if not self.free_blocks:
raise RuntimeError("OOM: No free blocks available")
block_idx = self.free_blocks.pop(0)
self.block_tables[seq_id] = [block_idx]
self.seq_lengths[seq_id] = 0
def append_token(self, seq_id, k, v):
"""Append new token KV to a sequence"""
seq_len = self.seq_lengths[seq_id]
block_idx_in_seq = seq_len // self.block_size
offset = seq_len % self.block_size
# Allocate new block when current block is full
if offset == 0 and block_idx_in_seq > 0:
if not self.free_blocks:
raise RuntimeError("OOM: No free blocks available")
new_block = self.free_blocks.pop(0)
self.block_tables[seq_id].append(new_block)
# Store KV in physical block
physical_block = self.block_tables[seq_id][block_idx_in_seq]
self.k_blocks[physical_block, :, offset] = k
self.v_blocks[physical_block, :, offset] = v
self.seq_lengths[seq_id] += 1
def get_kv(self, seq_id):
"""Assemble complete KV Cache for a sequence"""
blocks = self.block_tables[seq_id]
seq_len = self.seq_lengths[seq_id]
k_list, v_list = [], []
for i, block_idx in enumerate(blocks):
if i == len(blocks) - 1:
# Last block only up to actual token count
remaining = seq_len % self.block_size or self.block_size
k_list.append(self.k_blocks[block_idx, :, :remaining])
v_list.append(self.v_blocks[block_idx, :, :remaining])
else:
k_list.append(self.k_blocks[block_idx])
v_list.append(self.v_blocks[block_idx])
return torch.cat(k_list, dim=1), torch.cat(v_list, dim=1)
def free(self, seq_id):
"""Return blocks after sequence completion"""
for block_idx in self.block_tables[seq_id]:
self.free_blocks.append(block_idx)
del self.block_tables[seq_id]
del self.seq_lengths[seq_id]
The key benefits of PagedAttention are:
- Eliminates memory fragmentation: No contiguous memory required, so no external fragmentation
- Improved memory utilization: Blocks are allocated only as needed. vLLM achieves up to 24x higher throughput compared to previous systems
- Copy-on-Write: Shares KV Cache in beam search and parallel sampling to save memory
- Prefix caching: Shares common prefix (system prompt) KV Cache across multiple requests
Failure Cases and Recovery Procedures
Case 1: OOM During Long Context Inference
Situation: Attempted document summarization with a 128K context length using the Llama 3 70B model, but OOM occurred at batch size 4, shutting down the entire inference service.
Memory Analysis:
- Model weights (FP16): ~140 GB
- KV Cache per request (GQA, H=64, G=8): 2 x 80 x 8 x 128K x 128 x 2 = ~33.5 GB
- Total KV Cache for batch 4: ~134 GB
- Total memory needed: ~274 GB (risky even on 8x A100 80GB = 640 GB when considering activation memory and temporary buffers)
Recovery Procedure:
# 1. Set memory limits when restarting the vLLM server
python -m vllm.entrypoints.openai.api_server \
--model meta-llama/Llama-3-70B \
--tensor-parallel-size 8 \
--max-model-len 65536 \
--gpu-memory-utilization 0.9 \
--max-num-seqs 8 \
--enforce-eager
# 2. Enable KV Cache quantization
python -m vllm.entrypoints.openai.api_server \
--model meta-llama/Llama-3-70B \
--tensor-parallel-size 8 \
--kv-cache-dtype fp8_e5m2 \
--max-model-len 128000
# 3. Dynamic batch size limiting
# Reduce max-num-seqs to limit concurrent request count
Case 2: Quality Degradation from KV Cache Quantization
Situation: KV Cache was quantized to INT4 for inference cost reduction, but response quality degraded significantly in long conversations. Error rates increased particularly in math calculations and code generation tasks.
Symptoms:
- Increased hallucinations after 10+ conversation turns
- 15% accuracy drop on math problems
- Frequent syntax errors in code generation
Recovery Procedure:
# 1. Upgrade quantization precision from INT4 to INT8
# In vLLM, change kv-cache-dtype
# --kv-cache-dtype fp8_e5m2 (use FP8)
# 2. Keep only critical layers at high precision (mixed quantization)
# Initial and final layers at FP16, middle layers at INT8
kv_cache_config = {
"default_dtype": "int8",
"high_precision_layers": [0, 1, 2, 77, 78, 79], # First/last 3 layers
"high_precision_dtype": "fp16"
}
# 3. Re-run benchmarks to verify quality
# Validate with standard benchmarks like MMLU, HumanEval, GSM8K
Case 3: Response Mix-up from Prefix Caching Error
Situation: After enabling vLLM's prefix caching feature, incorrect contexts were applied to different users' requests. KV Cache was incorrectly shared between different sessions with the same system prompt.
Recovery Procedure:
# 1. Temporarily disable prefix caching
python -m vllm.entrypoints.openai.api_server \
--model meta-llama/Llama-3-70B \
--enable-prefix-caching false
# 2. Modify prefix hash key to include session ID
# Different sessions should use separate caches even with the same prefix
# 3. Set TTL when re-enabling prefix caching
# --prefix-cache-ttl 300 (auto-expire after 5 minutes)
Optimization Checklist
Model Selection and Configuration
- Verify the model's attention mechanism (MHA/MQA/GQA/MLA)
- Validate optimal n_kv_heads value when using GQA models (typically n_heads/8)
- Limit maximum sequence length to match actual usage patterns
KV Cache Memory Management
- Calculate KV Cache allocation ratio relative to GPU memory (60-80% of total recommended)
- Decide on KV Cache quantization (quality/memory tradeoff: FP8 > INT8 > INT4)
- Decide on PagedAttention adoption (vLLM, TensorRT-LLM, etc.)
- Verify key collision prevention when enabling prefix caching
Inference Performance Optimization
- Enable Continuous Batching (immediately recycle slots upon request completion)
- Evaluate Speculative Decoding adoption (maximize KV Cache utilization with draft model)
- Assess Sliding Window Attention applicability (long document summarization, etc.)
- Choose between Tensor Parallelism vs Pipeline Parallelism
Monitoring
- Collect KV Cache utilization metrics (vLLM:
vllm:cache_usage_percent) - Track per-request KV Cache memory usage
- Set up automatic OOM event alerts
- Profile throughput and latency by batch size
Operational Notes
- Set concurrent request upper limits (max-num-seqs)
- Set GPU memory utilization upper limits (gpu-memory-utilization)
- Implement graceful degradation on OOM (request queuing, batch reduction)
- Perform regular KV Cache profiling to check for memory leaks
Conclusion
KV Cache optimization is the key to LLM inference efficiency. The evolution of attention mechanisms from MHA to MQA, GQA, and MLA has successfully reduced KV Cache memory dramatically while maintaining model quality. GQA is a practical approach proven in the Llama series, and MLA enables even more aggressive compression as demonstrated in DeepSeek-V2/V3.
KV Cache quantization, eviction policies, and Sliding Window Attention are additional techniques for handling long sequences under memory constraints, while PagedAttention innovates memory management itself to maximize inference system throughput.
In real production environments, these techniques must be used in combination, and continuous profiling and benchmarking are essential to find optimal configurations that match model characteristics, workload patterns, and GPU resources.
References
- GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints (Ainslie et al., 2023)
- DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model (2024)
- KV Caching Explained - Hugging Face Blog
- KV Cache Optimization via Multi-Head Latent Attention - PyImageSearch
- Efficient Memory Management for Large Language Model Serving with PagedAttention (Kwon et al., 2023)