Split View: LLM 추론 최적화 완전 가이드: KV Cache, Speculative Decoding, Continuous Batching
LLM 추론 최적화 완전 가이드: KV Cache, Speculative Decoding, Continuous Batching
들어가며
대형 언어 모델(LLM)을 프로덕션에 배포하면 즉시 직면하는 문제가 있습니다. 바로 추론 속도와 비용입니다. GPT-4 수준의 모델을 처음 쿼리했을 때 수 초의 지연이 발생하고, 동시 사용자가 늘어나면 처리량이 급격히 저하됩니다.
이 가이드는 LLM 추론 최적화의 핵심 기술을 완전히 파헤칩니다. KV Cache의 작동 원리부터 PagedAttention, Speculative Decoding, FlashAttention, 그리고 최신 vLLM과 TensorRT-LLM 엔진까지 — 단순한 사용법이 아니라 왜 작동하는지 원리를 이해합니다.
1. LLM 추론 과정 이해
1.1 두 단계: Prefill과 Decode
LLM 텍스트 생성은 두 단계로 나뉩니다.
Prefill 단계 (프롬프트 처리)
- 입력 프롬프트의 모든 토큰을 동시에 처리
- 각 레이어에서 Key/Value 캐시를 생성하여 저장
- 연산 집약적(Compute-Bound): GPU 연산 능력이 병목
- TTFT (Time To First Token)에 직접 영향
Decode 단계 (토큰 생성)
- 한 번에 하나의 토큰을 자기회귀 방식으로 생성
- 이전에 생성된 모든 토큰의 KV Cache를 참조
- 메모리 대역폭 집약적(Memory-Bound): HBM 읽기 속도가 병목
- TPOT (Time Per Output Token)에 직접 영향
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
def measure_prefill_decode_time(model, tokenizer, prompt: str, max_new_tokens: int = 100):
"""Prefill과 Decode 단계 시간 측정"""
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
input_len = inputs["input_ids"].size(1)
# Prefill 측정
torch.cuda.synchronize()
prefill_start = time.perf_counter()
with torch.no_grad():
# 첫 번째 토큰까지 (Prefill + 첫 Decode)
first_output = model.generate(
**inputs,
max_new_tokens=1,
do_sample=False
)
torch.cuda.synchronize()
ttft = time.perf_counter() - prefill_start
# 전체 생성 측정
torch.cuda.synchronize()
total_start = time.perf_counter()
with torch.no_grad():
full_output = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False
)
torch.cuda.synchronize()
total_time = time.perf_counter() - total_start
output_tokens = full_output.size(1) - input_len
decode_time = total_time - ttft
tpot = decode_time / max(output_tokens - 1, 1)
print(f"입력 토큰 수: {input_len}")
print(f"생성 토큰 수: {output_tokens}")
print(f"TTFT (첫 토큰 지연): {ttft * 1000:.1f} ms")
print(f"TPOT (토큰당 시간): {tpot * 1000:.1f} ms")
print(f"처리량: {output_tokens / total_time:.1f} 토큰/초")
return ttft, tpot
1.2 메모리 병목 분석
Decode 단계가 왜 메모리 집약적인지 이해합니다.
def analyze_memory_bandwidth():
"""LLM 추론의 메모리 대역폭 분석"""
# 예시: Llama-2-7B 설정
model_params = {
"num_layers": 32,
"hidden_size": 4096,
"num_heads": 32,
"head_dim": 128,
"vocab_size": 32000,
}
dtype_bytes = 2 # FP16: 2 bytes
# 가중치 메모리 (한 번 로드)
# 각 Transformer 레이어의 가중치
attn_weight = 4 * model_params["hidden_size"] ** 2 # Q, K, V, O 프로젝션
ffn_weight = 8 * model_params["hidden_size"] ** 2 # Up, Gate, Down 프로젝션 (SwiGLU)
layer_weight = (attn_weight + ffn_weight) * dtype_bytes
total_weight_bytes = layer_weight * model_params["num_layers"]
total_weight_gb = total_weight_bytes / 1e9
print(f"모델 가중치: {total_weight_gb:.2f} GB")
# KV Cache 메모리 (시퀀스 길이에 비례)
seq_len = 2048
kv_cache_per_token = (
2 * # K와 V
model_params["num_layers"] *
model_params["num_heads"] *
model_params["head_dim"] *
dtype_bytes
)
kv_cache_total = kv_cache_per_token * seq_len / 1e6
print(f"KV Cache ({seq_len} 토큰): {kv_cache_total:.2f} MB")
print(f"토큰당 KV Cache: {kv_cache_per_token} bytes")
# A100 메모리 대역폭: 2 TB/s
memory_bandwidth_tbs = 2.0 # TB/s
# Decode 단계: 한 토큰 생성 시 가중치를 한 번씩 읽음
# 배치 크기가 작을수록 연산 대비 메모리 읽기 비율이 높아짐
batch_size = 1
flops_per_token = 2 * total_weight_bytes # 대략적인 FLOPs
# A100 FP16: 312 TFLOPS
compute_throughput = 312e12 # FLOPS
# 메모리 대역폭 기준 처리량
memory_bound_tps = memory_bandwidth_tbs * 1e12 / total_weight_bytes
# 연산 기준 처리량
compute_bound_tps = compute_throughput / flops_per_token
print(f"\n배치 크기 {batch_size}일 때:")
print(f"메모리 병목 처리량: {memory_bound_tps:.1f} 토큰/초")
print(f"연산 병목 처리량: {compute_bound_tps:.1f} 토큰/초")
print(f"실제 병목: {'메모리' if memory_bound_tps < compute_bound_tps else '연산'}")
analyze_memory_bandwidth()
1.3 추론 비용 분석
def estimate_inference_cost(
model_size_b: float,
tokens_per_request: int,
requests_per_day: int,
gpu_cost_per_hour: float = 3.0 # A100 시간당 가격 (USD)
):
"""추론 비용 추정"""
# 처리량 추정 (경험적 수치)
# 7B 모델: ~100 tok/s, 70B 모델: ~20 tok/s (A100 기준)
throughput_tps = 100 / (model_size_b / 7) ** 0.6
total_tokens_per_day = tokens_per_request * requests_per_day
seconds_needed = total_tokens_per_day / throughput_tps
hours_needed = seconds_needed / 3600
# GPU 수 (병렬 처리 고려)
# 일반적으로 1개 GPU로 처리하면
daily_cost = hours_needed * gpu_cost_per_hour
cost_per_1k_tokens = daily_cost / (total_tokens_per_day / 1000)
print(f"모델: {model_size_b}B 파라미터")
print(f"일일 요청: {requests_per_day:,}건")
print(f"요청당 토큰: {tokens_per_request}")
print(f"총 일일 토큰: {total_tokens_per_day:,}")
print(f"예상 처리량: {throughput_tps:.1f} 토큰/초")
print(f"필요 GPU 시간: {hours_needed:.2f}시간")
print(f"일일 비용: ${daily_cost:.2f}")
print(f"1K 토큰당 비용: ${cost_per_1k_tokens:.4f}")
# 예시
estimate_inference_cost(
model_size_b=7.0,
tokens_per_request=500,
requests_per_day=10000
)
2. KV Cache: 핵심 최적화 기술
2.1 KV Cache의 필요성
트랜스포머 어텐션에서 각 토큰은 이전 모든 토큰과 어텐션을 계산합니다. 이미 처리한 토큰을 재계산하지 않으려면 K와 V 행렬을 캐싱합니다.
import torch
import torch.nn as nn
import math
class MultiHeadAttentionWithKVCache(nn.Module):
"""KV Cache를 지원하는 멀티헤드 어텐션"""
def __init__(self, d_model: int, num_heads: int, max_seq_len: int = 4096):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_head = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
# KV Cache 초기화
self.register_buffer(
'k_cache',
torch.zeros(1, max_seq_len, num_heads, self.d_head)
)
self.register_buffer(
'v_cache',
torch.zeros(1, max_seq_len, num_heads, self.d_head)
)
self.cache_pos = 0
def forward(
self,
x: torch.Tensor,
use_cache: bool = True,
position: int = None
):
batch_size, seq_len, _ = x.shape
# Q, K, V 계산
q = self.W_q(x).reshape(batch_size, seq_len, self.num_heads, self.d_head)
k = self.W_k(x).reshape(batch_size, seq_len, self.num_heads, self.d_head)
v = self.W_v(x).reshape(batch_size, seq_len, self.num_heads, self.d_head)
if use_cache:
# 캐시에 현재 K, V 저장
start_pos = self.cache_pos if position is None else position
self.k_cache[:, start_pos:start_pos + seq_len] = k
self.v_cache[:, start_pos:start_pos + seq_len] = v
if position is None:
self.cache_pos += seq_len
# 캐시된 전체 K, V 사용
total_len = self.cache_pos if position is None else start_pos + seq_len
k = self.k_cache[:, :total_len]
v = self.v_cache[:, :total_len]
# 어텐션 계산
scale = math.sqrt(self.d_head)
# [batch, num_heads, seq_len, d_head] 형태로 변환
q = q.transpose(1, 2) # [B, H, S, D]
k = k.transpose(1, 2) # [B, H, T, D] (T: total cached length)
v = v.transpose(1, 2)
# 어텐션 스코어: [B, H, S, T]
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale
# 소프트맥스
attn_weights = torch.softmax(attn_scores, dim=-1)
# 가중 합: [B, H, S, D]
output = torch.matmul(attn_weights, v)
# 원래 형태로 변환
output = output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
output = self.W_o(output)
return output
def clear_cache(self):
"""캐시 초기화"""
self.k_cache.zero_()
self.v_cache.zero_()
self.cache_pos = 0
2.2 KV Cache 메모리 계산
def calculate_kv_cache_memory(
model_config: dict,
batch_size: int,
seq_len: int,
dtype_bytes: int = 2 # FP16
) -> dict:
"""KV Cache 메모리 사용량 계산"""
num_layers = model_config["num_layers"]
num_kv_heads = model_config.get("num_kv_heads", model_config["num_heads"])
head_dim = model_config["head_dim"]
# KV Cache 크기: 2 (K+V) * layers * kv_heads * head_dim * seq_len * dtype
kv_cache_bytes = (
2 * # K와 V
num_layers *
num_kv_heads *
head_dim *
seq_len *
batch_size *
dtype_bytes
)
return {
"kv_cache_bytes": kv_cache_bytes,
"kv_cache_mb": kv_cache_bytes / 1e6,
"kv_cache_gb": kv_cache_bytes / 1e9,
"per_token_bytes": kv_cache_bytes // seq_len,
}
# 모델별 KV Cache 비교
models = {
"Llama-2-7B": {
"num_layers": 32, "num_heads": 32,
"num_kv_heads": 32, "head_dim": 128
},
"Llama-2-13B": {
"num_layers": 40, "num_heads": 40,
"num_kv_heads": 40, "head_dim": 128
},
"Llama-2-70B (GQA)": {
"num_layers": 80, "num_heads": 64,
"num_kv_heads": 8, "head_dim": 128 # GQA: 8 KV 헤드
},
"Mistral-7B (GQA)": {
"num_layers": 32, "num_heads": 32,
"num_kv_heads": 8, "head_dim": 128 # GQA: 8 KV 헤드
},
}
print("KV Cache 메모리 사용량 (배치=1, seq=4096)")
print("=" * 70)
for name, config in models.items():
result = calculate_kv_cache_memory(config, batch_size=1, seq_len=4096)
print(f"{name:<25} {result['kv_cache_gb']:.2f} GB "
f"(토큰당 {result['per_token_bytes']:,} bytes)")
2.3 Grouped Query Attention (GQA)
GQA는 KV Cache를 줄이는 핵심 기술입니다. 여러 Query 헤드가 더 적은 수의 KV 헤드를 공유합니다.
import torch
import torch.nn as nn
import math
class GroupedQueryAttention(nn.Module):
"""Grouped Query Attention (GQA) 구현"""
def __init__(
self,
d_model: int,
num_q_heads: int,
num_kv_heads: int,
):
super().__init__()
assert num_q_heads % num_kv_heads == 0
self.d_model = d_model
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.num_groups = num_q_heads // num_kv_heads
self.d_head = d_model // num_q_heads
self.W_q = nn.Linear(d_model, num_q_heads * self.d_head, bias=False)
self.W_k = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
self.W_v = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
def forward(self, x: torch.Tensor, kv_cache=None):
batch_size, seq_len, _ = x.shape
# Q: [B, S, num_q_heads * d_head]
q = self.W_q(x).reshape(batch_size, seq_len, self.num_q_heads, self.d_head)
k = self.W_k(x).reshape(batch_size, seq_len, self.num_kv_heads, self.d_head)
v = self.W_v(x).reshape(batch_size, seq_len, self.num_kv_heads, self.d_head)
# KV Cache 업데이트
if kv_cache is not None:
k = torch.cat([kv_cache["k"], k], dim=1)
v = torch.cat([kv_cache["v"], v], dim=1)
new_kv_cache = {"k": k, "v": v}
total_len = k.size(1)
# [B, num_heads, S, d_head] 형태로
q = q.transpose(1, 2) # [B, Q_heads, S, d_head]
k = k.transpose(1, 2) # [B, KV_heads, T, d_head]
v = v.transpose(1, 2)
# GQA: KV 헤드를 Q 헤드 수만큼 반복
# k: [B, KV_heads, T, d_head] -> [B, Q_heads, T, d_head]
k = k.repeat_interleave(self.num_groups, dim=1)
v = v.repeat_interleave(self.num_groups, dim=1)
# 어텐션 계산
scale = math.sqrt(self.d_head)
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale
attn_weights = torch.softmax(attn_scores, dim=-1)
output = torch.matmul(attn_weights, v)
output = output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
return self.W_o(output), new_kv_cache
# MHA vs GQA vs MQA 메모리 비교
def compare_attention_variants():
"""어텐션 변형 KV Cache 메모리 비교"""
# 70B 모델 기준 (Llama-2-70B)
num_layers = 80
d_head = 128
seq_len = 4096
batch_size = 1
dtype_bytes = 2 # FP16
variants = {
"MHA (32 KV heads)": 32,
"GQA (8 KV heads)": 8,
"MQA (1 KV head)": 1,
}
print("70B 모델 어텐션 변형별 KV Cache 비교")
print(f"(seq_len={seq_len}, batch={batch_size})")
print("=" * 55)
for name, num_kv_heads in variants.items():
kv_bytes = 2 * num_layers * num_kv_heads * d_head * seq_len * batch_size * dtype_bytes
kv_gb = kv_bytes / 1e9
print(f"{name:<25} {kv_gb:.2f} GB")
compare_attention_variants()
2.4 DeepSeek MLA (Multi-Head Latent Attention)
DeepSeek-V2에서 도입된 MLA는 KV Cache를 저차원 잠재 벡터로 압축합니다.
class MultiHeadLatentAttention(nn.Module):
"""
DeepSeek MLA - KV Cache를 저차원 잠재 벡터로 압축
핵심 아이디어:
- KV를 고차원에 저장하는 대신, 저차원 잠재 벡터 c_kv를 저장
- c_kv에서 K, V를 복원 (업프로젝션)
- KV Cache 크기: num_layers * kv_lora_rank * seq_len
(vs 기존: 2 * num_layers * num_kv_heads * d_head * seq_len)
"""
def __init__(
self,
d_model: int = 5120,
num_heads: int = 128,
kv_lora_rank: int = 512, # 저차원 잠재 차원
qk_nope_head_dim: int = 128,
qk_rope_head_dim: int = 64,
v_head_dim: int = 128,
):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.kv_lora_rank = kv_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
# Q 프로젝션 (LoRA 스타일)
self.q_a_proj = nn.Linear(d_model, 1536, bias=False) # 다운프로젝션
self.q_b_proj = nn.Linear(1536, num_heads * (qk_nope_head_dim + qk_rope_head_dim), bias=False)
# KV 다운프로젝션: d_model -> kv_lora_rank
# 이것만 KV Cache에 저장!
self.kv_a_proj = nn.Linear(
d_model,
kv_lora_rank + qk_rope_head_dim,
bias=False
)
# KV 업프로젝션: kv_lora_rank -> K, V
self.kv_b_proj = nn.Linear(
kv_lora_rank,
num_heads * (qk_nope_head_dim + v_head_dim),
bias=False
)
self.o_proj = nn.Linear(num_heads * v_head_dim, d_model, bias=False)
def forward(self, x: torch.Tensor, compressed_kv_cache=None):
"""
Args:
x: [batch, seq, d_model]
compressed_kv_cache: [batch, cache_len, kv_lora_rank + rope_dim]
"""
batch_size, seq_len, _ = x.shape
# Q 계산
q = self.q_b_proj(self.q_a_proj(x))
# KV 압축 (이 결과만 캐시에 저장)
kv_compressed = self.kv_a_proj(x) # [B, S, kv_lora_rank + rope_dim]
# KV Cache 업데이트
if compressed_kv_cache is not None:
kv_compressed_total = torch.cat([compressed_kv_cache, kv_compressed], dim=1)
else:
kv_compressed_total = kv_compressed
# 캐시된 압축 KV에서 실제 K, V 복원 (업프로젝션)
kv_content = kv_compressed_total[:, :, :self.kv_lora_rank] # rope 제외
kv_full = self.kv_b_proj(kv_content) # [B, T, num_heads * (nope + v_dim)]
# 최종 어텐션 계산 (생략)
return None, kv_compressed
# KV Cache 크기 비교 (DeepSeek-V2 기준)
def compare_mla_vs_mha():
"""MLA vs MHA KV Cache 비교"""
seq_len = 4096
dtype_bytes = 2 # BF16
num_layers = 60 # DeepSeek-V2
# MHA (기존)
num_heads = 128
head_dim = 128
mha_kv_gb = 2 * num_layers * num_heads * head_dim * seq_len * dtype_bytes / 1e9
# MLA (DeepSeek-V2)
kv_lora_rank = 512
rope_dim = 64
mla_kv_gb = (kv_lora_rank + rope_dim) * num_layers * seq_len * dtype_bytes / 1e9
print(f"MHA KV Cache: {mha_kv_gb:.2f} GB")
print(f"MLA KV Cache: {mla_kv_gb:.2f} GB")
print(f"절감 비율: {mha_kv_gb / mla_kv_gb:.1f}x")
compare_mla_vs_mha()
3. PagedAttention: vLLM의 핵심 혁신
3.1 기존 KV Cache의 문제점
기존 LLM 서빙 시스템은 요청마다 최대 시퀀스 길이를 위한 메모리를 미리 할당합니다.
요청 1: [PROMPT=200 tokens] [KV_CACHE=최대 2048-200=1848 tokens 예약] → 내부 단편화
요청 2: [PROMPT=100 tokens] [KV_CACHE=1948 tokens 예약]
요청 3: 메모리 부족으로 대기 (외부 단편화)
이로 인해 실제 GPU 메모리의 60~80%가 낭비됩니다.
3.2 PagedAttention 원리
OS의 가상 메모리에서 영감을 받아 KV Cache를 고정 크기의 물리 블록으로 관리합니다.
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Set
import torch
@dataclass
class PhysicalBlock:
"""물리 메모리 블록"""
block_id: int
block_size: int # 블록에 저장 가능한 토큰 수 (예: 16)
device: str = "cuda"
ref_count: int = 0 # 참조 카운트 (CoW를 위해)
def __post_init__(self):
# 실제 KV 텐서 할당
# [2, num_layers, block_size, num_heads, head_dim]
pass
@dataclass
class LogicalBlock:
"""논리 블록 (요청과 매핑)"""
physical_block_id: int
num_filled: int = 0 # 현재 채워진 토큰 수
class PagedKVCacheManager:
"""PagedAttention KV Cache 관리자"""
def __init__(
self,
num_physical_blocks: int,
block_size: int,
num_layers: int,
num_kv_heads: int,
head_dim: int,
device: str = "cuda"
):
self.block_size = block_size
self.num_layers = num_layers
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.device = device
# 물리 블록 풀 초기화
self.free_blocks: List[int] = list(range(num_physical_blocks))
self.all_blocks: Dict[int, PhysicalBlock] = {
i: PhysicalBlock(block_id=i, block_size=block_size)
for i in range(num_physical_blocks)
}
# 요청별 논리 블록 테이블
self.block_tables: Dict[int, List[LogicalBlock]] = {}
# 실제 KV Cache 텐서
# [num_blocks, 2, num_layers, block_size, num_kv_heads, head_dim]
self.kv_cache = torch.zeros(
num_physical_blocks, 2, num_layers, block_size, num_kv_heads, head_dim,
dtype=torch.float16,
device=device
)
def allocate_blocks_for_request(self, request_id: int, num_tokens: int):
"""요청에 필요한 블록 할당"""
num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size
if len(self.free_blocks) < num_blocks_needed:
raise RuntimeError(f"메모리 부족: {num_blocks_needed} 블록 필요, {len(self.free_blocks)} 가용")
logical_blocks = []
for i in range(num_blocks_needed):
physical_id = self.free_blocks.pop(0)
self.all_blocks[physical_id].ref_count = 1
logical_blocks.append(
LogicalBlock(physical_block_id=physical_id)
)
self.block_tables[request_id] = logical_blocks
print(f"요청 {request_id}: {num_blocks_needed} 블록 할당, "
f"잔여 블록: {len(self.free_blocks)}")
def append_token(self, request_id: int, layer: int, token_pos: int, k: torch.Tensor, v: torch.Tensor):
"""새 토큰의 KV를 캐시에 추가"""
block_idx = token_pos // self.block_size
token_in_block = token_pos % self.block_size
logical_block = self.block_tables[request_id][block_idx]
physical_id = logical_block.physical_block_id
# KV Cache에 저장
self.kv_cache[physical_id, 0, layer, token_in_block] = k # K
self.kv_cache[physical_id, 1, layer, token_in_block] = v # V
logical_block.num_filled = token_in_block + 1
def get_physical_block_ids(self, request_id: int) -> List[int]:
"""요청의 물리 블록 ID 목록 반환"""
return [lb.physical_block_id for lb in self.block_tables[request_id]]
def free_request(self, request_id: int):
"""요청 완료 후 블록 해제"""
if request_id in self.block_tables:
for logical_block in self.block_tables[request_id]:
phys_id = logical_block.physical_block_id
self.all_blocks[phys_id].ref_count -= 1
if self.all_blocks[phys_id].ref_count == 0:
self.free_blocks.append(phys_id)
del self.block_tables[request_id]
def copy_on_write(self, src_request_id: int, dst_request_id: int):
"""Prefix Caching을 위한 Copy-on-Write"""
src_blocks = self.block_tables[src_request_id]
dst_blocks = []
for logical_block in src_blocks:
phys_id = logical_block.physical_block_id
# 참조 카운트 증가 (실제로는 쓰기 시에만 복사)
self.all_blocks[phys_id].ref_count += 1
dst_blocks.append(
LogicalBlock(physical_block_id=phys_id, num_filled=logical_block.num_filled)
)
self.block_tables[dst_request_id] = dst_blocks
# 사용 예시
manager = PagedKVCacheManager(
num_physical_blocks=1000,
block_size=16,
num_layers=32,
num_kv_heads=32,
head_dim=128
)
# 3개의 요청 처리
manager.allocate_blocks_for_request(request_id=1, num_tokens=200)
manager.allocate_blocks_for_request(request_id=2, num_tokens=500)
manager.allocate_blocks_for_request(request_id=3, num_tokens=100)
# 요청 1 완료 후 해제
manager.free_request(request_id=1)
print(f"\n요청 1 완료 후 가용 블록: {len(manager.free_blocks)}")
4. Continuous Batching
4.1 정적 배치의 문제점
기존 배치 처리 방식은 요청들이 모두 완료될 때까지 기다립니다.
시간 t=0: [요청A: 500토큰] [요청B: 100토큰] [요청C: 300토큰]
시간 t=1: 요청B 완료, 하지만 A, C 대기 중이므로 GPU에 여유 있어도 새 요청 불가
시간 t=2: 요청C 완료
시간 t=3: 요청A 완료 → 이제야 새 배치 시작!
4.2 Continuous Batching (Iteration-level Scheduling)
from typing import List, Optional, Tuple
from dataclasses import dataclass
import asyncio
import torch
from queue import Queue
import threading
@dataclass
class Request:
"""추론 요청"""
request_id: str
input_ids: List[int]
max_new_tokens: int
generated_ids: List[int] = None
is_finished: bool = False
def __post_init__(self):
self.generated_ids = []
class ContinuousBatchingScheduler:
"""Continuous Batching 스케줄러"""
def __init__(
self,
max_batch_size: int = 32,
max_seq_len: int = 4096
):
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
self.waiting_queue: List[Request] = []
self.running_requests: List[Request] = []
self.finished_requests: List[Request] = []
def add_request(self, request: Request):
"""새 요청 추가"""
self.waiting_queue.append(request)
def _can_add_request(self, request: Request) -> bool:
"""배치에 요청 추가 가능 여부 확인 (메모리 체크)"""
current_batch_size = len(self.running_requests) + 1
if current_batch_size > self.max_batch_size:
return False
# KV Cache 메모리 체크 (단순화)
total_tokens = sum(
len(r.input_ids) + len(r.generated_ids)
for r in self.running_requests
) + len(request.input_ids)
return total_tokens < self.max_seq_len * self.max_batch_size
def schedule_iteration(self) -> Tuple[List[Request], List[str]]:
"""
한 iteration을 위한 배치 스케줄링
Returns:
(실행할 요청들, 완료된 요청 ID들)
"""
completed_ids = []
# 완료된 요청 처리
still_running = []
for req in self.running_requests:
if req.is_finished:
self.finished_requests.append(req)
completed_ids.append(req.request_id)
else:
still_running.append(req)
self.running_requests = still_running
# 대기 중인 요청을 배치에 추가 (핵심: 빈 슬롯을 즉시 채움)
while self.waiting_queue and self._can_add_request(self.waiting_queue[0]):
new_request = self.waiting_queue.pop(0)
self.running_requests.append(new_request)
print(f"배치에 요청 {new_request.request_id} 추가 "
f"(현재 배치 크기: {len(self.running_requests)})")
return self.running_requests, completed_ids
def simulate_one_step(self, model_forward_fn):
"""한 단계 시뮬레이션"""
active_requests, completed = self.schedule_iteration()
if not active_requests:
return []
# 현재 배치의 입력 준비
# Prefill 요청: input_ids만 있는 경우
# Decode 요청: 이전 KV Cache가 있는 경우
batch_input_ids = []
for req in active_requests:
if len(req.generated_ids) == 0:
# Prefill
batch_input_ids.append(req.input_ids)
else:
# Decode (마지막 생성 토큰만)
batch_input_ids.append([req.generated_ids[-1]])
# 모델 실행 (실제로는 PagedAttention으로 처리)
outputs = model_forward_fn(batch_input_ids)
# 다음 토큰 처리
for req, next_token_id in zip(active_requests, outputs):
req.generated_ids.append(next_token_id)
# 종료 조건 확인
if (next_token_id == 2 or # EOS token
len(req.generated_ids) >= req.max_new_tokens):
req.is_finished = True
return completed
# 처리량 비교 시뮬레이션
def simulate_throughput_comparison():
"""정적 배치 vs Continuous Batching 처리량 비교"""
import random
requests = [
Request(
request_id=str(i),
input_ids=list(range(random.randint(50, 200))),
max_new_tokens=random.randint(50, 500)
)
for i in range(20)
]
# 정적 배치: 모든 요청이 완료될 때까지 기다림
max_tokens_static = max(r.max_new_tokens for r in requests)
total_iterations_static = max_tokens_static * 4 # 4개씩 배치
# Continuous Batching: 완료되자마자 새 요청 추가
total_tokens = sum(r.max_new_tokens for r in requests)
total_iterations_cb = total_tokens # 대략적인 추정
print(f"정적 배치 예상 iteration 수: {total_iterations_static}")
print(f"Continuous Batching 예상 iteration 수: {total_iterations_cb}")
print(f"처리량 향상: {total_iterations_static / total_iterations_cb:.2f}x")
5. Speculative Decoding
5.1 아이디어: 초안 + 검증
Speculative Decoding의 핵심은 작은 드래프트 모델이 여러 토큰을 빠르게 생성하고, 큰 검증 모델이 한번에 검증하는 것입니다.
기존: [큰 모델] → 토큰1 → 토큰2 → 토큰3 → 토큰4 → 토큰5
투기적: [작은 모델] → (토큰1, 토큰2, 토큰3, 토큰4, 토큰5)를 병렬 생성
[큰 모델] → 5개 토큰을 한번에 검증 (Prefill처럼 병렬!)
수락된 토큰들만 사용
5.2 수락률 기반 속도 향상 분석
import numpy as np
import torch
from typing import List, Tuple
def speculative_decode_step(
draft_model,
target_model,
input_ids: torch.Tensor,
draft_steps: int = 4,
temperature: float = 1.0
) -> Tuple[torch.Tensor, int, int]:
"""
Speculative Decoding 한 단계
Returns:
(생성된 토큰들, 수락된 토큰 수, 드래프트 토큰 수)
"""
batch_size = input_ids.size(0)
# 1. 드래프트 모델로 후보 토큰 생성
draft_tokens = []
draft_probs = []
current_ids = input_ids.clone()
for _ in range(draft_steps):
with torch.no_grad():
draft_output = draft_model(current_ids)
draft_logits = draft_output.logits[:, -1, :] # [B, vocab_size]
# 드래프트 확률 계산
if temperature > 0:
draft_prob = torch.softmax(draft_logits / temperature, dim=-1)
else:
draft_prob = torch.zeros_like(draft_logits)
draft_prob.scatter_(1, draft_logits.argmax(dim=-1, keepdim=True), 1.0)
# 드래프트 토큰 샘플링
draft_token = torch.multinomial(draft_prob, num_samples=1) # [B, 1]
draft_tokens.append(draft_token)
draft_probs.append(draft_prob)
# 다음 스텝을 위해 토큰 추가
current_ids = torch.cat([current_ids, draft_token], dim=1)
# 드래프트 토큰들을 하나의 텐서로
draft_sequence = torch.cat(draft_tokens, dim=1) # [B, draft_steps]
candidate_ids = torch.cat([input_ids, draft_sequence], dim=1)
# 2. 검증 모델로 드래프트 토큰 한번에 검증
with torch.no_grad():
target_output = target_model(candidate_ids)
target_logits = target_output.logits[:, input_ids.size(1) - 1:-1, :] # [B, draft_steps, vocab_size]
# 검증 모델의 확률
if temperature > 0:
target_probs = torch.softmax(target_logits / temperature, dim=-1)
else:
target_probs = torch.zeros_like(target_logits)
target_probs.scatter_(2, target_logits.argmax(dim=-1, keepdim=True), 1.0)
# 3. 각 드래프트 토큰 수락/거부 결정
accepted_tokens = []
num_accepted = 0
for step in range(draft_steps):
token = draft_sequence[:, step] # [B]
# 수락 확률 계산: min(1, p_target / p_draft)
p_draft = draft_probs[step].gather(1, token.unsqueeze(1)).squeeze(1)
p_target = target_probs[:, step, :].gather(1, token.unsqueeze(1)).squeeze(1)
acceptance_prob = torch.clamp(p_target / (p_draft + 1e-8), max=1.0)
# 랜덤 수락/거부
random_val = torch.rand_like(acceptance_prob)
accepted = random_val < acceptance_prob # [B]
if not accepted.all():
# 첫 번째 거부된 토큰에서 중단
break
accepted_tokens.append(token)
num_accepted += 1
# 4. 마지막 토큰: 검증 모델이 생성 (또는 수정된 분포에서 샘플링)
last_target_logits = target_output.logits[:, input_ids.size(1) + num_accepted - 1, :]
if temperature > 0:
last_prob = torch.softmax(last_target_logits / temperature, dim=-1)
else:
last_prob = torch.zeros_like(last_target_logits)
last_prob.scatter_(1, last_target_logits.argmax(dim=-1, keepdim=True), 1.0)
# 분포 수정 (거부된 경우)
if num_accepted < draft_steps:
# max(0, p_target - p_draft) 사용
correction = torch.clamp(
last_prob - draft_probs[num_accepted],
min=0
)
correction = correction / (correction.sum(dim=-1, keepdim=True) + 1e-8)
last_token = torch.multinomial(correction, num_samples=1)
else:
last_token = torch.multinomial(last_prob, num_samples=1)
accepted_tokens.append(last_token.squeeze(1))
final_tokens = torch.stack(accepted_tokens, dim=1)
return final_tokens, num_accepted, draft_steps
def analyze_speedup(acceptance_rate: float, draft_steps: int = 4) -> dict:
"""수락률에 따른 속도 향상 분석"""
# 기댓값 계산
# E[accepted tokens] = sum_{k=0}^{K} alpha^k = (1 - alpha^{K+1}) / (1 - alpha)
# 여기서 alpha = acceptance_rate, K = draft_steps
expected_accepted = sum(
acceptance_rate ** k for k in range(draft_steps + 1)
)
# 실제 속도 향상 (드래프트 모델 비용 포함)
# 드래프트 모델이 검증 모델의 1/10 크기라 가정
draft_model_ratio = 0.1
# 시간: 드래프트 K 스텝 + 검증 1 스텝
# 기존: K+1 스텝
# 투기적: K * draft_ratio + 1 스텝 (검증)
steps_with_speculative = draft_steps * draft_model_ratio + 1
expected_tokens_with_speculative = expected_accepted
speedup = expected_tokens_with_speculative / steps_with_speculative
return {
"acceptance_rate": acceptance_rate,
"draft_steps": draft_steps,
"expected_accepted_tokens": expected_accepted,
"speedup": speedup
}
# 수락률별 속도 향상 출력
print("Speculative Decoding 수락률별 속도 향상 (드래프트 K=4)")
print("=" * 55)
for alpha in [0.5, 0.6, 0.7, 0.8, 0.9, 0.95]:
result = analyze_speedup(alpha, draft_steps=4)
print(f"수락률 {alpha:.0%}: 기대 수락 {result['expected_accepted_tokens']:.2f}토큰, "
f"속도향상 {result['speedup']:.2f}x")
5.3 Medusa: 멀티 드래프트 헤드
import torch
import torch.nn as nn
class MedusaHead(nn.Module):
"""
Medusa: 단일 모델에 여러 드래프트 헤드 추가
각 헤드가 미래 토큰을 예측:
- Head 1: t+1 예측
- Head 2: t+2 예측
- Head N: t+N 예측
"""
def __init__(
self,
hidden_size: int,
vocab_size: int,
num_heads: int = 4,
hidden_layers: int = 1
):
super().__init__()
self.num_heads = num_heads
# 각 미래 토큰 위치를 위한 독립적인 헤드
self.heads = nn.ModuleList([
nn.Sequential(
*[nn.Linear(hidden_size, hidden_size, bias=False),
nn.SiLU()] * hidden_layers,
nn.Linear(hidden_size, vocab_size, bias=False)
)
for _ in range(num_heads)
])
def forward(self, hidden_states: torch.Tensor):
"""
Args:
hidden_states: [batch, seq, hidden_size] - 기본 모델의 마지막 히든 스테이트
Returns:
List of logits for each future position [batch, seq, vocab]
"""
return [head(hidden_states) for head in self.heads]
class MedusaModel(nn.Module):
"""Medusa 전체 모델"""
def __init__(self, base_model, vocab_size: int, num_medusa_heads: int = 4):
super().__init__()
self.base_model = base_model
hidden_size = base_model.config.hidden_size
self.medusa_heads = MedusaHead(
hidden_size=hidden_size,
vocab_size=vocab_size,
num_heads=num_medusa_heads
)
def forward(self, input_ids: torch.Tensor, use_medusa: bool = False):
# 기본 모델 실행
base_output = self.base_model(
input_ids,
output_hidden_states=True
)
base_logits = base_output.logits
if not use_medusa:
return base_logits, None
# Medusa 헤드로 미래 토큰 예측
last_hidden_state = base_output.hidden_states[-1]
medusa_logits = self.medusa_heads(last_hidden_state)
return base_logits, medusa_logits
def generate_with_medusa(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 100,
temperature: float = 1.0,
medusa_choices: int = 16, # 후보 토큰 수
threshold: float = 0.09 # 수락 임계값
):
"""Medusa를 이용한 빠른 생성"""
current_ids = input_ids.clone()
all_accepted = []
while len(all_accepted) < max_new_tokens:
# Medusa 헤드로 후보 토큰 예측
base_logits, medusa_logits = self.forward(current_ids, use_medusa=True)
# 각 위치에서 상위 후보들 선택
candidates = []
base_probs = torch.softmax(base_logits[:, -1, :] / temperature, dim=-1)
top_tokens = torch.topk(base_probs, medusa_choices)[1]
for head_logits in medusa_logits:
head_probs = torch.softmax(head_logits[:, -1, :] / temperature, dim=-1)
candidates.append(torch.topk(head_probs, medusa_choices)[1])
# 트리 어텐션으로 후보들 검증 (단순화)
# 실제로는 트리 마스크를 이용한 효율적인 검증
best_token = top_tokens[0, 0]
all_accepted.append(best_token.item())
current_ids = torch.cat([current_ids, best_token.unsqueeze(0).unsqueeze(0)], dim=1)
if best_token.item() == 2: # EOS
break
return all_accepted
6. FlashAttention: 메모리 효율적인 어텐션
6.1 표준 어텐션의 HBM 병목
표준 어텐션은 HBM(High Bandwidth Memory)에 중간 결과를 자주 쓰고 읽습니다.
표준 Attention 메모리 연산:
1. Q, K를 HBM에서 읽음 → 읽기: O(N * d)
2. S = Q @ K.T 계산 → 쓰기: O(N^2) ← 병목!
3. S를 HBM에서 읽어 소프트맥스 → 읽기: O(N^2)
4. P = softmax(S) 저장 → 쓰기: O(N^2)
5. P를 읽어 P @ V 계산 → 읽기: O(N^2)
6. 최종 결과 저장 → 쓰기: O(N * d)
총 HBM 접근: O(N^2) (시퀀스 길이 제곱에 비례!)
6.2 FlashAttention의 타일링 전략
import torch
import math
def flash_attention_v1(Q, K, V, block_size=64):
"""
FlashAttention v1 단순화 구현
타일링을 이용하여 전체 어텐션 행렬을 HBM에 저장하지 않음
핵심: Online Softmax 알고리즘으로 블록 단위 처리
"""
batch_size, num_heads, seq_len, d_head = Q.shape
scale = 1.0 / math.sqrt(d_head)
Q = Q * scale
# 출력 텐서 초기화 (SRAM에 유지)
O = torch.zeros_like(Q)
L = torch.zeros(batch_size, num_heads, seq_len, 1, device=Q.device) # 소프트맥스 분모
M = torch.full((batch_size, num_heads, seq_len, 1), float('-inf'), device=Q.device) # 최대값
num_blocks = (seq_len + block_size - 1) // block_size
for j in range(num_blocks):
# K, V 블록 로드 (HBM → SRAM)
k_start = j * block_size
k_end = min((j + 1) * block_size, seq_len)
K_j = K[:, :, k_start:k_end, :]
V_j = V[:, :, k_start:k_end, :]
for i in range(num_blocks):
# Q 블록 로드
q_start = i * block_size
q_end = min((i + 1) * block_size, seq_len)
Q_i = Q[:, :, q_start:q_end, :]
O_i = O[:, :, q_start:q_end, :]
L_i = L[:, :, q_start:q_end, :]
M_i = M[:, :, q_start:q_end, :]
# 어텐션 스코어 계산 (SRAM에서)
S_ij = torch.matmul(Q_i, K_j.transpose(-2, -1)) # [B, H, Br, Bc]
# Online Softmax 업데이트
M_ij_new = torch.maximum(M_i, S_ij.max(dim=-1, keepdim=True)[0])
P_ij = torch.exp(S_ij - M_ij_new)
L_ij_new = torch.exp(M_i - M_ij_new) * L_i + P_ij.sum(dim=-1, keepdim=True)
# 출력 업데이트 (재스케일링)
O_i_new = (
torch.exp(M_i - M_ij_new) * O_i +
torch.matmul(P_ij, V_j)
)
# 블록 결과를 HBM에 저장
O[:, :, q_start:q_end, :] = O_i_new
L[:, :, q_start:q_end, :] = L_ij_new
M[:, :, q_start:q_end, :] = M_ij_new
# 최종 정규화
O = O / L
return O
def compare_attention_implementations():
"""FlashAttention vs 표준 어텐션 비교"""
batch_size = 2
num_heads = 32
seq_len = 4096
d_head = 128
Q = torch.randn(batch_size, num_heads, seq_len, d_head, device='cuda', dtype=torch.float16)
K = torch.randn_like(Q)
V = torch.randn_like(Q)
# PyTorch SDPA (FlashAttention 2 구현 포함)
import torch.nn.functional as F
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False
):
flash_output = F.scaled_dot_product_attention(Q, K, V)
# 표준 어텐션
scale = 1.0 / math.sqrt(d_head)
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
attn_weights = torch.softmax(attn_scores, dim=-1)
standard_output = torch.matmul(attn_weights, V)
# 결과 비교
max_diff = (flash_output - standard_output).abs().max().item()
print(f"FlashAttention vs 표준 어텐션 최대 차이: {max_diff:.6f}")
# 메모리 사용량 비교
standard_attn_matrix_size = batch_size * num_heads * seq_len * seq_len * 2 # FP16
print(f"표준 어텐션 행렬 메모리: {standard_attn_matrix_size / 1e9:.2f} GB")
print(f"FlashAttention 행렬 메모리: ~0 GB (타일링으로 저장 불필요)")
6.3 PyTorch SDPA 사용법
import torch
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel
def modern_attention(q, k, v, is_causal=True, dropout_p=0.0):
"""
PyTorch 2.0+ scaled_dot_product_attention 사용
FlashAttention 2/3를 자동으로 선택
"""
# 자동 백엔드 선택 (Flash, Memory-efficient, Math)
output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=dropout_p,
is_causal=is_causal, # 인과적 마스킹
scale=None # None이면 1/sqrt(d_head) 사용
)
return output
# 특정 백엔드 강제 선택
def attention_with_flash_backend(q, k, v):
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
return F.scaled_dot_product_attention(q, k, v, is_causal=True)
def attention_with_efficient_backend(q, k, v):
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
return F.scaled_dot_product_attention(q, k, v, is_causal=True)
# FlashAttention 버전별 특징
flash_versions = {
"FlashAttention 1": {
"paper": "arXiv:2205.14135",
"key_innovation": "타일링 + Online Softmax",
"memory": "O(N) (어텐션 행렬 저장 불필요)",
"speedup": "2-4x vs 표준 어텐션"
},
"FlashAttention 2": {
"paper": "arXiv:2307.08691",
"key_innovation": "작업 분할 최적화, FP16/BF16 지원",
"memory": "O(N)",
"speedup": "5-9x vs 표준 어텐션 (H100에서)"
},
"FlashAttention 3": {
"paper": "arXiv:2407.08608",
"key_innovation": "H100 특화, FP8 지원, 비동기 파이프라인",
"memory": "O(N)",
"speedup": "1.5-2x vs FA2 (H100에서)"
},
}
for name, info in flash_versions.items():
print(f"\n{name}")
for k, v in info.items():
print(f" {k}: {v}")
7. 멀티 GPU 추론
7.1 Tensor Parallelism
가중치 행렬을 여러 GPU에 분산하여 각 GPU가 일부를 처리합니다.
import torch
import torch.distributed as dist
class TensorParallelLinear(torch.nn.Module):
"""
Tensor Parallel Linear 레이어
컬럼 분할 방식 (Column Parallel)
"""
def __init__(
self,
in_features: int,
out_features: int,
world_size: int,
rank: int
):
super().__init__()
self.world_size = world_size
self.rank = rank
# 각 GPU는 out_features // world_size 개의 출력 뉴런을 담당
self.local_out_features = out_features // world_size
self.weight = torch.nn.Parameter(
torch.randn(self.local_out_features, in_features) / (in_features ** 0.5)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 로컬 계산
local_output = torch.nn.functional.linear(x, self.weight)
# All-gather로 모든 GPU의 출력을 합침
# (실제로는 분산 환경에서)
# dist.all_gather(output_list, local_output)
return local_output
def setup_tensor_parallel_llm(model_name: str, tp_size: int):
"""
Tensor Parallel LLM 설정 예시 (vLLM 방식)
vLLM 내부적으로 이 방식 사용
"""
from vllm import LLM, SamplingParams
# vLLM의 tensor_parallel_size 설정
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size, # GPU 수
gpu_memory_utilization=0.9
)
return llm
7.2 vLLM 완전 활용
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
import asyncio
import time
# vLLM 기본 사용
def vllm_basic_usage():
"""vLLM 기본 사용법"""
llm = LLM(
model="meta-llama/Llama-2-7b-hf",
tensor_parallel_size=1, # GPU 수
gpu_memory_utilization=0.90, # GPU 메모리 사용률
max_model_len=4096, # 최대 시퀀스 길이
quantization=None, # "awq", "gptq", "squeezellm"
dtype="auto", # "float16", "bfloat16"
max_num_seqs=256, # 최대 동시 시퀀스 수
enable_prefix_caching=True, # 프리픽스 캐싱 활성화
use_v2_block_manager=True, # PagedAttention v2
)
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=200,
presence_penalty=0.0,
frequency_penalty=0.0,
)
prompts = [
"Explain quantum computing in simple terms",
"What is the future of artificial intelligence?",
"How does the human brain work?",
]
# 배치 추론
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print(f"Prompt: {output.prompt[:50]}...")
print(f"Output: {output.outputs[0].text[:100]}...")
print(f"Tokens generated: {len(output.outputs[0].token_ids)}")
print()
return outputs
# vLLM 비동기 서버
async def vllm_async_server():
"""vLLM 비동기 엔진 사용"""
engine_args = AsyncEngineArgs(
model="meta-llama/Llama-2-7b-hf",
tensor_parallel_size=1,
gpu_memory_utilization=0.90,
max_model_len=4096,
enable_prefix_caching=True,
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
async def generate_stream(prompt: str, request_id: str):
sampling_params = SamplingParams(
temperature=0.8,
max_tokens=200
)
full_text = ""
async for output in engine.generate(prompt, sampling_params, request_id):
if output.outputs:
delta = output.outputs[0].text[len(full_text):]
full_text = output.outputs[0].text
if delta:
print(f"[{request_id}] {delta}", end="", flush=True)
if output.finished:
print(f"\n[{request_id}] 완료")
# 여러 요청 동시 처리
await asyncio.gather(
generate_stream("What is AI?", "req_1"),
generate_stream("Explain machine learning", "req_2"),
generate_stream("What is deep learning?", "req_3"),
)
8. 추론 엔진 비교
8.1 주요 추론 엔진 특징
| 엔진 | 개발사 | 핵심 기능 | 적합한 용도 |
|---|---|---|---|
| vLLM | UC Berkeley | PagedAttention, Continuous Batching | 범용 LLM 서빙 |
| TGI | HuggingFace | Flash Attention 2, Speculative | HF 모델 서빙 |
| TensorRT-LLM | NVIDIA | NVIDIA GPU 최적화, FP8 | NVIDIA 최대 성능 |
| DeepSpeed-MII | Microsoft | ZeRO 추론, 초대형 모델 | 다중 GPU 대형 모델 |
| llama.cpp | Georgi Gerganov | CPU 최적화, GGUF | 로컬 실행 |
8.2 벤치마크 비교
import subprocess
import json
import time
import requests
def benchmark_vllm_server(
model: str,
num_requests: int = 100,
max_tokens: int = 100,
concurrency: int = 10
):
"""vLLM 서버 벤치마크"""
results = {
"total_requests": num_requests,
"concurrency": concurrency,
"latencies": [],
"ttfts": [],
"throughputs": []
}
import asyncio
import aiohttp
async def send_request(session, prompt, request_id):
start = time.perf_counter()
first_token_time = None
payload = {
"model": model,
"prompt": prompt,
"max_tokens": max_tokens,
"stream": True,
"temperature": 0.0
}
async with session.post(
"http://localhost:8000/v1/completions",
json=payload
) as response:
async for line in response.content:
if line.startswith(b"data: "):
data = line[6:].decode()
if data.strip() == "[DONE]":
break
if first_token_time is None:
first_token_time = time.perf_counter() - start
end = time.perf_counter()
return {
"latency": end - start,
"ttft": first_token_time,
}
async def run_benchmark():
prompts = [
f"Tell me about topic number {i}." for i in range(num_requests)
]
start = time.perf_counter()
async with aiohttp.ClientSession() as session:
# 동시 요청
tasks = []
for i, prompt in enumerate(prompts):
if len(tasks) >= concurrency:
done, tasks = await asyncio.wait(
tasks, return_when=asyncio.FIRST_COMPLETED
)
for task in done:
result = await task
results["latencies"].append(result["latency"])
if result["ttft"]:
results["ttfts"].append(result["ttft"])
tasks.add(asyncio.ensure_future(
send_request(session, prompt, i)
))
# 남은 작업 처리
for coro in asyncio.as_completed(tasks):
result = await coro
results["latencies"].append(result["latency"])
total_time = time.perf_counter() - start
total_tokens = num_requests * max_tokens
results["throughput"] = total_tokens / total_time
asyncio.run(run_benchmark())
# 통계 계산
import statistics
latencies = results["latencies"]
return {
"avg_latency_ms": statistics.mean(latencies) * 1000,
"p50_latency_ms": statistics.median(latencies) * 1000,
"p99_latency_ms": sorted(latencies)[int(len(latencies) * 0.99)] * 1000,
"avg_ttft_ms": statistics.mean(results["ttfts"]) * 1000 if results["ttfts"] else 0,
"throughput_tps": results.get("throughput", 0),
}
9. 프롬프트 캐싱
9.1 프리픽스 캐싱
동일한 시스템 프롬프트나 문서를 반복적으로 처리할 때 KV Cache를 재사용합니다.
from vllm import LLM, SamplingParams
def demonstrate_prefix_caching():
"""프리픽스 캐싱 효과 시연"""
llm = LLM(
model="meta-llama/Llama-2-7b-hf",
enable_prefix_caching=True, # 프리픽스 캐싱 활성화
max_model_len=4096,
)
# 긴 시스템 프롬프트 (모든 요청에 공통)
system_prompt = """You are a helpful AI assistant with expertise in:
- Python programming and software development
- Machine learning and deep learning
- Data science and statistics
- Cloud computing and DevOps
[... 긴 시스템 프롬프트 ...]""" * 10 # 1000+ 토큰
questions = [
"How do I optimize a Python loop?",
"What is gradient descent?",
"Explain containerization.",
"What is a neural network?",
]
sampling_params = SamplingParams(temperature=0.7, max_tokens=100)
# 첫 번째 배치: 캐시 없음 (콜드 스타트)
import time
cold_prompts = [f"{system_prompt}\n\nQuestion: {q}" for q in questions]
cold_start = time.time()
llm.generate(cold_prompts, sampling_params)
cold_time = time.time() - cold_start
# 두 번째 배치: 동일한 시스템 프롬프트 (캐시 히트!)
warm_start = time.time()
llm.generate(cold_prompts, sampling_params)
warm_time = time.time() - warm_start
print(f"첫 번째 (캐시 없음): {cold_time:.2f}초")
print(f"두 번째 (캐시 히트): {warm_time:.2f}초")
print(f"속도 향상: {cold_time / warm_time:.2f}x")
def radix_tree_prefix_cache():
"""Radix Tree 기반 프리픽스 캐시 구현"""
class RadixNode:
def __init__(self):
self.children: dict = {}
self.kv_cache_block_id: int = None
class RadixTreeCache:
"""토큰 시퀀스를 Radix Tree로 관리하여 공통 프리픽스 KV 캐시 공유"""
def __init__(self):
self.root = RadixNode()
self.cache_hits = 0
self.cache_misses = 0
def insert(self, token_ids: list, block_id: int):
"""토큰 시퀀스와 해당 KV Cache 블록 ID 삽입"""
node = self.root
for token_id in token_ids:
if token_id not in node.children:
node.children[token_id] = RadixNode()
node = node.children[token_id]
node.kv_cache_block_id = block_id
def lookup(self, token_ids: list) -> tuple:
"""주어진 토큰 시퀀스의 최장 일치 프리픽스 찾기"""
node = self.root
matched_len = 0
last_block_id = None
for i, token_id in enumerate(token_ids):
if token_id in node.children:
node = node.children[token_id]
matched_len = i + 1
if node.kv_cache_block_id is not None:
last_block_id = node.kv_cache_block_id
else:
break
if last_block_id is not None:
self.cache_hits += 1
else:
self.cache_misses += 1
return matched_len, last_block_id
def get_hit_rate(self) -> float:
total = self.cache_hits + self.cache_misses
return self.cache_hits / total if total > 0 else 0.0
return RadixTreeCache()
10. 실전 최적화 체크리스트
10.1 단계별 최적화 가이드
class LLMOptimizationChecklist:
"""LLM 추론 최적화 체크리스트"""
optimizations = [
{
"category": "기본 설정",
"level": 1,
"items": [
{
"name": "FP16/BF16 사용",
"impact": "높음",
"effort": "낮음",
"description": "FP32 → FP16으로 메모리 2배 절약, 속도 향상",
"code": """
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16, # 또는 float16
device_map="auto"
)"""
},
{
"name": "Flash Attention 2 활성화",
"impact": "높음",
"effort": "낮음",
"description": "어텐션 연산 2-4x 속도 향상, 메모리 절약",
"code": """
model = AutoModelForCausalLM.from_pretrained(
model_name,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16
)"""
},
]
},
{
"category": "KV Cache 최적화",
"level": 2,
"items": [
{
"name": "GQA/MQA 모델 선택",
"impact": "높음",
"effort": "중간",
"description": "KV Cache 4-8배 감소, 더 많은 배치 처리 가능"
},
{
"name": "프리픽스 캐싱",
"impact": "중간",
"effort": "낮음",
"description": "공통 시스템 프롬프트 KV Cache 재사용"
},
]
},
{
"category": "배치 최적화",
"level": 3,
"items": [
{
"name": "Continuous Batching (vLLM)",
"impact": "매우 높음",
"effort": "낮음",
"description": "처리량 2-5x 향상",
"code": """
from vllm import LLM, SamplingParams
llm = LLM(
model=model_name,
gpu_memory_utilization=0.90,
enable_prefix_caching=True,
)"""
},
]
},
{
"category": "모델 최적화",
"level": 4,
"items": [
{
"name": "AWQ 4-bit 양자화",
"impact": "높음",
"effort": "중간",
"description": "메모리 4x 감소, 속도 1.5-2x 향상",
"code": """
from awq import AutoAWQForCausalLM
model = AutoAWQForCausalLM.from_quantized(
"model-awq-4bit",
fuse_layers=True
)"""
},
{
"name": "Speculative Decoding",
"impact": "중간",
"effort": "높음",
"description": "2-3x 속도 향상 (적합한 드래프트 모델 필요)"
},
]
},
{
"category": "하드웨어 최적화",
"level": 5,
"items": [
{
"name": "Tensor Parallelism",
"impact": "매우 높음",
"effort": "중간",
"description": "다중 GPU로 선형적 처리량 향상"
},
{
"name": "CUDA 그래프 캡처",
"impact": "중간",
"effort": "높음",
"description": "커널 론치 오버헤드 제거"
},
]
}
]
@classmethod
def print_checklist(cls):
print("=" * 70)
print("LLM 추론 최적화 단계별 체크리스트")
print("=" * 70)
for category in cls.optimizations:
print(f"\n[레벨 {category['level']}] {category['category']}")
print("-" * 50)
for item in category['items']:
impact_emoji = {"매우 높음": "★★★", "높음": "★★", "중간": "★", "낮음": "☆"}
print(f" ✓ {item['name']}")
print(f" 효과: {impact_emoji.get(item['impact'], '?')} {item['impact']}")
print(f" 설명: {item['description']}")
print("\n추천 최적화 순서:")
print("1. BF16/FP16 전환 (즉시, 무료)")
print("2. Flash Attention 2 (즉시, 패키지 설치만)")
print("3. vLLM으로 서빙 (처리량 극대화)")
print("4. AWQ/GPTQ 4비트 양자화 (메모리 4배 절약)")
print("5. Speculative Decoding (레이턴시 개선)")
print("6. 멀티 GPU Tensor Parallelism (규모 확장)")
LLMOptimizationChecklist.print_checklist()
마무리
LLM 추론 최적화는 계층적 접근이 필요합니다.
핵심 요점 정리:
-
KV Cache 이해: 메모리 사용량 공식
2 * layers * kv_heads * d_head * seq_len * dtype_bytes를 외우고, GQA/MQA로 KV Cache를 4-8배 줄이세요. -
PagedAttention: vLLM의 핵심 혁신으로, OS의 가상 메모리에서 영감 받아 KV Cache 단편화를 해결합니다.
-
Continuous Batching: 요청이 완료되는 즉시 새 요청을 삽입하여 GPU 활용률을 극대화합니다.
-
Speculative Decoding: 작은 드래프트 모델 + 큰 검증 모델 조합으로 2-3x 속도 향상이 가능합니다.
-
FlashAttention: 어텐션 계산의 메모리 효율을 O(N^2)에서 O(N)으로 줄여 긴 컨텍스트를 가능하게 합니다.
프로덕션 배포 권고:
- 소규모 서비스: vLLM + AWQ 4bit + 프리픽스 캐싱
- 대규모 서비스: TensorRT-LLM 또는 vLLM + Tensor Parallelism
- 최저 레이턴시 요구: Speculative Decoding + CUDA 그래프
참고 자료
- vLLM/PagedAttention: arXiv:2309.06180
- Speculative Decoding: arXiv:2211.17192
- FlashAttention: arXiv:2205.14135
- FlashAttention-2: arXiv:2307.08691
- Medusa: arXiv:2401.10774
- Continuous Batching: How continuous batching enables 23x throughput
- DeepSeek-V2 MLA: arXiv:2405.04434
LLM Inference Optimization Complete Guide: KV Cache, Speculative Decoding, Continuous Batching
Introduction
When you deploy a large language model (LLM) to production, you immediately face a challenge: inference speed and cost. A first query to a GPT-4 class model can take several seconds, and throughput degrades sharply as concurrent users grow.
This guide thoroughly explores the core techniques for LLM inference optimization. From KV Cache internals to PagedAttention, Speculative Decoding, FlashAttention, and the latest vLLM and TensorRT-LLM engines — we don't just cover how to use them, we understand why they work.
1. Understanding the LLM Inference Pipeline
1.1 Two Phases: Prefill and Decode
LLM text generation is split into two distinct phases.
Prefill Phase (Prompt Processing)
- Processes all tokens of the input prompt simultaneously (parallel)
- Generates and stores Key/Value caches at each layer
- Compute-Bound: GPU compute throughput is the bottleneck
- Directly affects TTFT (Time To First Token)
Decode Phase (Token Generation)
- Generates one token at a time in an autoregressive manner
- References KV Cache from all previously generated tokens
- Memory-Bandwidth-Bound: HBM read speed is the bottleneck
- Directly affects TPOT (Time Per Output Token)
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
def measure_prefill_decode_time(model, tokenizer, prompt: str, max_new_tokens: int = 100):
"""Measure prefill and decode phase timing"""
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
input_len = inputs["input_ids"].size(1)
# Measure TTFT
torch.cuda.synchronize()
prefill_start = time.perf_counter()
with torch.no_grad():
first_output = model.generate(
**inputs,
max_new_tokens=1,
do_sample=False
)
torch.cuda.synchronize()
ttft = time.perf_counter() - prefill_start
# Measure total generation
torch.cuda.synchronize()
total_start = time.perf_counter()
with torch.no_grad():
full_output = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False
)
torch.cuda.synchronize()
total_time = time.perf_counter() - total_start
output_tokens = full_output.size(1) - input_len
decode_time = total_time - ttft
tpot = decode_time / max(output_tokens - 1, 1)
print(f"Input tokens: {input_len}")
print(f"Output tokens: {output_tokens}")
print(f"TTFT (first token latency): {ttft * 1000:.1f} ms")
print(f"TPOT (per token time): {tpot * 1000:.1f} ms")
print(f"Throughput: {output_tokens / total_time:.1f} tokens/sec")
return ttft, tpot
1.2 Memory Bandwidth Analysis
Understanding why the decode phase is memory-bound:
def analyze_memory_bandwidth():
"""LLM inference memory bandwidth analysis"""
# Example: Llama-2-7B config
model_params = {
"num_layers": 32,
"hidden_size": 4096,
"num_heads": 32,
"head_dim": 128,
"vocab_size": 32000,
}
dtype_bytes = 2 # FP16: 2 bytes
# Weight memory
attn_weight = 4 * model_params["hidden_size"] ** 2 # Q, K, V, O projections
ffn_weight = 8 * model_params["hidden_size"] ** 2 # SwiGLU: Up, Gate, Down
layer_weight = (attn_weight + ffn_weight) * dtype_bytes
total_weight_bytes = layer_weight * model_params["num_layers"]
total_weight_gb = total_weight_bytes / 1e9
print(f"Model weights: {total_weight_gb:.2f} GB")
# KV Cache memory (proportional to sequence length)
seq_len = 2048
kv_cache_per_token = (
2 *
model_params["num_layers"] *
model_params["num_heads"] *
model_params["head_dim"] *
dtype_bytes
)
kv_cache_total = kv_cache_per_token * seq_len / 1e6
print(f"KV Cache ({seq_len} tokens): {kv_cache_total:.2f} MB")
print(f"KV Cache per token: {kv_cache_per_token} bytes")
# A100 memory bandwidth: 2 TB/s
memory_bandwidth_tbs = 2.0
# Decode: weights loaded once per token
memory_bound_tps = memory_bandwidth_tbs * 1e12 / total_weight_bytes
# A100 FP16: 312 TFLOPS
compute_throughput = 312e12
flops_per_token = 2 * total_weight_bytes
compute_bound_tps = compute_throughput / flops_per_token
print(f"\nFor batch_size=1:")
print(f"Memory-bound throughput: {memory_bound_tps:.1f} tokens/sec")
print(f"Compute-bound throughput: {compute_bound_tps:.1f} tokens/sec")
print(f"Actual bottleneck: {'memory' if memory_bound_tps < compute_bound_tps else 'compute'}")
analyze_memory_bandwidth()
1.3 Inference Cost Analysis
def estimate_inference_cost(
model_size_b: float,
tokens_per_request: int,
requests_per_day: int,
gpu_cost_per_hour: float = 3.0 # A100 hourly price (USD)
):
"""Estimate inference cost"""
# Empirical throughput estimates
# 7B model: ~100 tok/s, 70B model: ~20 tok/s (A100)
throughput_tps = 100 / (model_size_b / 7) ** 0.6
total_tokens_per_day = tokens_per_request * requests_per_day
seconds_needed = total_tokens_per_day / throughput_tps
hours_needed = seconds_needed / 3600
daily_cost = hours_needed * gpu_cost_per_hour
cost_per_1k_tokens = daily_cost / (total_tokens_per_day / 1000)
print(f"Model: {model_size_b}B parameters")
print(f"Daily requests: {requests_per_day:,}")
print(f"Tokens per request: {tokens_per_request}")
print(f"Total daily tokens: {total_tokens_per_day:,}")
print(f"Estimated throughput: {throughput_tps:.1f} tokens/sec")
print(f"GPU hours needed: {hours_needed:.2f}")
print(f"Daily cost: ${daily_cost:.2f}")
print(f"Cost per 1K tokens: ${cost_per_1k_tokens:.4f}")
estimate_inference_cost(
model_size_b=7.0,
tokens_per_request=500,
requests_per_day=10000
)
2. KV Cache: The Core Optimization
2.1 Why KV Cache is Necessary
In transformer attention, each token computes attention against all previous tokens. To avoid recomputing already-processed tokens, we cache the K and V matrices.
import torch
import torch.nn as nn
import math
class MultiHeadAttentionWithKVCache(nn.Module):
"""Multi-head attention with KV Cache support"""
def __init__(self, d_model: int, num_heads: int, max_seq_len: int = 4096):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_head = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
# KV Cache initialization
self.register_buffer(
'k_cache',
torch.zeros(1, max_seq_len, num_heads, self.d_head)
)
self.register_buffer(
'v_cache',
torch.zeros(1, max_seq_len, num_heads, self.d_head)
)
self.cache_pos = 0
def forward(self, x: torch.Tensor, use_cache: bool = True, position: int = None):
batch_size, seq_len, _ = x.shape
q = self.W_q(x).reshape(batch_size, seq_len, self.num_heads, self.d_head)
k = self.W_k(x).reshape(batch_size, seq_len, self.num_heads, self.d_head)
v = self.W_v(x).reshape(batch_size, seq_len, self.num_heads, self.d_head)
if use_cache:
start_pos = self.cache_pos if position is None else position
self.k_cache[:, start_pos:start_pos + seq_len] = k
self.v_cache[:, start_pos:start_pos + seq_len] = v
if position is None:
self.cache_pos += seq_len
total_len = self.cache_pos if position is None else start_pos + seq_len
k = self.k_cache[:, :total_len]
v = self.v_cache[:, :total_len]
scale = math.sqrt(self.d_head)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale
attn_weights = torch.softmax(attn_scores, dim=-1)
output = torch.matmul(attn_weights, v)
output = output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
output = self.W_o(output)
return output
def clear_cache(self):
self.k_cache.zero_()
self.v_cache.zero_()
self.cache_pos = 0
2.2 KV Cache Memory Calculation
def calculate_kv_cache_memory(
model_config: dict,
batch_size: int,
seq_len: int,
dtype_bytes: int = 2 # FP16
) -> dict:
"""Calculate KV Cache memory usage"""
num_layers = model_config["num_layers"]
num_kv_heads = model_config.get("num_kv_heads", model_config["num_heads"])
head_dim = model_config["head_dim"]
# Formula: 2 (K+V) * layers * kv_heads * head_dim * seq_len * batch * dtype
kv_cache_bytes = (
2 *
num_layers *
num_kv_heads *
head_dim *
seq_len *
batch_size *
dtype_bytes
)
return {
"kv_cache_bytes": kv_cache_bytes,
"kv_cache_mb": kv_cache_bytes / 1e6,
"kv_cache_gb": kv_cache_bytes / 1e9,
"per_token_bytes": kv_cache_bytes // seq_len,
}
models = {
"Llama-2-7B (MHA)": {
"num_layers": 32, "num_heads": 32,
"num_kv_heads": 32, "head_dim": 128
},
"Llama-2-70B (GQA)": {
"num_layers": 80, "num_heads": 64,
"num_kv_heads": 8, "head_dim": 128
},
"Mistral-7B (GQA)": {
"num_layers": 32, "num_heads": 32,
"num_kv_heads": 8, "head_dim": 128
},
}
print("KV Cache memory (batch=1, seq=4096)")
print("=" * 65)
for name, config in models.items():
result = calculate_kv_cache_memory(config, batch_size=1, seq_len=4096)
print(f"{name:<25} {result['kv_cache_gb']:.2f} GB "
f"({result['per_token_bytes']:,} bytes/token)")
2.3 Grouped Query Attention (GQA)
GQA is a key technique for reducing KV Cache — multiple query heads share fewer KV heads.
import torch
import torch.nn as nn
import math
class GroupedQueryAttention(nn.Module):
"""Grouped Query Attention (GQA) implementation"""
def __init__(self, d_model: int, num_q_heads: int, num_kv_heads: int):
super().__init__()
assert num_q_heads % num_kv_heads == 0
self.d_model = d_model
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.num_groups = num_q_heads // num_kv_heads
self.d_head = d_model // num_q_heads
self.W_q = nn.Linear(d_model, num_q_heads * self.d_head, bias=False)
self.W_k = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
self.W_v = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
def forward(self, x: torch.Tensor, kv_cache=None):
batch_size, seq_len, _ = x.shape
q = self.W_q(x).reshape(batch_size, seq_len, self.num_q_heads, self.d_head)
k = self.W_k(x).reshape(batch_size, seq_len, self.num_kv_heads, self.d_head)
v = self.W_v(x).reshape(batch_size, seq_len, self.num_kv_heads, self.d_head)
if kv_cache is not None:
k = torch.cat([kv_cache["k"], k], dim=1)
v = torch.cat([kv_cache["v"], v], dim=1)
new_kv_cache = {"k": k, "v": v}
q = q.transpose(1, 2) # [B, Q_heads, S, d_head]
k = k.transpose(1, 2) # [B, KV_heads, T, d_head]
v = v.transpose(1, 2)
# GQA: expand KV heads to match Q heads
k = k.repeat_interleave(self.num_groups, dim=1)
v = v.repeat_interleave(self.num_groups, dim=1)
scale = math.sqrt(self.d_head)
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale
attn_weights = torch.softmax(attn_scores, dim=-1)
output = torch.matmul(attn_weights, v)
output = output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
return self.W_o(output), new_kv_cache
def compare_attention_variants():
"""Compare KV Cache memory across attention variants"""
# 70B model (Llama-2-70B)
num_layers = 80
d_head = 128
seq_len = 4096
dtype_bytes = 2
variants = {
"MHA (32 KV heads)": 32,
"GQA (8 KV heads)": 8,
"MQA (1 KV head)": 1,
}
print("70B model attention variant KV Cache comparison")
print(f"(seq_len={seq_len}, batch=1)")
print("=" * 55)
for name, num_kv_heads in variants.items():
kv_bytes = 2 * num_layers * num_kv_heads * d_head * seq_len * dtype_bytes
kv_gb = kv_bytes / 1e9
print(f"{name:<25} {kv_gb:.2f} GB")
compare_attention_variants()
2.4 DeepSeek MLA (Multi-Head Latent Attention)
MLA, introduced in DeepSeek-V2, compresses KV Cache into low-dimensional latent vectors.
class MultiHeadLatentAttention(nn.Module):
"""
DeepSeek MLA — compresses KV Cache to low-rank latent vectors
Core idea:
- Instead of storing high-dim K,V, store low-dim latent c_kv
- Recover K, V from c_kv via up-projection
- KV Cache size: num_layers * kv_lora_rank * seq_len
(vs standard: 2 * num_layers * num_kv_heads * d_head * seq_len)
"""
def __init__(
self,
d_model: int = 5120,
num_heads: int = 128,
kv_lora_rank: int = 512,
qk_nope_head_dim: int = 128,
qk_rope_head_dim: int = 64,
v_head_dim: int = 128,
):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.kv_lora_rank = kv_lora_rank
# Q projection (LoRA-style)
self.q_a_proj = nn.Linear(d_model, 1536, bias=False)
self.q_b_proj = nn.Linear(
1536,
num_heads * (qk_nope_head_dim + qk_rope_head_dim),
bias=False
)
# KV down-projection: d_model -> kv_lora_rank
# Only THIS is stored in KV Cache!
self.kv_a_proj = nn.Linear(
d_model,
kv_lora_rank + qk_rope_head_dim,
bias=False
)
# KV up-projection: kv_lora_rank -> K, V
self.kv_b_proj = nn.Linear(
kv_lora_rank,
num_heads * (qk_nope_head_dim + v_head_dim),
bias=False
)
self.o_proj = nn.Linear(num_heads * v_head_dim, d_model, bias=False)
def forward(self, x: torch.Tensor, compressed_kv_cache=None):
batch_size, seq_len, _ = x.shape
# KV compression (only this result goes in cache)
kv_compressed = self.kv_a_proj(x) # [B, S, kv_lora_rank + rope_dim]
if compressed_kv_cache is not None:
kv_compressed_total = torch.cat([compressed_kv_cache, kv_compressed], dim=1)
else:
kv_compressed_total = kv_compressed
# Recover full K, V from cached compressed representation
kv_content = kv_compressed_total[:, :, :self.kv_lora_rank]
kv_full = self.kv_b_proj(kv_content)
return None, kv_compressed
def compare_mla_vs_mha():
"""Compare MLA vs MHA KV Cache"""
seq_len = 4096
dtype_bytes = 2 # BF16
num_layers = 60 # DeepSeek-V2
num_heads = 128
head_dim = 128
mha_kv_gb = 2 * num_layers * num_heads * head_dim * seq_len * dtype_bytes / 1e9
kv_lora_rank = 512
rope_dim = 64
mla_kv_gb = (kv_lora_rank + rope_dim) * num_layers * seq_len * dtype_bytes / 1e9
print(f"MHA KV Cache: {mha_kv_gb:.2f} GB")
print(f"MLA KV Cache: {mla_kv_gb:.2f} GB")
print(f"Reduction: {mha_kv_gb / mla_kv_gb:.1f}x")
compare_mla_vs_mha()
3. PagedAttention: vLLM's Core Innovation
3.1 The Problem with Conventional KV Cache
Traditional LLM serving systems pre-allocate memory for the maximum sequence length per request:
Request 1: [PROMPT=200 tokens] [KV_CACHE=up to 1848 tokens reserved] → internal fragmentation
Request 2: [PROMPT=100 tokens] [KV_CACHE=1948 tokens reserved]
Request 3: Blocked waiting (external fragmentation)
This wastes 60–80% of GPU memory.
3.2 PagedAttention: How It Works
Inspired by OS virtual memory, PagedAttention manages KV Cache in fixed-size physical blocks.
from dataclasses import dataclass, field
from typing import Dict, List, Optional
import torch
@dataclass
class PhysicalBlock:
"""Physical memory block"""
block_id: int
block_size: int # tokens per block (e.g., 16)
ref_count: int = 0
@dataclass
class LogicalBlock:
"""Logical block (mapped to a request)"""
physical_block_id: int
num_filled: int = 0
class PagedKVCacheManager:
"""PagedAttention KV Cache manager"""
def __init__(
self,
num_physical_blocks: int,
block_size: int,
num_layers: int,
num_kv_heads: int,
head_dim: int,
device: str = "cuda"
):
self.block_size = block_size
self.num_layers = num_layers
self.free_blocks: List[int] = list(range(num_physical_blocks))
self.all_blocks: Dict[int, PhysicalBlock] = {
i: PhysicalBlock(block_id=i, block_size=block_size)
for i in range(num_physical_blocks)
}
self.block_tables: Dict[int, List[LogicalBlock]] = {}
# Actual KV Cache tensor pool
self.kv_cache = torch.zeros(
num_physical_blocks, 2, num_layers, block_size, num_kv_heads, head_dim,
dtype=torch.float16,
device=device
)
def allocate_blocks_for_request(self, request_id: int, num_tokens: int):
"""Allocate blocks for a request"""
num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size
if len(self.free_blocks) < num_blocks_needed:
raise RuntimeError(
f"OOM: need {num_blocks_needed} blocks, {len(self.free_blocks)} available"
)
logical_blocks = []
for _ in range(num_blocks_needed):
physical_id = self.free_blocks.pop(0)
self.all_blocks[physical_id].ref_count = 1
logical_blocks.append(LogicalBlock(physical_block_id=physical_id))
self.block_tables[request_id] = logical_blocks
print(f"Request {request_id}: {num_blocks_needed} blocks allocated, "
f"{len(self.free_blocks)} remaining")
def append_token(self, request_id: int, layer: int, token_pos: int,
k: torch.Tensor, v: torch.Tensor):
"""Store KV for a new token in the cache"""
block_idx = token_pos // self.block_size
token_in_block = token_pos % self.block_size
logical_block = self.block_tables[request_id][block_idx]
physical_id = logical_block.physical_block_id
self.kv_cache[physical_id, 0, layer, token_in_block] = k
self.kv_cache[physical_id, 1, layer, token_in_block] = v
logical_block.num_filled = token_in_block + 1
def free_request(self, request_id: int):
"""Free blocks after request completes"""
if request_id in self.block_tables:
for logical_block in self.block_tables[request_id]:
phys_id = logical_block.physical_block_id
self.all_blocks[phys_id].ref_count -= 1
if self.all_blocks[phys_id].ref_count == 0:
self.free_blocks.append(phys_id)
del self.block_tables[request_id]
def copy_on_write(self, src_request_id: int, dst_request_id: int):
"""Copy-on-Write for prefix caching"""
src_blocks = self.block_tables[src_request_id]
dst_blocks = []
for logical_block in src_blocks:
phys_id = logical_block.physical_block_id
self.all_blocks[phys_id].ref_count += 1
dst_blocks.append(
LogicalBlock(
physical_block_id=phys_id,
num_filled=logical_block.num_filled
)
)
self.block_tables[dst_request_id] = dst_blocks
# Demo
manager = PagedKVCacheManager(
num_physical_blocks=1000,
block_size=16,
num_layers=32,
num_kv_heads=32,
head_dim=128
)
manager.allocate_blocks_for_request(request_id=1, num_tokens=200)
manager.allocate_blocks_for_request(request_id=2, num_tokens=500)
manager.allocate_blocks_for_request(request_id=3, num_tokens=100)
manager.free_request(request_id=1)
print(f"\nAfter request 1 completes — free blocks: {len(manager.free_blocks)}")
4. Continuous Batching
4.1 The Problem with Static Batching
Static batching waits for all requests in a batch to complete before starting the next:
t=0: [Request A: 500 tokens] [Request B: 100 tokens] [Request C: 300 tokens]
t=1: Request B done — but can't start new requests until A and C finish
t=2: Request C done
t=3: Request A done → only now can a new batch start
4.2 Continuous Batching (Iteration-level Scheduling)
from dataclasses import dataclass, field
from typing import List, Tuple
import time
@dataclass
class Request:
"""Inference request"""
request_id: str
input_ids: List[int]
max_new_tokens: int
generated_ids: List[int] = field(default_factory=list)
is_finished: bool = False
class ContinuousBatchingScheduler:
"""Continuous Batching scheduler"""
def __init__(self, max_batch_size: int = 32, max_seq_len: int = 4096):
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
self.waiting_queue: List[Request] = []
self.running_requests: List[Request] = []
self.finished_requests: List[Request] = []
def add_request(self, request: Request):
self.waiting_queue.append(request)
def _can_add_request(self, request: Request) -> bool:
if len(self.running_requests) + 1 > self.max_batch_size:
return False
total_tokens = sum(
len(r.input_ids) + len(r.generated_ids)
for r in self.running_requests
) + len(request.input_ids)
return total_tokens < self.max_seq_len * self.max_batch_size
def schedule_iteration(self) -> Tuple[List[Request], List[str]]:
"""
Schedule one iteration's batch
Returns:
(active requests, list of just-completed request IDs)
"""
completed_ids = []
still_running = []
for req in self.running_requests:
if req.is_finished:
self.finished_requests.append(req)
completed_ids.append(req.request_id)
else:
still_running.append(req)
self.running_requests = still_running
# Fill empty slots immediately with waiting requests
while self.waiting_queue and self._can_add_request(self.waiting_queue[0]):
new_request = self.waiting_queue.pop(0)
self.running_requests.append(new_request)
print(f"Added request {new_request.request_id} to batch "
f"(batch size: {len(self.running_requests)})")
return self.running_requests, completed_ids
def simulate_one_step(self, model_forward_fn):
"""Simulate one step"""
active_requests, completed = self.schedule_iteration()
if not active_requests:
return []
batch_input_ids = []
for req in active_requests:
if len(req.generated_ids) == 0:
batch_input_ids.append(req.input_ids) # Prefill
else:
batch_input_ids.append([req.generated_ids[-1]]) # Decode
outputs = model_forward_fn(batch_input_ids)
for req, next_token_id in zip(active_requests, outputs):
req.generated_ids.append(next_token_id)
if next_token_id == 2 or len(req.generated_ids) >= req.max_new_tokens:
req.is_finished = True
return completed
5. Speculative Decoding
5.1 The Idea: Draft + Verify
Speculative Decoding's core: a small draft model generates several tokens in parallel, and the large target model verifies them all at once (like a prefill).
Standard: [large model] → token1 → token2 → token3 → token4 → token5
Speculative: [small model] → generate (t1, t2, t3, t4, t5) in parallel
[large model] → verify all 5 tokens at once (parallel, like prefill)
only accepted tokens are kept
5.2 Acceptance Rate and Speedup Analysis
import torch
from typing import List, Tuple
def speculative_decode_step(
draft_model,
target_model,
input_ids: torch.Tensor,
draft_steps: int = 4,
temperature: float = 1.0
) -> Tuple[torch.Tensor, int, int]:
"""
One step of Speculative Decoding
Returns:
(generated tokens, num_accepted, num_drafted)
"""
# 1. Draft model generates candidate tokens
draft_tokens = []
draft_probs = []
current_ids = input_ids.clone()
for _ in range(draft_steps):
with torch.no_grad():
draft_logits = draft_model(current_ids).logits[:, -1, :]
draft_prob = torch.softmax(draft_logits / (temperature + 1e-8), dim=-1)
draft_token = torch.multinomial(draft_prob, num_samples=1)
draft_tokens.append(draft_token)
draft_probs.append(draft_prob)
current_ids = torch.cat([current_ids, draft_token], dim=1)
draft_sequence = torch.cat(draft_tokens, dim=1)
candidate_ids = torch.cat([input_ids, draft_sequence], dim=1)
# 2. Target model verifies all draft tokens in one pass
with torch.no_grad():
target_logits = target_model(candidate_ids).logits[
:, input_ids.size(1) - 1:-1, :
]
target_probs = torch.softmax(target_logits / (temperature + 1e-8), dim=-1)
# 3. Accept/reject each draft token
accepted_tokens = []
num_accepted = 0
for step in range(draft_steps):
token = draft_sequence[:, step]
p_draft = draft_probs[step].gather(1, token.unsqueeze(1)).squeeze(1)
p_target = target_probs[:, step, :].gather(1, token.unsqueeze(1)).squeeze(1)
acceptance_prob = torch.clamp(p_target / (p_draft + 1e-8), max=1.0)
accepted = torch.rand_like(acceptance_prob) < acceptance_prob
if not accepted.all():
break
accepted_tokens.append(token)
num_accepted += 1
# 4. Final token from target model
last_logits = target_model(candidate_ids).logits[:, input_ids.size(1) + num_accepted - 1, :]
last_prob = torch.softmax(last_logits / (temperature + 1e-8), dim=-1)
if num_accepted < draft_steps:
correction = torch.clamp(last_prob - draft_probs[num_accepted], min=0)
correction = correction / (correction.sum(dim=-1, keepdim=True) + 1e-8)
last_token = torch.multinomial(correction, num_samples=1)
else:
last_token = torch.multinomial(last_prob, num_samples=1)
accepted_tokens.append(last_token.squeeze(1))
final_tokens = torch.stack(accepted_tokens, dim=1)
return final_tokens, num_accepted, draft_steps
def analyze_speedup(acceptance_rate: float, draft_steps: int = 4) -> dict:
"""Speedup analysis based on acceptance rate"""
# E[accepted tokens] = sum_{k=0}^{K} alpha^k
expected_accepted = sum(
acceptance_rate ** k for k in range(draft_steps + 1)
)
# Draft model is 1/10 the size of target model
draft_model_ratio = 0.1
steps_with_speculative = draft_steps * draft_model_ratio + 1
speedup = expected_accepted / steps_with_speculative
return {
"acceptance_rate": acceptance_rate,
"expected_accepted": expected_accepted,
"speedup": speedup
}
print("Speculative Decoding speedup by acceptance rate (K=4)")
print("=" * 55)
for alpha in [0.5, 0.6, 0.7, 0.8, 0.9, 0.95]:
result = analyze_speedup(alpha, draft_steps=4)
print(f"Accept rate {alpha:.0%}: expected {result['expected_accepted']:.2f} tokens, "
f"speedup {result['speedup']:.2f}x")
5.3 Medusa: Multiple Draft Heads
import torch
import torch.nn as nn
class MedusaHead(nn.Module):
"""
Medusa: attach multiple draft heads to a single model
Each head predicts a future token:
- Head 1: predicts t+1
- Head 2: predicts t+2
- Head N: predicts t+N
"""
def __init__(
self,
hidden_size: int,
vocab_size: int,
num_heads: int = 4,
hidden_layers: int = 1
):
super().__init__()
self.num_heads = num_heads
self.heads = nn.ModuleList([
nn.Sequential(
*[nn.Linear(hidden_size, hidden_size, bias=False),
nn.SiLU()] * hidden_layers,
nn.Linear(hidden_size, vocab_size, bias=False)
)
for _ in range(num_heads)
])
def forward(self, hidden_states: torch.Tensor):
"""
Returns list of logits for each future position
"""
return [head(hidden_states) for head in self.heads]
class MedusaModel(nn.Module):
"""Medusa full model"""
def __init__(self, base_model, vocab_size: int, num_medusa_heads: int = 4):
super().__init__()
self.base_model = base_model
hidden_size = base_model.config.hidden_size
self.medusa_heads = MedusaHead(hidden_size, vocab_size, num_medusa_heads)
def forward(self, input_ids: torch.Tensor, use_medusa: bool = False):
base_output = self.base_model(input_ids, output_hidden_states=True)
base_logits = base_output.logits
if not use_medusa:
return base_logits, None
last_hidden = base_output.hidden_states[-1]
medusa_logits = self.medusa_heads(last_hidden)
return base_logits, medusa_logits
6. FlashAttention: Memory-Efficient Attention
6.1 Standard Attention's HBM Bottleneck
Standard attention repeatedly writes and reads intermediate results to HBM:
Standard Attention memory ops:
1. Read Q, K from HBM → read: O(N * d)
2. Compute S = Q @ K.T → write: O(N^2) ← bottleneck!
3. Read S from HBM for softmax → read: O(N^2)
4. Store P = softmax(S) → write: O(N^2)
5. Read P for P @ V → read: O(N^2)
6. Store final output → write: O(N * d)
Total HBM accesses: O(N^2) — quadratic in sequence length!
6.2 FlashAttention's Tiling Strategy
import torch
import math
def flash_attention_v1(Q, K, V, block_size=64):
"""
FlashAttention v1 simplified implementation
Avoids storing the full attention matrix in HBM via tiling
Key: Online Softmax for block-wise processing
"""
batch_size, num_heads, seq_len, d_head = Q.shape
scale = 1.0 / math.sqrt(d_head)
Q = Q * scale
O = torch.zeros_like(Q)
L = torch.zeros(batch_size, num_heads, seq_len, 1, device=Q.device)
M = torch.full((batch_size, num_heads, seq_len, 1), float('-inf'), device=Q.device)
num_blocks = (seq_len + block_size - 1) // block_size
for j in range(num_blocks):
k_start = j * block_size
k_end = min((j + 1) * block_size, seq_len)
K_j = K[:, :, k_start:k_end, :]
V_j = V[:, :, k_start:k_end, :]
for i in range(num_blocks):
q_start = i * block_size
q_end = min((i + 1) * block_size, seq_len)
Q_i = Q[:, :, q_start:q_end, :]
O_i = O[:, :, q_start:q_end, :]
L_i = L[:, :, q_start:q_end, :]
M_i = M[:, :, q_start:q_end, :]
# Compute attention scores in SRAM
S_ij = torch.matmul(Q_i, K_j.transpose(-2, -1))
# Online Softmax update
M_new = torch.maximum(M_i, S_ij.max(dim=-1, keepdim=True)[0])
P_ij = torch.exp(S_ij - M_new)
L_new = torch.exp(M_i - M_new) * L_i + P_ij.sum(dim=-1, keepdim=True)
# Rescale and update output
O_new = torch.exp(M_i - M_new) * O_i + torch.matmul(P_ij, V_j)
O[:, :, q_start:q_end, :] = O_new
L[:, :, q_start:q_end, :] = L_new
M[:, :, q_start:q_end, :] = M_new
return O / L
def compare_attention_implementations():
"""Compare FlashAttention vs standard attention"""
batch_size, num_heads, seq_len, d_head = 2, 32, 4096, 128
Q = torch.randn(batch_size, num_heads, seq_len, d_head, device='cuda', dtype=torch.float16)
K = torch.randn_like(Q)
V = torch.randn_like(Q)
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
flash_output = F.scaled_dot_product_attention(Q, K, V)
scale = 1.0 / math.sqrt(d_head)
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
attn_weights = torch.softmax(attn_scores, dim=-1)
standard_output = torch.matmul(attn_weights, V)
max_diff = (flash_output - standard_output).abs().max().item()
print(f"FlashAttention vs standard max diff: {max_diff:.6f}")
standard_attn_bytes = batch_size * num_heads * seq_len * seq_len * 2 # FP16
print(f"Standard attention matrix memory: {standard_attn_bytes / 1e9:.2f} GB")
print(f"FlashAttention matrix memory: ~0 GB (tiled, never fully materialized)")
6.3 Using PyTorch SDPA
import torch
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel
def modern_attention(q, k, v, is_causal=True, dropout_p=0.0):
"""
PyTorch 2.0+ scaled_dot_product_attention
Automatically selects FlashAttention 2/3
"""
return F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=dropout_p,
is_causal=is_causal,
scale=None # defaults to 1/sqrt(d_head)
)
# Flash Attention version highlights
flash_versions = {
"FlashAttention 1 (arXiv:2205.14135)": {
"key_innovation": "Tiling + Online Softmax",
"memory": "O(N) — no attention matrix stored",
"speedup": "2-4x vs standard"
},
"FlashAttention 2 (arXiv:2307.08691)": {
"key_innovation": "Work partitioning, FP16/BF16",
"memory": "O(N)",
"speedup": "5-9x vs standard on H100"
},
"FlashAttention 3 (arXiv:2407.08608)": {
"key_innovation": "H100-specific, FP8, async pipeline",
"memory": "O(N)",
"speedup": "1.5-2x vs FA2 on H100"
},
}
for name, info in flash_versions.items():
print(f"\n{name}")
for k, v in info.items():
print(f" {k}: {v}")
7. Multi-GPU Inference
7.1 Tensor Parallelism
Weight matrices are split across GPUs; each GPU handles a shard.
import torch
import torch.nn as nn
class TensorParallelLinear(nn.Module):
"""
Tensor Parallel Linear layer (column-parallel)
Each GPU owns out_features // world_size output neurons
"""
def __init__(
self,
in_features: int,
out_features: int,
world_size: int,
rank: int
):
super().__init__()
self.world_size = world_size
self.rank = rank
self.local_out_features = out_features // world_size
self.weight = nn.Parameter(
torch.randn(self.local_out_features, in_features) / (in_features ** 0.5)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
local_output = nn.functional.linear(x, self.weight)
# In a real distributed setup: dist.all_gather(output_list, local_output)
return local_output
def setup_vllm_multiGPU(model_name: str, tp_size: int):
"""Set up multi-GPU vLLM inference"""
from vllm import LLM, SamplingParams
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size,
gpu_memory_utilization=0.9
)
return llm
7.2 Full vLLM Usage
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
import asyncio
import time
def vllm_basic_usage():
"""vLLM basic usage"""
llm = LLM(
model="meta-llama/Llama-2-7b-hf",
tensor_parallel_size=1,
gpu_memory_utilization=0.90,
max_model_len=4096,
quantization=None, # "awq", "gptq", "squeezellm"
dtype="auto",
max_num_seqs=256,
enable_prefix_caching=True,
use_v2_block_manager=True,
)
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=200,
)
prompts = [
"Explain quantum computing in simple terms",
"What is the future of artificial intelligence?",
"How does the human brain work?",
]
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print(f"Prompt: {output.prompt[:50]}...")
print(f"Output: {output.outputs[0].text[:100]}...")
print(f"Tokens: {len(output.outputs[0].token_ids)}")
print()
return outputs
async def vllm_async_server():
"""vLLM async engine usage"""
engine_args = AsyncEngineArgs(
model="meta-llama/Llama-2-7b-hf",
tensor_parallel_size=1,
gpu_memory_utilization=0.90,
max_model_len=4096,
enable_prefix_caching=True,
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
async def generate_stream(prompt: str, request_id: str):
sampling_params = SamplingParams(temperature=0.8, max_tokens=200)
full_text = ""
async for output in engine.generate(prompt, sampling_params, request_id):
if output.outputs:
delta = output.outputs[0].text[len(full_text):]
full_text = output.outputs[0].text
if delta:
print(f"[{request_id}] {delta}", end="", flush=True)
if output.finished:
print(f"\n[{request_id}] done")
await asyncio.gather(
generate_stream("What is AI?", "req_1"),
generate_stream("Explain machine learning", "req_2"),
generate_stream("What is deep learning?", "req_3"),
)
8. Inference Engine Comparison
8.1 Major Engine Features
| Engine | Author | Key Features | Best For |
|---|---|---|---|
| vLLM | UC Berkeley | PagedAttention, Continuous Batching | General LLM serving |
| TGI | HuggingFace | Flash Attention 2, Speculative | HF model serving |
| TensorRT-LLM | NVIDIA | NVIDIA-optimized, FP8 | Max NVIDIA perf |
| DeepSpeed-MII | Microsoft | ZeRO inference, huge models | Multi-GPU giant models |
| llama.cpp | G. Gerganov | CPU-optimized, GGUF | Local execution |
8.2 Benchmark Results
import time
def run_inference_benchmark(engine_name: str, model, tokenizer, prompts, max_tokens=100):
"""Simple inference benchmark"""
num_warmup = 5
num_runs = 50
# Warmup
for prompt in prompts[:num_warmup]:
_ = model.generate(
tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
max_new_tokens=max_tokens,
do_sample=False
)
# Timed runs
import torch
torch.cuda.synchronize()
start = time.perf_counter()
total_tokens = 0
for prompt in prompts[:num_runs]:
output = model.generate(
tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
max_new_tokens=max_tokens,
do_sample=False
)
total_tokens += max_tokens
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
throughput = total_tokens / elapsed
latency_ms = elapsed / num_runs * 1000
print(f"\n{engine_name}")
print(f" Throughput: {throughput:.1f} tokens/sec")
print(f" Avg latency: {latency_ms:.1f} ms/request")
return throughput, latency_ms
# Example comparison (A100 80GB, Llama-2-7B, batch=1, 100 output tokens)
benchmark_results = {
"HuggingFace (FP16)": {"throughput": 52, "latency_ms": 1924},
"HuggingFace (Flash Attn 2)": {"throughput": 78, "latency_ms": 1280},
"vLLM": {"throughput": 120, "latency_ms": 832},
"vLLM + AWQ 4bit": {"throughput": 165, "latency_ms": 606},
"TensorRT-LLM": {"throughput": 180, "latency_ms": 555},
}
print("LLM inference engine benchmark (A100 80GB, Llama-2-7B)")
print("=" * 65)
print(f"{'Engine':<30} {'Throughput (tok/s)':<22} {'Latency (ms)':<15}")
print("-" * 65)
for engine, stats in benchmark_results.items():
print(f"{engine:<30} {stats['throughput']:<22} {stats['latency_ms']:<15}")
9. Prompt Caching
9.1 Prefix Caching
Reuse KV Cache when the same system prompt or document is processed repeatedly.
from vllm import LLM, SamplingParams
import time
def demonstrate_prefix_caching():
"""Demonstrate prefix caching benefit"""
llm = LLM(
model="meta-llama/Llama-2-7b-hf",
enable_prefix_caching=True,
max_model_len=4096,
)
# Long system prompt common to all requests (1000+ tokens)
system_prompt = (
"You are a helpful AI assistant with expertise in Python, "
"machine learning, data science, and cloud computing. "
) * 50
questions = [
"How do I optimize a Python loop?",
"What is gradient descent?",
"Explain containerization.",
"What is a neural network?",
]
sampling_params = SamplingParams(temperature=0.7, max_tokens=100)
cold_prompts = [f"{system_prompt}\n\nQuestion: {q}" for q in questions]
# Cold start (no cache)
cold_start = time.time()
llm.generate(cold_prompts, sampling_params)
cold_time = time.time() - cold_start
# Warm start (cache hit)
warm_start = time.time()
llm.generate(cold_prompts, sampling_params)
warm_time = time.time() - warm_start
print(f"Cold start (no cache): {cold_time:.2f}s")
print(f"Warm start (cache hit): {warm_time:.2f}s")
print(f"Speedup: {cold_time / warm_time:.2f}x")
def radix_tree_prefix_cache():
"""Radix Tree prefix cache implementation"""
class RadixNode:
def __init__(self):
self.children: dict = {}
self.kv_cache_block_id: int = None
class RadixTreeCache:
"""
Manages token sequences in a Radix Tree to share
common-prefix KV caches
"""
def __init__(self):
self.root = RadixNode()
self.cache_hits = 0
self.cache_misses = 0
def insert(self, token_ids: list, block_id: int):
node = self.root
for token_id in token_ids:
if token_id not in node.children:
node.children[token_id] = RadixNode()
node = node.children[token_id]
node.kv_cache_block_id = block_id
def lookup(self, token_ids: list) -> tuple:
"""Find the longest matching prefix"""
node = self.root
matched_len = 0
last_block_id = None
for i, token_id in enumerate(token_ids):
if token_id in node.children:
node = node.children[token_id]
matched_len = i + 1
if node.kv_cache_block_id is not None:
last_block_id = node.kv_cache_block_id
else:
break
if last_block_id is not None:
self.cache_hits += 1
else:
self.cache_misses += 1
return matched_len, last_block_id
def get_hit_rate(self) -> float:
total = self.cache_hits + self.cache_misses
return self.cache_hits / total if total > 0 else 0.0
return RadixTreeCache()
10. Practical Optimization Checklist
10.1 Step-by-Step Optimization Guide
class LLMOptimizationChecklist:
"""LLM inference optimization checklist"""
optimizations = [
{
"category": "Baseline",
"level": 1,
"items": [
{
"name": "Use FP16/BF16",
"impact": "High",
"effort": "Low",
"description": "FP32 → FP16: 2x memory saving, speed improvement",
"code": """
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)"""
},
{
"name": "Enable Flash Attention 2",
"impact": "High",
"effort": "Low",
"description": "2-4x attention speedup, memory saving",
"code": """
model = AutoModelForCausalLM.from_pretrained(
model_name,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16
)"""
},
]
},
{
"category": "KV Cache Optimization",
"level": 2,
"items": [
{
"name": "Choose GQA/MQA model",
"impact": "High",
"effort": "Medium",
"description": "4-8x KV Cache reduction, larger effective batch"
},
{
"name": "Prefix caching",
"impact": "Medium",
"effort": "Low",
"description": "Reuse KV Cache for common system prompts"
},
]
},
{
"category": "Batching Optimization",
"level": 3,
"items": [
{
"name": "Continuous Batching with vLLM",
"impact": "Very High",
"effort": "Low",
"description": "2-5x throughput improvement",
"code": """
from vllm import LLM, SamplingParams
llm = LLM(
model=model_name,
gpu_memory_utilization=0.90,
enable_prefix_caching=True,
)"""
},
]
},
{
"category": "Model Optimization",
"level": 4,
"items": [
{
"name": "AWQ 4-bit quantization",
"impact": "High",
"effort": "Medium",
"description": "4x memory reduction, 1.5-2x speed",
"code": """
from awq import AutoAWQForCausalLM
model = AutoAWQForCausalLM.from_quantized(
"model-awq-4bit",
fuse_layers=True
)"""
},
{
"name": "Speculative Decoding",
"impact": "Medium",
"effort": "High",
"description": "2-3x speedup (requires suitable draft model)"
},
]
},
{
"category": "Hardware Optimization",
"level": 5,
"items": [
{
"name": "Tensor Parallelism",
"impact": "Very High",
"effort": "Medium",
"description": "Linear throughput scaling with multiple GPUs"
},
{
"name": "CUDA graph capture",
"impact": "Medium",
"effort": "High",
"description": "Eliminate kernel launch overhead"
},
]
}
]
@classmethod
def print_checklist(cls):
print("=" * 70)
print("LLM Inference Optimization — Step-by-Step Checklist")
print("=" * 70)
for category in cls.optimizations:
print(f"\n[Level {category['level']}] {category['category']}")
print("-" * 50)
for item in category['items']:
impact_stars = {"Very High": "★★★", "High": "★★", "Medium": "★", "Low": "☆"}
print(f" ✓ {item['name']}")
print(f" Impact: {impact_stars.get(item['impact'], '?')} {item['impact']}")
print(f" Note: {item['description']}")
print("\nRecommended optimization order:")
print("1. Switch to BF16/FP16 (immediate, free)")
print("2. Enable Flash Attention 2 (immediate, just install)")
print("3. Serve with vLLM (max throughput)")
print("4. AWQ/GPTQ 4-bit quantization (4x memory reduction)")
print("5. Speculative Decoding (latency improvement)")
print("6. Multi-GPU Tensor Parallelism (scale out)")
LLMOptimizationChecklist.print_checklist()
10.2 End-to-End Production Setup
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from fastapi import FastAPI
from pydantic import BaseModel
import asyncio
import uvicorn
app = FastAPI(title="LLM Inference API")
class GenerateRequest(BaseModel):
prompt: str
max_tokens: int = 200
temperature: float = 0.7
top_p: float = 0.95
stream: bool = False
class GenerateResponse(BaseModel):
text: str
tokens_generated: int
finish_reason: str
# Global engine instance
engine: AsyncLLMEngine = None
def create_optimized_engine(model_name: str, **kwargs) -> AsyncLLMEngine:
"""Create a production-optimized vLLM engine"""
engine_args = AsyncEngineArgs(
model=model_name,
tensor_parallel_size=kwargs.get("tp_size", 1),
gpu_memory_utilization=kwargs.get("gpu_util", 0.90),
max_model_len=kwargs.get("max_model_len", 4096),
quantization=kwargs.get("quantization", None), # "awq" or "gptq"
dtype="auto",
max_num_seqs=kwargs.get("max_num_seqs", 256),
enable_prefix_caching=True,
use_v2_block_manager=True,
speculative_model=kwargs.get("draft_model", None), # optional draft model
num_speculative_tokens=kwargs.get("num_spec_tokens", 5),
)
return AsyncLLMEngine.from_engine_args(engine_args)
@app.on_event("startup")
async def startup():
global engine
engine = create_optimized_engine(
model_name="meta-llama/Llama-2-7b-hf",
tp_size=1,
gpu_util=0.90,
max_model_len=4096,
)
@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
sampling_params = SamplingParams(
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens,
)
request_id = f"req_{id(request)}"
final_output = None
async for output in engine.generate(request.prompt, sampling_params, request_id):
final_output = output
if final_output and final_output.outputs:
result = final_output.outputs[0]
return GenerateResponse(
text=result.text,
tokens_generated=len(result.token_ids),
finish_reason=result.finish_reason or "length"
)
return GenerateResponse(text="", tokens_generated=0, finish_reason="error")
# Run with: uvicorn script:app --host 0.0.0.0 --port 8000 --workers 1
Conclusion
LLM inference optimization requires a layered approach.
Key takeaways:
-
Understand KV Cache: Memorize
2 * layers * kv_heads * d_head * seq_len * dtype_bytes. Use GQA/MQA to cut KV Cache by 4–8x. -
PagedAttention: vLLM's core innovation — borrowed from OS virtual memory to eliminate KV Cache fragmentation.
-
Continuous Batching: Immediately insert new requests as completions happen, maximizing GPU utilization.
-
Speculative Decoding: Small draft model + large verifier = 2–3x speedup at the right acceptance rate.
-
FlashAttention: Reduces attention's memory from O(N^2) to O(N), enabling long contexts.
Production deployment recommendations:
- Small services: vLLM + AWQ 4-bit + prefix caching
- Large services: TensorRT-LLM or vLLM + Tensor Parallelism
- Lowest latency: Speculative Decoding + CUDA graphs
References
- vLLM/PagedAttention: arXiv:2309.06180
- Speculative Decoding: arXiv:2211.17192
- FlashAttention: arXiv:2205.14135
- FlashAttention-2: arXiv:2307.08691
- Medusa: arXiv:2401.10774
- Continuous Batching: Anyscale Blog
- DeepSeek-V2 MLA: arXiv:2405.04434