- Published on
LLM Inference Optimization Complete Guide: KV Cache, Speculative Decoding, Continuous Batching
- Authors

- Name
- Youngju Kim
- @fjvbn20031
Introduction
When you deploy a large language model (LLM) to production, you immediately face a challenge: inference speed and cost. A first query to a GPT-4 class model can take several seconds, and throughput degrades sharply as concurrent users grow.
This guide thoroughly explores the core techniques for LLM inference optimization. From KV Cache internals to PagedAttention, Speculative Decoding, FlashAttention, and the latest vLLM and TensorRT-LLM engines — we don't just cover how to use them, we understand why they work.
1. Understanding the LLM Inference Pipeline
1.1 Two Phases: Prefill and Decode
LLM text generation is split into two distinct phases.
Prefill Phase (Prompt Processing)
- Processes all tokens of the input prompt simultaneously (parallel)
- Generates and stores Key/Value caches at each layer
- Compute-Bound: GPU compute throughput is the bottleneck
- Directly affects TTFT (Time To First Token)
Decode Phase (Token Generation)
- Generates one token at a time in an autoregressive manner
- References KV Cache from all previously generated tokens
- Memory-Bandwidth-Bound: HBM read speed is the bottleneck
- Directly affects TPOT (Time Per Output Token)
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
def measure_prefill_decode_time(model, tokenizer, prompt: str, max_new_tokens: int = 100):
"""Measure prefill and decode phase timing"""
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
input_len = inputs["input_ids"].size(1)
# Measure TTFT
torch.cuda.synchronize()
prefill_start = time.perf_counter()
with torch.no_grad():
first_output = model.generate(
**inputs,
max_new_tokens=1,
do_sample=False
)
torch.cuda.synchronize()
ttft = time.perf_counter() - prefill_start
# Measure total generation
torch.cuda.synchronize()
total_start = time.perf_counter()
with torch.no_grad():
full_output = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False
)
torch.cuda.synchronize()
total_time = time.perf_counter() - total_start
output_tokens = full_output.size(1) - input_len
decode_time = total_time - ttft
tpot = decode_time / max(output_tokens - 1, 1)
print(f"Input tokens: {input_len}")
print(f"Output tokens: {output_tokens}")
print(f"TTFT (first token latency): {ttft * 1000:.1f} ms")
print(f"TPOT (per token time): {tpot * 1000:.1f} ms")
print(f"Throughput: {output_tokens / total_time:.1f} tokens/sec")
return ttft, tpot
1.2 Memory Bandwidth Analysis
Understanding why the decode phase is memory-bound:
def analyze_memory_bandwidth():
"""LLM inference memory bandwidth analysis"""
# Example: Llama-2-7B config
model_params = {
"num_layers": 32,
"hidden_size": 4096,
"num_heads": 32,
"head_dim": 128,
"vocab_size": 32000,
}
dtype_bytes = 2 # FP16: 2 bytes
# Weight memory
attn_weight = 4 * model_params["hidden_size"] ** 2 # Q, K, V, O projections
ffn_weight = 8 * model_params["hidden_size"] ** 2 # SwiGLU: Up, Gate, Down
layer_weight = (attn_weight + ffn_weight) * dtype_bytes
total_weight_bytes = layer_weight * model_params["num_layers"]
total_weight_gb = total_weight_bytes / 1e9
print(f"Model weights: {total_weight_gb:.2f} GB")
# KV Cache memory (proportional to sequence length)
seq_len = 2048
kv_cache_per_token = (
2 *
model_params["num_layers"] *
model_params["num_heads"] *
model_params["head_dim"] *
dtype_bytes
)
kv_cache_total = kv_cache_per_token * seq_len / 1e6
print(f"KV Cache ({seq_len} tokens): {kv_cache_total:.2f} MB")
print(f"KV Cache per token: {kv_cache_per_token} bytes")
# A100 memory bandwidth: 2 TB/s
memory_bandwidth_tbs = 2.0
# Decode: weights loaded once per token
memory_bound_tps = memory_bandwidth_tbs * 1e12 / total_weight_bytes
# A100 FP16: 312 TFLOPS
compute_throughput = 312e12
flops_per_token = 2 * total_weight_bytes
compute_bound_tps = compute_throughput / flops_per_token
print(f"\nFor batch_size=1:")
print(f"Memory-bound throughput: {memory_bound_tps:.1f} tokens/sec")
print(f"Compute-bound throughput: {compute_bound_tps:.1f} tokens/sec")
print(f"Actual bottleneck: {'memory' if memory_bound_tps < compute_bound_tps else 'compute'}")
analyze_memory_bandwidth()
1.3 Inference Cost Analysis
def estimate_inference_cost(
model_size_b: float,
tokens_per_request: int,
requests_per_day: int,
gpu_cost_per_hour: float = 3.0 # A100 hourly price (USD)
):
"""Estimate inference cost"""
# Empirical throughput estimates
# 7B model: ~100 tok/s, 70B model: ~20 tok/s (A100)
throughput_tps = 100 / (model_size_b / 7) ** 0.6
total_tokens_per_day = tokens_per_request * requests_per_day
seconds_needed = total_tokens_per_day / throughput_tps
hours_needed = seconds_needed / 3600
daily_cost = hours_needed * gpu_cost_per_hour
cost_per_1k_tokens = daily_cost / (total_tokens_per_day / 1000)
print(f"Model: {model_size_b}B parameters")
print(f"Daily requests: {requests_per_day:,}")
print(f"Tokens per request: {tokens_per_request}")
print(f"Total daily tokens: {total_tokens_per_day:,}")
print(f"Estimated throughput: {throughput_tps:.1f} tokens/sec")
print(f"GPU hours needed: {hours_needed:.2f}")
print(f"Daily cost: ${daily_cost:.2f}")
print(f"Cost per 1K tokens: ${cost_per_1k_tokens:.4f}")
estimate_inference_cost(
model_size_b=7.0,
tokens_per_request=500,
requests_per_day=10000
)
2. KV Cache: The Core Optimization
2.1 Why KV Cache is Necessary
In transformer attention, each token computes attention against all previous tokens. To avoid recomputing already-processed tokens, we cache the K and V matrices.
import torch
import torch.nn as nn
import math
class MultiHeadAttentionWithKVCache(nn.Module):
"""Multi-head attention with KV Cache support"""
def __init__(self, d_model: int, num_heads: int, max_seq_len: int = 4096):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_head = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
# KV Cache initialization
self.register_buffer(
'k_cache',
torch.zeros(1, max_seq_len, num_heads, self.d_head)
)
self.register_buffer(
'v_cache',
torch.zeros(1, max_seq_len, num_heads, self.d_head)
)
self.cache_pos = 0
def forward(self, x: torch.Tensor, use_cache: bool = True, position: int = None):
batch_size, seq_len, _ = x.shape
q = self.W_q(x).reshape(batch_size, seq_len, self.num_heads, self.d_head)
k = self.W_k(x).reshape(batch_size, seq_len, self.num_heads, self.d_head)
v = self.W_v(x).reshape(batch_size, seq_len, self.num_heads, self.d_head)
if use_cache:
start_pos = self.cache_pos if position is None else position
self.k_cache[:, start_pos:start_pos + seq_len] = k
self.v_cache[:, start_pos:start_pos + seq_len] = v
if position is None:
self.cache_pos += seq_len
total_len = self.cache_pos if position is None else start_pos + seq_len
k = self.k_cache[:, :total_len]
v = self.v_cache[:, :total_len]
scale = math.sqrt(self.d_head)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale
attn_weights = torch.softmax(attn_scores, dim=-1)
output = torch.matmul(attn_weights, v)
output = output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
output = self.W_o(output)
return output
def clear_cache(self):
self.k_cache.zero_()
self.v_cache.zero_()
self.cache_pos = 0
2.2 KV Cache Memory Calculation
def calculate_kv_cache_memory(
model_config: dict,
batch_size: int,
seq_len: int,
dtype_bytes: int = 2 # FP16
) -> dict:
"""Calculate KV Cache memory usage"""
num_layers = model_config["num_layers"]
num_kv_heads = model_config.get("num_kv_heads", model_config["num_heads"])
head_dim = model_config["head_dim"]
# Formula: 2 (K+V) * layers * kv_heads * head_dim * seq_len * batch * dtype
kv_cache_bytes = (
2 *
num_layers *
num_kv_heads *
head_dim *
seq_len *
batch_size *
dtype_bytes
)
return {
"kv_cache_bytes": kv_cache_bytes,
"kv_cache_mb": kv_cache_bytes / 1e6,
"kv_cache_gb": kv_cache_bytes / 1e9,
"per_token_bytes": kv_cache_bytes // seq_len,
}
models = {
"Llama-2-7B (MHA)": {
"num_layers": 32, "num_heads": 32,
"num_kv_heads": 32, "head_dim": 128
},
"Llama-2-70B (GQA)": {
"num_layers": 80, "num_heads": 64,
"num_kv_heads": 8, "head_dim": 128
},
"Mistral-7B (GQA)": {
"num_layers": 32, "num_heads": 32,
"num_kv_heads": 8, "head_dim": 128
},
}
print("KV Cache memory (batch=1, seq=4096)")
print("=" * 65)
for name, config in models.items():
result = calculate_kv_cache_memory(config, batch_size=1, seq_len=4096)
print(f"{name:<25} {result['kv_cache_gb']:.2f} GB "
f"({result['per_token_bytes']:,} bytes/token)")
2.3 Grouped Query Attention (GQA)
GQA is a key technique for reducing KV Cache — multiple query heads share fewer KV heads.
import torch
import torch.nn as nn
import math
class GroupedQueryAttention(nn.Module):
"""Grouped Query Attention (GQA) implementation"""
def __init__(self, d_model: int, num_q_heads: int, num_kv_heads: int):
super().__init__()
assert num_q_heads % num_kv_heads == 0
self.d_model = d_model
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.num_groups = num_q_heads // num_kv_heads
self.d_head = d_model // num_q_heads
self.W_q = nn.Linear(d_model, num_q_heads * self.d_head, bias=False)
self.W_k = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
self.W_v = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
def forward(self, x: torch.Tensor, kv_cache=None):
batch_size, seq_len, _ = x.shape
q = self.W_q(x).reshape(batch_size, seq_len, self.num_q_heads, self.d_head)
k = self.W_k(x).reshape(batch_size, seq_len, self.num_kv_heads, self.d_head)
v = self.W_v(x).reshape(batch_size, seq_len, self.num_kv_heads, self.d_head)
if kv_cache is not None:
k = torch.cat([kv_cache["k"], k], dim=1)
v = torch.cat([kv_cache["v"], v], dim=1)
new_kv_cache = {"k": k, "v": v}
q = q.transpose(1, 2) # [B, Q_heads, S, d_head]
k = k.transpose(1, 2) # [B, KV_heads, T, d_head]
v = v.transpose(1, 2)
# GQA: expand KV heads to match Q heads
k = k.repeat_interleave(self.num_groups, dim=1)
v = v.repeat_interleave(self.num_groups, dim=1)
scale = math.sqrt(self.d_head)
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale
attn_weights = torch.softmax(attn_scores, dim=-1)
output = torch.matmul(attn_weights, v)
output = output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
return self.W_o(output), new_kv_cache
def compare_attention_variants():
"""Compare KV Cache memory across attention variants"""
# 70B model (Llama-2-70B)
num_layers = 80
d_head = 128
seq_len = 4096
dtype_bytes = 2
variants = {
"MHA (32 KV heads)": 32,
"GQA (8 KV heads)": 8,
"MQA (1 KV head)": 1,
}
print("70B model attention variant KV Cache comparison")
print(f"(seq_len={seq_len}, batch=1)")
print("=" * 55)
for name, num_kv_heads in variants.items():
kv_bytes = 2 * num_layers * num_kv_heads * d_head * seq_len * dtype_bytes
kv_gb = kv_bytes / 1e9
print(f"{name:<25} {kv_gb:.2f} GB")
compare_attention_variants()
2.4 DeepSeek MLA (Multi-Head Latent Attention)
MLA, introduced in DeepSeek-V2, compresses KV Cache into low-dimensional latent vectors.
class MultiHeadLatentAttention(nn.Module):
"""
DeepSeek MLA — compresses KV Cache to low-rank latent vectors
Core idea:
- Instead of storing high-dim K,V, store low-dim latent c_kv
- Recover K, V from c_kv via up-projection
- KV Cache size: num_layers * kv_lora_rank * seq_len
(vs standard: 2 * num_layers * num_kv_heads * d_head * seq_len)
"""
def __init__(
self,
d_model: int = 5120,
num_heads: int = 128,
kv_lora_rank: int = 512,
qk_nope_head_dim: int = 128,
qk_rope_head_dim: int = 64,
v_head_dim: int = 128,
):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.kv_lora_rank = kv_lora_rank
# Q projection (LoRA-style)
self.q_a_proj = nn.Linear(d_model, 1536, bias=False)
self.q_b_proj = nn.Linear(
1536,
num_heads * (qk_nope_head_dim + qk_rope_head_dim),
bias=False
)
# KV down-projection: d_model -> kv_lora_rank
# Only THIS is stored in KV Cache!
self.kv_a_proj = nn.Linear(
d_model,
kv_lora_rank + qk_rope_head_dim,
bias=False
)
# KV up-projection: kv_lora_rank -> K, V
self.kv_b_proj = nn.Linear(
kv_lora_rank,
num_heads * (qk_nope_head_dim + v_head_dim),
bias=False
)
self.o_proj = nn.Linear(num_heads * v_head_dim, d_model, bias=False)
def forward(self, x: torch.Tensor, compressed_kv_cache=None):
batch_size, seq_len, _ = x.shape
# KV compression (only this result goes in cache)
kv_compressed = self.kv_a_proj(x) # [B, S, kv_lora_rank + rope_dim]
if compressed_kv_cache is not None:
kv_compressed_total = torch.cat([compressed_kv_cache, kv_compressed], dim=1)
else:
kv_compressed_total = kv_compressed
# Recover full K, V from cached compressed representation
kv_content = kv_compressed_total[:, :, :self.kv_lora_rank]
kv_full = self.kv_b_proj(kv_content)
return None, kv_compressed
def compare_mla_vs_mha():
"""Compare MLA vs MHA KV Cache"""
seq_len = 4096
dtype_bytes = 2 # BF16
num_layers = 60 # DeepSeek-V2
num_heads = 128
head_dim = 128
mha_kv_gb = 2 * num_layers * num_heads * head_dim * seq_len * dtype_bytes / 1e9
kv_lora_rank = 512
rope_dim = 64
mla_kv_gb = (kv_lora_rank + rope_dim) * num_layers * seq_len * dtype_bytes / 1e9
print(f"MHA KV Cache: {mha_kv_gb:.2f} GB")
print(f"MLA KV Cache: {mla_kv_gb:.2f} GB")
print(f"Reduction: {mha_kv_gb / mla_kv_gb:.1f}x")
compare_mla_vs_mha()
3. PagedAttention: vLLM's Core Innovation
3.1 The Problem with Conventional KV Cache
Traditional LLM serving systems pre-allocate memory for the maximum sequence length per request:
Request 1: [PROMPT=200 tokens] [KV_CACHE=up to 1848 tokens reserved] → internal fragmentation
Request 2: [PROMPT=100 tokens] [KV_CACHE=1948 tokens reserved]
Request 3: Blocked waiting (external fragmentation)
This wastes 60–80% of GPU memory.
3.2 PagedAttention: How It Works
Inspired by OS virtual memory, PagedAttention manages KV Cache in fixed-size physical blocks.
from dataclasses import dataclass, field
from typing import Dict, List, Optional
import torch
@dataclass
class PhysicalBlock:
"""Physical memory block"""
block_id: int
block_size: int # tokens per block (e.g., 16)
ref_count: int = 0
@dataclass
class LogicalBlock:
"""Logical block (mapped to a request)"""
physical_block_id: int
num_filled: int = 0
class PagedKVCacheManager:
"""PagedAttention KV Cache manager"""
def __init__(
self,
num_physical_blocks: int,
block_size: int,
num_layers: int,
num_kv_heads: int,
head_dim: int,
device: str = "cuda"
):
self.block_size = block_size
self.num_layers = num_layers
self.free_blocks: List[int] = list(range(num_physical_blocks))
self.all_blocks: Dict[int, PhysicalBlock] = {
i: PhysicalBlock(block_id=i, block_size=block_size)
for i in range(num_physical_blocks)
}
self.block_tables: Dict[int, List[LogicalBlock]] = {}
# Actual KV Cache tensor pool
self.kv_cache = torch.zeros(
num_physical_blocks, 2, num_layers, block_size, num_kv_heads, head_dim,
dtype=torch.float16,
device=device
)
def allocate_blocks_for_request(self, request_id: int, num_tokens: int):
"""Allocate blocks for a request"""
num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size
if len(self.free_blocks) < num_blocks_needed:
raise RuntimeError(
f"OOM: need {num_blocks_needed} blocks, {len(self.free_blocks)} available"
)
logical_blocks = []
for _ in range(num_blocks_needed):
physical_id = self.free_blocks.pop(0)
self.all_blocks[physical_id].ref_count = 1
logical_blocks.append(LogicalBlock(physical_block_id=physical_id))
self.block_tables[request_id] = logical_blocks
print(f"Request {request_id}: {num_blocks_needed} blocks allocated, "
f"{len(self.free_blocks)} remaining")
def append_token(self, request_id: int, layer: int, token_pos: int,
k: torch.Tensor, v: torch.Tensor):
"""Store KV for a new token in the cache"""
block_idx = token_pos // self.block_size
token_in_block = token_pos % self.block_size
logical_block = self.block_tables[request_id][block_idx]
physical_id = logical_block.physical_block_id
self.kv_cache[physical_id, 0, layer, token_in_block] = k
self.kv_cache[physical_id, 1, layer, token_in_block] = v
logical_block.num_filled = token_in_block + 1
def free_request(self, request_id: int):
"""Free blocks after request completes"""
if request_id in self.block_tables:
for logical_block in self.block_tables[request_id]:
phys_id = logical_block.physical_block_id
self.all_blocks[phys_id].ref_count -= 1
if self.all_blocks[phys_id].ref_count == 0:
self.free_blocks.append(phys_id)
del self.block_tables[request_id]
def copy_on_write(self, src_request_id: int, dst_request_id: int):
"""Copy-on-Write for prefix caching"""
src_blocks = self.block_tables[src_request_id]
dst_blocks = []
for logical_block in src_blocks:
phys_id = logical_block.physical_block_id
self.all_blocks[phys_id].ref_count += 1
dst_blocks.append(
LogicalBlock(
physical_block_id=phys_id,
num_filled=logical_block.num_filled
)
)
self.block_tables[dst_request_id] = dst_blocks
# Demo
manager = PagedKVCacheManager(
num_physical_blocks=1000,
block_size=16,
num_layers=32,
num_kv_heads=32,
head_dim=128
)
manager.allocate_blocks_for_request(request_id=1, num_tokens=200)
manager.allocate_blocks_for_request(request_id=2, num_tokens=500)
manager.allocate_blocks_for_request(request_id=3, num_tokens=100)
manager.free_request(request_id=1)
print(f"\nAfter request 1 completes — free blocks: {len(manager.free_blocks)}")
4. Continuous Batching
4.1 The Problem with Static Batching
Static batching waits for all requests in a batch to complete before starting the next:
t=0: [Request A: 500 tokens] [Request B: 100 tokens] [Request C: 300 tokens]
t=1: Request B done — but can't start new requests until A and C finish
t=2: Request C done
t=3: Request A done → only now can a new batch start
4.2 Continuous Batching (Iteration-level Scheduling)
from dataclasses import dataclass, field
from typing import List, Tuple
import time
@dataclass
class Request:
"""Inference request"""
request_id: str
input_ids: List[int]
max_new_tokens: int
generated_ids: List[int] = field(default_factory=list)
is_finished: bool = False
class ContinuousBatchingScheduler:
"""Continuous Batching scheduler"""
def __init__(self, max_batch_size: int = 32, max_seq_len: int = 4096):
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
self.waiting_queue: List[Request] = []
self.running_requests: List[Request] = []
self.finished_requests: List[Request] = []
def add_request(self, request: Request):
self.waiting_queue.append(request)
def _can_add_request(self, request: Request) -> bool:
if len(self.running_requests) + 1 > self.max_batch_size:
return False
total_tokens = sum(
len(r.input_ids) + len(r.generated_ids)
for r in self.running_requests
) + len(request.input_ids)
return total_tokens < self.max_seq_len * self.max_batch_size
def schedule_iteration(self) -> Tuple[List[Request], List[str]]:
"""
Schedule one iteration's batch
Returns:
(active requests, list of just-completed request IDs)
"""
completed_ids = []
still_running = []
for req in self.running_requests:
if req.is_finished:
self.finished_requests.append(req)
completed_ids.append(req.request_id)
else:
still_running.append(req)
self.running_requests = still_running
# Fill empty slots immediately with waiting requests
while self.waiting_queue and self._can_add_request(self.waiting_queue[0]):
new_request = self.waiting_queue.pop(0)
self.running_requests.append(new_request)
print(f"Added request {new_request.request_id} to batch "
f"(batch size: {len(self.running_requests)})")
return self.running_requests, completed_ids
def simulate_one_step(self, model_forward_fn):
"""Simulate one step"""
active_requests, completed = self.schedule_iteration()
if not active_requests:
return []
batch_input_ids = []
for req in active_requests:
if len(req.generated_ids) == 0:
batch_input_ids.append(req.input_ids) # Prefill
else:
batch_input_ids.append([req.generated_ids[-1]]) # Decode
outputs = model_forward_fn(batch_input_ids)
for req, next_token_id in zip(active_requests, outputs):
req.generated_ids.append(next_token_id)
if next_token_id == 2 or len(req.generated_ids) >= req.max_new_tokens:
req.is_finished = True
return completed
5. Speculative Decoding
5.1 The Idea: Draft + Verify
Speculative Decoding's core: a small draft model generates several tokens in parallel, and the large target model verifies them all at once (like a prefill).
Standard: [large model] → token1 → token2 → token3 → token4 → token5
Speculative: [small model] → generate (t1, t2, t3, t4, t5) in parallel
[large model] → verify all 5 tokens at once (parallel, like prefill)
only accepted tokens are kept
5.2 Acceptance Rate and Speedup Analysis
import torch
from typing import List, Tuple
def speculative_decode_step(
draft_model,
target_model,
input_ids: torch.Tensor,
draft_steps: int = 4,
temperature: float = 1.0
) -> Tuple[torch.Tensor, int, int]:
"""
One step of Speculative Decoding
Returns:
(generated tokens, num_accepted, num_drafted)
"""
# 1. Draft model generates candidate tokens
draft_tokens = []
draft_probs = []
current_ids = input_ids.clone()
for _ in range(draft_steps):
with torch.no_grad():
draft_logits = draft_model(current_ids).logits[:, -1, :]
draft_prob = torch.softmax(draft_logits / (temperature + 1e-8), dim=-1)
draft_token = torch.multinomial(draft_prob, num_samples=1)
draft_tokens.append(draft_token)
draft_probs.append(draft_prob)
current_ids = torch.cat([current_ids, draft_token], dim=1)
draft_sequence = torch.cat(draft_tokens, dim=1)
candidate_ids = torch.cat([input_ids, draft_sequence], dim=1)
# 2. Target model verifies all draft tokens in one pass
with torch.no_grad():
target_logits = target_model(candidate_ids).logits[
:, input_ids.size(1) - 1:-1, :
]
target_probs = torch.softmax(target_logits / (temperature + 1e-8), dim=-1)
# 3. Accept/reject each draft token
accepted_tokens = []
num_accepted = 0
for step in range(draft_steps):
token = draft_sequence[:, step]
p_draft = draft_probs[step].gather(1, token.unsqueeze(1)).squeeze(1)
p_target = target_probs[:, step, :].gather(1, token.unsqueeze(1)).squeeze(1)
acceptance_prob = torch.clamp(p_target / (p_draft + 1e-8), max=1.0)
accepted = torch.rand_like(acceptance_prob) < acceptance_prob
if not accepted.all():
break
accepted_tokens.append(token)
num_accepted += 1
# 4. Final token from target model
last_logits = target_model(candidate_ids).logits[:, input_ids.size(1) + num_accepted - 1, :]
last_prob = torch.softmax(last_logits / (temperature + 1e-8), dim=-1)
if num_accepted < draft_steps:
correction = torch.clamp(last_prob - draft_probs[num_accepted], min=0)
correction = correction / (correction.sum(dim=-1, keepdim=True) + 1e-8)
last_token = torch.multinomial(correction, num_samples=1)
else:
last_token = torch.multinomial(last_prob, num_samples=1)
accepted_tokens.append(last_token.squeeze(1))
final_tokens = torch.stack(accepted_tokens, dim=1)
return final_tokens, num_accepted, draft_steps
def analyze_speedup(acceptance_rate: float, draft_steps: int = 4) -> dict:
"""Speedup analysis based on acceptance rate"""
# E[accepted tokens] = sum_{k=0}^{K} alpha^k
expected_accepted = sum(
acceptance_rate ** k for k in range(draft_steps + 1)
)
# Draft model is 1/10 the size of target model
draft_model_ratio = 0.1
steps_with_speculative = draft_steps * draft_model_ratio + 1
speedup = expected_accepted / steps_with_speculative
return {
"acceptance_rate": acceptance_rate,
"expected_accepted": expected_accepted,
"speedup": speedup
}
print("Speculative Decoding speedup by acceptance rate (K=4)")
print("=" * 55)
for alpha in [0.5, 0.6, 0.7, 0.8, 0.9, 0.95]:
result = analyze_speedup(alpha, draft_steps=4)
print(f"Accept rate {alpha:.0%}: expected {result['expected_accepted']:.2f} tokens, "
f"speedup {result['speedup']:.2f}x")
5.3 Medusa: Multiple Draft Heads
import torch
import torch.nn as nn
class MedusaHead(nn.Module):
"""
Medusa: attach multiple draft heads to a single model
Each head predicts a future token:
- Head 1: predicts t+1
- Head 2: predicts t+2
- Head N: predicts t+N
"""
def __init__(
self,
hidden_size: int,
vocab_size: int,
num_heads: int = 4,
hidden_layers: int = 1
):
super().__init__()
self.num_heads = num_heads
self.heads = nn.ModuleList([
nn.Sequential(
*[nn.Linear(hidden_size, hidden_size, bias=False),
nn.SiLU()] * hidden_layers,
nn.Linear(hidden_size, vocab_size, bias=False)
)
for _ in range(num_heads)
])
def forward(self, hidden_states: torch.Tensor):
"""
Returns list of logits for each future position
"""
return [head(hidden_states) for head in self.heads]
class MedusaModel(nn.Module):
"""Medusa full model"""
def __init__(self, base_model, vocab_size: int, num_medusa_heads: int = 4):
super().__init__()
self.base_model = base_model
hidden_size = base_model.config.hidden_size
self.medusa_heads = MedusaHead(hidden_size, vocab_size, num_medusa_heads)
def forward(self, input_ids: torch.Tensor, use_medusa: bool = False):
base_output = self.base_model(input_ids, output_hidden_states=True)
base_logits = base_output.logits
if not use_medusa:
return base_logits, None
last_hidden = base_output.hidden_states[-1]
medusa_logits = self.medusa_heads(last_hidden)
return base_logits, medusa_logits
6. FlashAttention: Memory-Efficient Attention
6.1 Standard Attention's HBM Bottleneck
Standard attention repeatedly writes and reads intermediate results to HBM:
Standard Attention memory ops:
1. Read Q, K from HBM → read: O(N * d)
2. Compute S = Q @ K.T → write: O(N^2) ← bottleneck!
3. Read S from HBM for softmax → read: O(N^2)
4. Store P = softmax(S) → write: O(N^2)
5. Read P for P @ V → read: O(N^2)
6. Store final output → write: O(N * d)
Total HBM accesses: O(N^2) — quadratic in sequence length!
6.2 FlashAttention's Tiling Strategy
import torch
import math
def flash_attention_v1(Q, K, V, block_size=64):
"""
FlashAttention v1 simplified implementation
Avoids storing the full attention matrix in HBM via tiling
Key: Online Softmax for block-wise processing
"""
batch_size, num_heads, seq_len, d_head = Q.shape
scale = 1.0 / math.sqrt(d_head)
Q = Q * scale
O = torch.zeros_like(Q)
L = torch.zeros(batch_size, num_heads, seq_len, 1, device=Q.device)
M = torch.full((batch_size, num_heads, seq_len, 1), float('-inf'), device=Q.device)
num_blocks = (seq_len + block_size - 1) // block_size
for j in range(num_blocks):
k_start = j * block_size
k_end = min((j + 1) * block_size, seq_len)
K_j = K[:, :, k_start:k_end, :]
V_j = V[:, :, k_start:k_end, :]
for i in range(num_blocks):
q_start = i * block_size
q_end = min((i + 1) * block_size, seq_len)
Q_i = Q[:, :, q_start:q_end, :]
O_i = O[:, :, q_start:q_end, :]
L_i = L[:, :, q_start:q_end, :]
M_i = M[:, :, q_start:q_end, :]
# Compute attention scores in SRAM
S_ij = torch.matmul(Q_i, K_j.transpose(-2, -1))
# Online Softmax update
M_new = torch.maximum(M_i, S_ij.max(dim=-1, keepdim=True)[0])
P_ij = torch.exp(S_ij - M_new)
L_new = torch.exp(M_i - M_new) * L_i + P_ij.sum(dim=-1, keepdim=True)
# Rescale and update output
O_new = torch.exp(M_i - M_new) * O_i + torch.matmul(P_ij, V_j)
O[:, :, q_start:q_end, :] = O_new
L[:, :, q_start:q_end, :] = L_new
M[:, :, q_start:q_end, :] = M_new
return O / L
def compare_attention_implementations():
"""Compare FlashAttention vs standard attention"""
batch_size, num_heads, seq_len, d_head = 2, 32, 4096, 128
Q = torch.randn(batch_size, num_heads, seq_len, d_head, device='cuda', dtype=torch.float16)
K = torch.randn_like(Q)
V = torch.randn_like(Q)
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
flash_output = F.scaled_dot_product_attention(Q, K, V)
scale = 1.0 / math.sqrt(d_head)
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
attn_weights = torch.softmax(attn_scores, dim=-1)
standard_output = torch.matmul(attn_weights, V)
max_diff = (flash_output - standard_output).abs().max().item()
print(f"FlashAttention vs standard max diff: {max_diff:.6f}")
standard_attn_bytes = batch_size * num_heads * seq_len * seq_len * 2 # FP16
print(f"Standard attention matrix memory: {standard_attn_bytes / 1e9:.2f} GB")
print(f"FlashAttention matrix memory: ~0 GB (tiled, never fully materialized)")
6.3 Using PyTorch SDPA
import torch
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel
def modern_attention(q, k, v, is_causal=True, dropout_p=0.0):
"""
PyTorch 2.0+ scaled_dot_product_attention
Automatically selects FlashAttention 2/3
"""
return F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=dropout_p,
is_causal=is_causal,
scale=None # defaults to 1/sqrt(d_head)
)
# Flash Attention version highlights
flash_versions = {
"FlashAttention 1 (arXiv:2205.14135)": {
"key_innovation": "Tiling + Online Softmax",
"memory": "O(N) — no attention matrix stored",
"speedup": "2-4x vs standard"
},
"FlashAttention 2 (arXiv:2307.08691)": {
"key_innovation": "Work partitioning, FP16/BF16",
"memory": "O(N)",
"speedup": "5-9x vs standard on H100"
},
"FlashAttention 3 (arXiv:2407.08608)": {
"key_innovation": "H100-specific, FP8, async pipeline",
"memory": "O(N)",
"speedup": "1.5-2x vs FA2 on H100"
},
}
for name, info in flash_versions.items():
print(f"\n{name}")
for k, v in info.items():
print(f" {k}: {v}")
7. Multi-GPU Inference
7.1 Tensor Parallelism
Weight matrices are split across GPUs; each GPU handles a shard.
import torch
import torch.nn as nn
class TensorParallelLinear(nn.Module):
"""
Tensor Parallel Linear layer (column-parallel)
Each GPU owns out_features // world_size output neurons
"""
def __init__(
self,
in_features: int,
out_features: int,
world_size: int,
rank: int
):
super().__init__()
self.world_size = world_size
self.rank = rank
self.local_out_features = out_features // world_size
self.weight = nn.Parameter(
torch.randn(self.local_out_features, in_features) / (in_features ** 0.5)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
local_output = nn.functional.linear(x, self.weight)
# In a real distributed setup: dist.all_gather(output_list, local_output)
return local_output
def setup_vllm_multiGPU(model_name: str, tp_size: int):
"""Set up multi-GPU vLLM inference"""
from vllm import LLM, SamplingParams
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size,
gpu_memory_utilization=0.9
)
return llm
7.2 Full vLLM Usage
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
import asyncio
import time
def vllm_basic_usage():
"""vLLM basic usage"""
llm = LLM(
model="meta-llama/Llama-2-7b-hf",
tensor_parallel_size=1,
gpu_memory_utilization=0.90,
max_model_len=4096,
quantization=None, # "awq", "gptq", "squeezellm"
dtype="auto",
max_num_seqs=256,
enable_prefix_caching=True,
use_v2_block_manager=True,
)
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=200,
)
prompts = [
"Explain quantum computing in simple terms",
"What is the future of artificial intelligence?",
"How does the human brain work?",
]
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print(f"Prompt: {output.prompt[:50]}...")
print(f"Output: {output.outputs[0].text[:100]}...")
print(f"Tokens: {len(output.outputs[0].token_ids)}")
print()
return outputs
async def vllm_async_server():
"""vLLM async engine usage"""
engine_args = AsyncEngineArgs(
model="meta-llama/Llama-2-7b-hf",
tensor_parallel_size=1,
gpu_memory_utilization=0.90,
max_model_len=4096,
enable_prefix_caching=True,
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
async def generate_stream(prompt: str, request_id: str):
sampling_params = SamplingParams(temperature=0.8, max_tokens=200)
full_text = ""
async for output in engine.generate(prompt, sampling_params, request_id):
if output.outputs:
delta = output.outputs[0].text[len(full_text):]
full_text = output.outputs[0].text
if delta:
print(f"[{request_id}] {delta}", end="", flush=True)
if output.finished:
print(f"\n[{request_id}] done")
await asyncio.gather(
generate_stream("What is AI?", "req_1"),
generate_stream("Explain machine learning", "req_2"),
generate_stream("What is deep learning?", "req_3"),
)
8. Inference Engine Comparison
8.1 Major Engine Features
| Engine | Author | Key Features | Best For |
|---|---|---|---|
| vLLM | UC Berkeley | PagedAttention, Continuous Batching | General LLM serving |
| TGI | HuggingFace | Flash Attention 2, Speculative | HF model serving |
| TensorRT-LLM | NVIDIA | NVIDIA-optimized, FP8 | Max NVIDIA perf |
| DeepSpeed-MII | Microsoft | ZeRO inference, huge models | Multi-GPU giant models |
| llama.cpp | G. Gerganov | CPU-optimized, GGUF | Local execution |
8.2 Benchmark Results
import time
def run_inference_benchmark(engine_name: str, model, tokenizer, prompts, max_tokens=100):
"""Simple inference benchmark"""
num_warmup = 5
num_runs = 50
# Warmup
for prompt in prompts[:num_warmup]:
_ = model.generate(
tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
max_new_tokens=max_tokens,
do_sample=False
)
# Timed runs
import torch
torch.cuda.synchronize()
start = time.perf_counter()
total_tokens = 0
for prompt in prompts[:num_runs]:
output = model.generate(
tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
max_new_tokens=max_tokens,
do_sample=False
)
total_tokens += max_tokens
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
throughput = total_tokens / elapsed
latency_ms = elapsed / num_runs * 1000
print(f"\n{engine_name}")
print(f" Throughput: {throughput:.1f} tokens/sec")
print(f" Avg latency: {latency_ms:.1f} ms/request")
return throughput, latency_ms
# Example comparison (A100 80GB, Llama-2-7B, batch=1, 100 output tokens)
benchmark_results = {
"HuggingFace (FP16)": {"throughput": 52, "latency_ms": 1924},
"HuggingFace (Flash Attn 2)": {"throughput": 78, "latency_ms": 1280},
"vLLM": {"throughput": 120, "latency_ms": 832},
"vLLM + AWQ 4bit": {"throughput": 165, "latency_ms": 606},
"TensorRT-LLM": {"throughput": 180, "latency_ms": 555},
}
print("LLM inference engine benchmark (A100 80GB, Llama-2-7B)")
print("=" * 65)
print(f"{'Engine':<30} {'Throughput (tok/s)':<22} {'Latency (ms)':<15}")
print("-" * 65)
for engine, stats in benchmark_results.items():
print(f"{engine:<30} {stats['throughput']:<22} {stats['latency_ms']:<15}")
9. Prompt Caching
9.1 Prefix Caching
Reuse KV Cache when the same system prompt or document is processed repeatedly.
from vllm import LLM, SamplingParams
import time
def demonstrate_prefix_caching():
"""Demonstrate prefix caching benefit"""
llm = LLM(
model="meta-llama/Llama-2-7b-hf",
enable_prefix_caching=True,
max_model_len=4096,
)
# Long system prompt common to all requests (1000+ tokens)
system_prompt = (
"You are a helpful AI assistant with expertise in Python, "
"machine learning, data science, and cloud computing. "
) * 50
questions = [
"How do I optimize a Python loop?",
"What is gradient descent?",
"Explain containerization.",
"What is a neural network?",
]
sampling_params = SamplingParams(temperature=0.7, max_tokens=100)
cold_prompts = [f"{system_prompt}\n\nQuestion: {q}" for q in questions]
# Cold start (no cache)
cold_start = time.time()
llm.generate(cold_prompts, sampling_params)
cold_time = time.time() - cold_start
# Warm start (cache hit)
warm_start = time.time()
llm.generate(cold_prompts, sampling_params)
warm_time = time.time() - warm_start
print(f"Cold start (no cache): {cold_time:.2f}s")
print(f"Warm start (cache hit): {warm_time:.2f}s")
print(f"Speedup: {cold_time / warm_time:.2f}x")
def radix_tree_prefix_cache():
"""Radix Tree prefix cache implementation"""
class RadixNode:
def __init__(self):
self.children: dict = {}
self.kv_cache_block_id: int = None
class RadixTreeCache:
"""
Manages token sequences in a Radix Tree to share
common-prefix KV caches
"""
def __init__(self):
self.root = RadixNode()
self.cache_hits = 0
self.cache_misses = 0
def insert(self, token_ids: list, block_id: int):
node = self.root
for token_id in token_ids:
if token_id not in node.children:
node.children[token_id] = RadixNode()
node = node.children[token_id]
node.kv_cache_block_id = block_id
def lookup(self, token_ids: list) -> tuple:
"""Find the longest matching prefix"""
node = self.root
matched_len = 0
last_block_id = None
for i, token_id in enumerate(token_ids):
if token_id in node.children:
node = node.children[token_id]
matched_len = i + 1
if node.kv_cache_block_id is not None:
last_block_id = node.kv_cache_block_id
else:
break
if last_block_id is not None:
self.cache_hits += 1
else:
self.cache_misses += 1
return matched_len, last_block_id
def get_hit_rate(self) -> float:
total = self.cache_hits + self.cache_misses
return self.cache_hits / total if total > 0 else 0.0
return RadixTreeCache()
10. Practical Optimization Checklist
10.1 Step-by-Step Optimization Guide
class LLMOptimizationChecklist:
"""LLM inference optimization checklist"""
optimizations = [
{
"category": "Baseline",
"level": 1,
"items": [
{
"name": "Use FP16/BF16",
"impact": "High",
"effort": "Low",
"description": "FP32 → FP16: 2x memory saving, speed improvement",
"code": """
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)"""
},
{
"name": "Enable Flash Attention 2",
"impact": "High",
"effort": "Low",
"description": "2-4x attention speedup, memory saving",
"code": """
model = AutoModelForCausalLM.from_pretrained(
model_name,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16
)"""
},
]
},
{
"category": "KV Cache Optimization",
"level": 2,
"items": [
{
"name": "Choose GQA/MQA model",
"impact": "High",
"effort": "Medium",
"description": "4-8x KV Cache reduction, larger effective batch"
},
{
"name": "Prefix caching",
"impact": "Medium",
"effort": "Low",
"description": "Reuse KV Cache for common system prompts"
},
]
},
{
"category": "Batching Optimization",
"level": 3,
"items": [
{
"name": "Continuous Batching with vLLM",
"impact": "Very High",
"effort": "Low",
"description": "2-5x throughput improvement",
"code": """
from vllm import LLM, SamplingParams
llm = LLM(
model=model_name,
gpu_memory_utilization=0.90,
enable_prefix_caching=True,
)"""
},
]
},
{
"category": "Model Optimization",
"level": 4,
"items": [
{
"name": "AWQ 4-bit quantization",
"impact": "High",
"effort": "Medium",
"description": "4x memory reduction, 1.5-2x speed",
"code": """
from awq import AutoAWQForCausalLM
model = AutoAWQForCausalLM.from_quantized(
"model-awq-4bit",
fuse_layers=True
)"""
},
{
"name": "Speculative Decoding",
"impact": "Medium",
"effort": "High",
"description": "2-3x speedup (requires suitable draft model)"
},
]
},
{
"category": "Hardware Optimization",
"level": 5,
"items": [
{
"name": "Tensor Parallelism",
"impact": "Very High",
"effort": "Medium",
"description": "Linear throughput scaling with multiple GPUs"
},
{
"name": "CUDA graph capture",
"impact": "Medium",
"effort": "High",
"description": "Eliminate kernel launch overhead"
},
]
}
]
@classmethod
def print_checklist(cls):
print("=" * 70)
print("LLM Inference Optimization — Step-by-Step Checklist")
print("=" * 70)
for category in cls.optimizations:
print(f"\n[Level {category['level']}] {category['category']}")
print("-" * 50)
for item in category['items']:
impact_stars = {"Very High": "★★★", "High": "★★", "Medium": "★", "Low": "☆"}
print(f" ✓ {item['name']}")
print(f" Impact: {impact_stars.get(item['impact'], '?')} {item['impact']}")
print(f" Note: {item['description']}")
print("\nRecommended optimization order:")
print("1. Switch to BF16/FP16 (immediate, free)")
print("2. Enable Flash Attention 2 (immediate, just install)")
print("3. Serve with vLLM (max throughput)")
print("4. AWQ/GPTQ 4-bit quantization (4x memory reduction)")
print("5. Speculative Decoding (latency improvement)")
print("6. Multi-GPU Tensor Parallelism (scale out)")
LLMOptimizationChecklist.print_checklist()
10.2 End-to-End Production Setup
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from fastapi import FastAPI
from pydantic import BaseModel
import asyncio
import uvicorn
app = FastAPI(title="LLM Inference API")
class GenerateRequest(BaseModel):
prompt: str
max_tokens: int = 200
temperature: float = 0.7
top_p: float = 0.95
stream: bool = False
class GenerateResponse(BaseModel):
text: str
tokens_generated: int
finish_reason: str
# Global engine instance
engine: AsyncLLMEngine = None
def create_optimized_engine(model_name: str, **kwargs) -> AsyncLLMEngine:
"""Create a production-optimized vLLM engine"""
engine_args = AsyncEngineArgs(
model=model_name,
tensor_parallel_size=kwargs.get("tp_size", 1),
gpu_memory_utilization=kwargs.get("gpu_util", 0.90),
max_model_len=kwargs.get("max_model_len", 4096),
quantization=kwargs.get("quantization", None), # "awq" or "gptq"
dtype="auto",
max_num_seqs=kwargs.get("max_num_seqs", 256),
enable_prefix_caching=True,
use_v2_block_manager=True,
speculative_model=kwargs.get("draft_model", None), # optional draft model
num_speculative_tokens=kwargs.get("num_spec_tokens", 5),
)
return AsyncLLMEngine.from_engine_args(engine_args)
@app.on_event("startup")
async def startup():
global engine
engine = create_optimized_engine(
model_name="meta-llama/Llama-2-7b-hf",
tp_size=1,
gpu_util=0.90,
max_model_len=4096,
)
@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
sampling_params = SamplingParams(
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens,
)
request_id = f"req_{id(request)}"
final_output = None
async for output in engine.generate(request.prompt, sampling_params, request_id):
final_output = output
if final_output and final_output.outputs:
result = final_output.outputs[0]
return GenerateResponse(
text=result.text,
tokens_generated=len(result.token_ids),
finish_reason=result.finish_reason or "length"
)
return GenerateResponse(text="", tokens_generated=0, finish_reason="error")
# Run with: uvicorn script:app --host 0.0.0.0 --port 8000 --workers 1
Conclusion
LLM inference optimization requires a layered approach.
Key takeaways:
-
Understand KV Cache: Memorize
2 * layers * kv_heads * d_head * seq_len * dtype_bytes. Use GQA/MQA to cut KV Cache by 4–8x. -
PagedAttention: vLLM's core innovation — borrowed from OS virtual memory to eliminate KV Cache fragmentation.
-
Continuous Batching: Immediately insert new requests as completions happen, maximizing GPU utilization.
-
Speculative Decoding: Small draft model + large verifier = 2–3x speedup at the right acceptance rate.
-
FlashAttention: Reduces attention's memory from O(N^2) to O(N), enabling long contexts.
Production deployment recommendations:
- Small services: vLLM + AWQ 4-bit + prefix caching
- Large services: TensorRT-LLM or vLLM + Tensor Parallelism
- Lowest latency: Speculative Decoding + CUDA graphs
References
- vLLM/PagedAttention: arXiv:2309.06180
- Speculative Decoding: arXiv:2211.17192
- FlashAttention: arXiv:2205.14135
- FlashAttention-2: arXiv:2307.08691
- Medusa: arXiv:2401.10774
- Continuous Batching: Anyscale Blog
- DeepSeek-V2 MLA: arXiv:2405.04434