Skip to content

Split View: LLM 추론 최적화 완전 가이드: KV Cache, Speculative Decoding, Continuous Batching

|

LLM 추론 최적화 완전 가이드: KV Cache, Speculative Decoding, Continuous Batching

들어가며

대형 언어 모델(LLM)을 프로덕션에 배포하면 즉시 직면하는 문제가 있습니다. 바로 추론 속도와 비용입니다. GPT-4 수준의 모델을 처음 쿼리했을 때 수 초의 지연이 발생하고, 동시 사용자가 늘어나면 처리량이 급격히 저하됩니다.

이 가이드는 LLM 추론 최적화의 핵심 기술을 완전히 파헤칩니다. KV Cache의 작동 원리부터 PagedAttention, Speculative Decoding, FlashAttention, 그리고 최신 vLLM과 TensorRT-LLM 엔진까지 — 단순한 사용법이 아니라 왜 작동하는지 원리를 이해합니다.


1. LLM 추론 과정 이해

1.1 두 단계: Prefill과 Decode

LLM 텍스트 생성은 두 단계로 나뉩니다.

Prefill 단계 (프롬프트 처리)

  • 입력 프롬프트의 모든 토큰을 동시에 처리
  • 각 레이어에서 Key/Value 캐시를 생성하여 저장
  • 연산 집약적(Compute-Bound): GPU 연산 능력이 병목
  • TTFT (Time To First Token)에 직접 영향

Decode 단계 (토큰 생성)

  • 한 번에 하나의 토큰을 자기회귀 방식으로 생성
  • 이전에 생성된 모든 토큰의 KV Cache를 참조
  • 메모리 대역폭 집약적(Memory-Bound): HBM 읽기 속도가 병목
  • TPOT (Time Per Output Token)에 직접 영향
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time

def measure_prefill_decode_time(model, tokenizer, prompt: str, max_new_tokens: int = 100):
    """Prefill과 Decode 단계 시간 측정"""

    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    input_len = inputs["input_ids"].size(1)

    # Prefill 측정
    torch.cuda.synchronize()
    prefill_start = time.perf_counter()

    with torch.no_grad():
        # 첫 번째 토큰까지 (Prefill + 첫 Decode)
        first_output = model.generate(
            **inputs,
            max_new_tokens=1,
            do_sample=False
        )

    torch.cuda.synchronize()
    ttft = time.perf_counter() - prefill_start

    # 전체 생성 측정
    torch.cuda.synchronize()
    total_start = time.perf_counter()

    with torch.no_grad():
        full_output = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False
        )

    torch.cuda.synchronize()
    total_time = time.perf_counter() - total_start

    output_tokens = full_output.size(1) - input_len
    decode_time = total_time - ttft
    tpot = decode_time / max(output_tokens - 1, 1)

    print(f"입력 토큰 수: {input_len}")
    print(f"생성 토큰 수: {output_tokens}")
    print(f"TTFT (첫 토큰 지연): {ttft * 1000:.1f} ms")
    print(f"TPOT (토큰당 시간): {tpot * 1000:.1f} ms")
    print(f"처리량: {output_tokens / total_time:.1f} 토큰/초")

    return ttft, tpot

1.2 메모리 병목 분석

Decode 단계가 왜 메모리 집약적인지 이해합니다.

def analyze_memory_bandwidth():
    """LLM 추론의 메모리 대역폭 분석"""

    # 예시: Llama-2-7B 설정
    model_params = {
        "num_layers": 32,
        "hidden_size": 4096,
        "num_heads": 32,
        "head_dim": 128,
        "vocab_size": 32000,
    }

    dtype_bytes = 2  # FP16: 2 bytes

    # 가중치 메모리 (한 번 로드)
    # 각 Transformer 레이어의 가중치
    attn_weight = 4 * model_params["hidden_size"] ** 2  # Q, K, V, O 프로젝션
    ffn_weight = 8 * model_params["hidden_size"] ** 2  # Up, Gate, Down 프로젝션 (SwiGLU)
    layer_weight = (attn_weight + ffn_weight) * dtype_bytes

    total_weight_bytes = layer_weight * model_params["num_layers"]
    total_weight_gb = total_weight_bytes / 1e9

    print(f"모델 가중치: {total_weight_gb:.2f} GB")

    # KV Cache 메모리 (시퀀스 길이에 비례)
    seq_len = 2048
    kv_cache_per_token = (
        2 *  # K와 V
        model_params["num_layers"] *
        model_params["num_heads"] *
        model_params["head_dim"] *
        dtype_bytes
    )

    kv_cache_total = kv_cache_per_token * seq_len / 1e6
    print(f"KV Cache ({seq_len} 토큰): {kv_cache_total:.2f} MB")
    print(f"토큰당 KV Cache: {kv_cache_per_token} bytes")

    # A100 메모리 대역폭: 2 TB/s
    memory_bandwidth_tbs = 2.0  # TB/s

    # Decode 단계: 한 토큰 생성 시 가중치를 한 번씩 읽음
    # 배치 크기가 작을수록 연산 대비 메모리 읽기 비율이 높아짐
    batch_size = 1
    flops_per_token = 2 * total_weight_bytes  # 대략적인 FLOPs

    # A100 FP16: 312 TFLOPS
    compute_throughput = 312e12  # FLOPS

    # 메모리 대역폭 기준 처리량
    memory_bound_tps = memory_bandwidth_tbs * 1e12 / total_weight_bytes
    # 연산 기준 처리량
    compute_bound_tps = compute_throughput / flops_per_token

    print(f"\n배치 크기 {batch_size}일 때:")
    print(f"메모리 병목 처리량: {memory_bound_tps:.1f} 토큰/초")
    print(f"연산 병목 처리량: {compute_bound_tps:.1f} 토큰/초")
    print(f"실제 병목: {'메모리' if memory_bound_tps < compute_bound_tps else '연산'}")

analyze_memory_bandwidth()

1.3 추론 비용 분석

def estimate_inference_cost(
    model_size_b: float,
    tokens_per_request: int,
    requests_per_day: int,
    gpu_cost_per_hour: float = 3.0  # A100 시간당 가격 (USD)
):
    """추론 비용 추정"""

    # 처리량 추정 (경험적 수치)
    # 7B 모델: ~100 tok/s, 70B 모델: ~20 tok/s (A100 기준)
    throughput_tps = 100 / (model_size_b / 7) ** 0.6

    total_tokens_per_day = tokens_per_request * requests_per_day
    seconds_needed = total_tokens_per_day / throughput_tps
    hours_needed = seconds_needed / 3600

    # GPU 수 (병렬 처리 고려)
    # 일반적으로 1개 GPU로 처리하면
    daily_cost = hours_needed * gpu_cost_per_hour

    cost_per_1k_tokens = daily_cost / (total_tokens_per_day / 1000)

    print(f"모델: {model_size_b}B 파라미터")
    print(f"일일 요청: {requests_per_day:,}건")
    print(f"요청당 토큰: {tokens_per_request}")
    print(f"총 일일 토큰: {total_tokens_per_day:,}")
    print(f"예상 처리량: {throughput_tps:.1f} 토큰/초")
    print(f"필요 GPU 시간: {hours_needed:.2f}시간")
    print(f"일일 비용: ${daily_cost:.2f}")
    print(f"1K 토큰당 비용: ${cost_per_1k_tokens:.4f}")

# 예시
estimate_inference_cost(
    model_size_b=7.0,
    tokens_per_request=500,
    requests_per_day=10000
)

2. KV Cache: 핵심 최적화 기술

2.1 KV Cache의 필요성

트랜스포머 어텐션에서 각 토큰은 이전 모든 토큰과 어텐션을 계산합니다. 이미 처리한 토큰을 재계산하지 않으려면 K와 V 행렬을 캐싱합니다.

import torch
import torch.nn as nn
import math

class MultiHeadAttentionWithKVCache(nn.Module):
    """KV Cache를 지원하는 멀티헤드 어텐션"""

    def __init__(self, d_model: int, num_heads: int, max_seq_len: int = 4096):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

        # KV Cache 초기화
        self.register_buffer(
            'k_cache',
            torch.zeros(1, max_seq_len, num_heads, self.d_head)
        )
        self.register_buffer(
            'v_cache',
            torch.zeros(1, max_seq_len, num_heads, self.d_head)
        )
        self.cache_pos = 0

    def forward(
        self,
        x: torch.Tensor,
        use_cache: bool = True,
        position: int = None
    ):
        batch_size, seq_len, _ = x.shape

        # Q, K, V 계산
        q = self.W_q(x).reshape(batch_size, seq_len, self.num_heads, self.d_head)
        k = self.W_k(x).reshape(batch_size, seq_len, self.num_heads, self.d_head)
        v = self.W_v(x).reshape(batch_size, seq_len, self.num_heads, self.d_head)

        if use_cache:
            # 캐시에 현재 K, V 저장
            start_pos = self.cache_pos if position is None else position
            self.k_cache[:, start_pos:start_pos + seq_len] = k
            self.v_cache[:, start_pos:start_pos + seq_len] = v

            if position is None:
                self.cache_pos += seq_len

            # 캐시된 전체 K, V 사용
            total_len = self.cache_pos if position is None else start_pos + seq_len
            k = self.k_cache[:, :total_len]
            v = self.v_cache[:, :total_len]

        # 어텐션 계산
        scale = math.sqrt(self.d_head)

        # [batch, num_heads, seq_len, d_head] 형태로 변환
        q = q.transpose(1, 2)  # [B, H, S, D]
        k = k.transpose(1, 2)  # [B, H, T, D] (T: total cached length)
        v = v.transpose(1, 2)

        # 어텐션 스코어: [B, H, S, T]
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale

        # 소프트맥스
        attn_weights = torch.softmax(attn_scores, dim=-1)

        # 가중 합: [B, H, S, D]
        output = torch.matmul(attn_weights, v)

        # 원래 형태로 변환
        output = output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
        output = self.W_o(output)

        return output

    def clear_cache(self):
        """캐시 초기화"""
        self.k_cache.zero_()
        self.v_cache.zero_()
        self.cache_pos = 0

2.2 KV Cache 메모리 계산

def calculate_kv_cache_memory(
    model_config: dict,
    batch_size: int,
    seq_len: int,
    dtype_bytes: int = 2  # FP16
) -> dict:
    """KV Cache 메모리 사용량 계산"""

    num_layers = model_config["num_layers"]
    num_kv_heads = model_config.get("num_kv_heads", model_config["num_heads"])
    head_dim = model_config["head_dim"]

    # KV Cache 크기: 2 (K+V) * layers * kv_heads * head_dim * seq_len * dtype
    kv_cache_bytes = (
        2 *           # K와 V
        num_layers *
        num_kv_heads *
        head_dim *
        seq_len *
        batch_size *
        dtype_bytes
    )

    return {
        "kv_cache_bytes": kv_cache_bytes,
        "kv_cache_mb": kv_cache_bytes / 1e6,
        "kv_cache_gb": kv_cache_bytes / 1e9,
        "per_token_bytes": kv_cache_bytes // seq_len,
    }

# 모델별 KV Cache 비교
models = {
    "Llama-2-7B": {
        "num_layers": 32, "num_heads": 32,
        "num_kv_heads": 32, "head_dim": 128
    },
    "Llama-2-13B": {
        "num_layers": 40, "num_heads": 40,
        "num_kv_heads": 40, "head_dim": 128
    },
    "Llama-2-70B (GQA)": {
        "num_layers": 80, "num_heads": 64,
        "num_kv_heads": 8, "head_dim": 128  # GQA: 8 KV 헤드
    },
    "Mistral-7B (GQA)": {
        "num_layers": 32, "num_heads": 32,
        "num_kv_heads": 8, "head_dim": 128  # GQA: 8 KV 헤드
    },
}

print("KV Cache 메모리 사용량 (배치=1, seq=4096)")
print("=" * 70)
for name, config in models.items():
    result = calculate_kv_cache_memory(config, batch_size=1, seq_len=4096)
    print(f"{name:<25} {result['kv_cache_gb']:.2f} GB  "
          f"(토큰당 {result['per_token_bytes']:,} bytes)")

2.3 Grouped Query Attention (GQA)

GQA는 KV Cache를 줄이는 핵심 기술입니다. 여러 Query 헤드가 더 적은 수의 KV 헤드를 공유합니다.

import torch
import torch.nn as nn
import math

class GroupedQueryAttention(nn.Module):
    """Grouped Query Attention (GQA) 구현"""

    def __init__(
        self,
        d_model: int,
        num_q_heads: int,
        num_kv_heads: int,
    ):
        super().__init__()
        assert num_q_heads % num_kv_heads == 0

        self.d_model = d_model
        self.num_q_heads = num_q_heads
        self.num_kv_heads = num_kv_heads
        self.num_groups = num_q_heads // num_kv_heads
        self.d_head = d_model // num_q_heads

        self.W_q = nn.Linear(d_model, num_q_heads * self.d_head, bias=False)
        self.W_k = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
        self.W_v = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x: torch.Tensor, kv_cache=None):
        batch_size, seq_len, _ = x.shape

        # Q: [B, S, num_q_heads * d_head]
        q = self.W_q(x).reshape(batch_size, seq_len, self.num_q_heads, self.d_head)
        k = self.W_k(x).reshape(batch_size, seq_len, self.num_kv_heads, self.d_head)
        v = self.W_v(x).reshape(batch_size, seq_len, self.num_kv_heads, self.d_head)

        # KV Cache 업데이트
        if kv_cache is not None:
            k = torch.cat([kv_cache["k"], k], dim=1)
            v = torch.cat([kv_cache["v"], v], dim=1)
        new_kv_cache = {"k": k, "v": v}

        total_len = k.size(1)

        # [B, num_heads, S, d_head] 형태로
        q = q.transpose(1, 2)  # [B, Q_heads, S, d_head]
        k = k.transpose(1, 2)  # [B, KV_heads, T, d_head]
        v = v.transpose(1, 2)

        # GQA: KV 헤드를 Q 헤드 수만큼 반복
        # k: [B, KV_heads, T, d_head] -> [B, Q_heads, T, d_head]
        k = k.repeat_interleave(self.num_groups, dim=1)
        v = v.repeat_interleave(self.num_groups, dim=1)

        # 어텐션 계산
        scale = math.sqrt(self.d_head)
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale
        attn_weights = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_weights, v)

        output = output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
        return self.W_o(output), new_kv_cache

# MHA vs GQA vs MQA 메모리 비교
def compare_attention_variants():
    """어텐션 변형 KV Cache 메모리 비교"""

    # 70B 모델 기준 (Llama-2-70B)
    num_layers = 80
    d_head = 128
    seq_len = 4096
    batch_size = 1
    dtype_bytes = 2  # FP16

    variants = {
        "MHA (32 KV heads)": 32,
        "GQA (8 KV heads)": 8,
        "MQA (1 KV head)": 1,
    }

    print("70B 모델 어텐션 변형별 KV Cache 비교")
    print(f"(seq_len={seq_len}, batch={batch_size})")
    print("=" * 55)

    for name, num_kv_heads in variants.items():
        kv_bytes = 2 * num_layers * num_kv_heads * d_head * seq_len * batch_size * dtype_bytes
        kv_gb = kv_bytes / 1e9
        print(f"{name:<25} {kv_gb:.2f} GB")

compare_attention_variants()

2.4 DeepSeek MLA (Multi-Head Latent Attention)

DeepSeek-V2에서 도입된 MLA는 KV Cache를 저차원 잠재 벡터로 압축합니다.

class MultiHeadLatentAttention(nn.Module):
    """
    DeepSeek MLA - KV Cache를 저차원 잠재 벡터로 압축

    핵심 아이디어:
    - KV를 고차원에 저장하는 대신, 저차원 잠재 벡터 c_kv를 저장
    - c_kv에서 K, V를 복원 (업프로젝션)
    - KV Cache 크기: num_layers * kv_lora_rank * seq_len
      (vs 기존: 2 * num_layers * num_kv_heads * d_head * seq_len)
    """

    def __init__(
        self,
        d_model: int = 5120,
        num_heads: int = 128,
        kv_lora_rank: int = 512,  # 저차원 잠재 차원
        qk_nope_head_dim: int = 128,
        qk_rope_head_dim: int = 64,
        v_head_dim: int = 128,
    ):
        super().__init__()

        self.d_model = d_model
        self.num_heads = num_heads
        self.kv_lora_rank = kv_lora_rank
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.v_head_dim = v_head_dim

        # Q 프로젝션 (LoRA 스타일)
        self.q_a_proj = nn.Linear(d_model, 1536, bias=False)  # 다운프로젝션
        self.q_b_proj = nn.Linear(1536, num_heads * (qk_nope_head_dim + qk_rope_head_dim), bias=False)

        # KV 다운프로젝션: d_model -> kv_lora_rank
        # 이것만 KV Cache에 저장!
        self.kv_a_proj = nn.Linear(
            d_model,
            kv_lora_rank + qk_rope_head_dim,
            bias=False
        )

        # KV 업프로젝션: kv_lora_rank -> K, V
        self.kv_b_proj = nn.Linear(
            kv_lora_rank,
            num_heads * (qk_nope_head_dim + v_head_dim),
            bias=False
        )

        self.o_proj = nn.Linear(num_heads * v_head_dim, d_model, bias=False)

    def forward(self, x: torch.Tensor, compressed_kv_cache=None):
        """
        Args:
            x: [batch, seq, d_model]
            compressed_kv_cache: [batch, cache_len, kv_lora_rank + rope_dim]
        """
        batch_size, seq_len, _ = x.shape

        # Q 계산
        q = self.q_b_proj(self.q_a_proj(x))

        # KV 압축 (이 결과만 캐시에 저장)
        kv_compressed = self.kv_a_proj(x)  # [B, S, kv_lora_rank + rope_dim]

        # KV Cache 업데이트
        if compressed_kv_cache is not None:
            kv_compressed_total = torch.cat([compressed_kv_cache, kv_compressed], dim=1)
        else:
            kv_compressed_total = kv_compressed

        # 캐시된 압축 KV에서 실제 K, V 복원 (업프로젝션)
        kv_content = kv_compressed_total[:, :, :self.kv_lora_rank]  # rope 제외
        kv_full = self.kv_b_proj(kv_content)  # [B, T, num_heads * (nope + v_dim)]

        # 최종 어텐션 계산 (생략)

        return None, kv_compressed

# KV Cache 크기 비교 (DeepSeek-V2 기준)
def compare_mla_vs_mha():
    """MLA vs MHA KV Cache 비교"""

    seq_len = 4096
    dtype_bytes = 2  # BF16
    num_layers = 60  # DeepSeek-V2

    # MHA (기존)
    num_heads = 128
    head_dim = 128
    mha_kv_gb = 2 * num_layers * num_heads * head_dim * seq_len * dtype_bytes / 1e9

    # MLA (DeepSeek-V2)
    kv_lora_rank = 512
    rope_dim = 64
    mla_kv_gb = (kv_lora_rank + rope_dim) * num_layers * seq_len * dtype_bytes / 1e9

    print(f"MHA KV Cache: {mha_kv_gb:.2f} GB")
    print(f"MLA KV Cache: {mla_kv_gb:.2f} GB")
    print(f"절감 비율: {mha_kv_gb / mla_kv_gb:.1f}x")

compare_mla_vs_mha()

3. PagedAttention: vLLM의 핵심 혁신

3.1 기존 KV Cache의 문제점

기존 LLM 서빙 시스템은 요청마다 최대 시퀀스 길이를 위한 메모리를 미리 할당합니다.

요청 1: [PROMPT=200 tokens] [KV_CACHE=최대 2048-200=1848 tokens 예약] → 내부 단편화
요청 2: [PROMPT=100 tokens] [KV_CACHE=1948 tokens 예약]
요청 3: 메모리 부족으로 대기 (외부 단편화)

이로 인해 실제 GPU 메모리의 60~80%가 낭비됩니다.

3.2 PagedAttention 원리

OS의 가상 메모리에서 영감을 받아 KV Cache를 고정 크기의 물리 블록으로 관리합니다.

from dataclasses import dataclass, field
from typing import Dict, List, Optional, Set
import torch

@dataclass
class PhysicalBlock:
    """물리 메모리 블록"""
    block_id: int
    block_size: int  # 블록에 저장 가능한 토큰 수 (예: 16)
    device: str = "cuda"
    ref_count: int = 0  # 참조 카운트 (CoW를 위해)

    def __post_init__(self):
        # 실제 KV 텐서 할당
        # [2, num_layers, block_size, num_heads, head_dim]
        pass


@dataclass
class LogicalBlock:
    """논리 블록 (요청과 매핑)"""
    physical_block_id: int
    num_filled: int = 0  # 현재 채워진 토큰 수

class PagedKVCacheManager:
    """PagedAttention KV Cache 관리자"""

    def __init__(
        self,
        num_physical_blocks: int,
        block_size: int,
        num_layers: int,
        num_kv_heads: int,
        head_dim: int,
        device: str = "cuda"
    ):
        self.block_size = block_size
        self.num_layers = num_layers
        self.num_kv_heads = num_kv_heads
        self.head_dim = head_dim
        self.device = device

        # 물리 블록 풀 초기화
        self.free_blocks: List[int] = list(range(num_physical_blocks))
        self.all_blocks: Dict[int, PhysicalBlock] = {
            i: PhysicalBlock(block_id=i, block_size=block_size)
            for i in range(num_physical_blocks)
        }

        # 요청별 논리 블록 테이블
        self.block_tables: Dict[int, List[LogicalBlock]] = {}

        # 실제 KV Cache 텐서
        # [num_blocks, 2, num_layers, block_size, num_kv_heads, head_dim]
        self.kv_cache = torch.zeros(
            num_physical_blocks, 2, num_layers, block_size, num_kv_heads, head_dim,
            dtype=torch.float16,
            device=device
        )

    def allocate_blocks_for_request(self, request_id: int, num_tokens: int):
        """요청에 필요한 블록 할당"""
        num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size

        if len(self.free_blocks) < num_blocks_needed:
            raise RuntimeError(f"메모리 부족: {num_blocks_needed} 블록 필요, {len(self.free_blocks)} 가용")

        logical_blocks = []
        for i in range(num_blocks_needed):
            physical_id = self.free_blocks.pop(0)
            self.all_blocks[physical_id].ref_count = 1
            logical_blocks.append(
                LogicalBlock(physical_block_id=physical_id)
            )

        self.block_tables[request_id] = logical_blocks
        print(f"요청 {request_id}: {num_blocks_needed} 블록 할당, "
              f"잔여 블록: {len(self.free_blocks)}")

    def append_token(self, request_id: int, layer: int, token_pos: int, k: torch.Tensor, v: torch.Tensor):
        """새 토큰의 KV를 캐시에 추가"""
        block_idx = token_pos // self.block_size
        token_in_block = token_pos % self.block_size

        logical_block = self.block_tables[request_id][block_idx]
        physical_id = logical_block.physical_block_id

        # KV Cache에 저장
        self.kv_cache[physical_id, 0, layer, token_in_block] = k  # K
        self.kv_cache[physical_id, 1, layer, token_in_block] = v  # V
        logical_block.num_filled = token_in_block + 1

    def get_physical_block_ids(self, request_id: int) -> List[int]:
        """요청의 물리 블록 ID 목록 반환"""
        return [lb.physical_block_id for lb in self.block_tables[request_id]]

    def free_request(self, request_id: int):
        """요청 완료 후 블록 해제"""
        if request_id in self.block_tables:
            for logical_block in self.block_tables[request_id]:
                phys_id = logical_block.physical_block_id
                self.all_blocks[phys_id].ref_count -= 1

                if self.all_blocks[phys_id].ref_count == 0:
                    self.free_blocks.append(phys_id)

            del self.block_tables[request_id]

    def copy_on_write(self, src_request_id: int, dst_request_id: int):
        """Prefix Caching을 위한 Copy-on-Write"""
        src_blocks = self.block_tables[src_request_id]
        dst_blocks = []

        for logical_block in src_blocks:
            phys_id = logical_block.physical_block_id
            # 참조 카운트 증가 (실제로는 쓰기 시에만 복사)
            self.all_blocks[phys_id].ref_count += 1
            dst_blocks.append(
                LogicalBlock(physical_block_id=phys_id, num_filled=logical_block.num_filled)
            )

        self.block_tables[dst_request_id] = dst_blocks


# 사용 예시
manager = PagedKVCacheManager(
    num_physical_blocks=1000,
    block_size=16,
    num_layers=32,
    num_kv_heads=32,
    head_dim=128
)

# 3개의 요청 처리
manager.allocate_blocks_for_request(request_id=1, num_tokens=200)
manager.allocate_blocks_for_request(request_id=2, num_tokens=500)
manager.allocate_blocks_for_request(request_id=3, num_tokens=100)

# 요청 1 완료 후 해제
manager.free_request(request_id=1)
print(f"\n요청 1 완료 후 가용 블록: {len(manager.free_blocks)}")

4. Continuous Batching

4.1 정적 배치의 문제점

기존 배치 처리 방식은 요청들이 모두 완료될 때까지 기다립니다.

시간 t=0: [요청A: 500토큰] [요청B: 100토큰] [요청C: 300토큰]
시간 t=1: 요청B 완료, 하지만 A, C 대기 중이므로 GPU에 여유 있어도 새 요청 불가
시간 t=2: 요청C 완료
시간 t=3: 요청A 완료 → 이제야 새 배치 시작!

4.2 Continuous Batching (Iteration-level Scheduling)

from typing import List, Optional, Tuple
from dataclasses import dataclass
import asyncio
import torch
from queue import Queue
import threading

@dataclass
class Request:
    """추론 요청"""
    request_id: str
    input_ids: List[int]
    max_new_tokens: int
    generated_ids: List[int] = None
    is_finished: bool = False

    def __post_init__(self):
        self.generated_ids = []

class ContinuousBatchingScheduler:
    """Continuous Batching 스케줄러"""

    def __init__(
        self,
        max_batch_size: int = 32,
        max_seq_len: int = 4096
    ):
        self.max_batch_size = max_batch_size
        self.max_seq_len = max_seq_len

        self.waiting_queue: List[Request] = []
        self.running_requests: List[Request] = []
        self.finished_requests: List[Request] = []

    def add_request(self, request: Request):
        """새 요청 추가"""
        self.waiting_queue.append(request)

    def _can_add_request(self, request: Request) -> bool:
        """배치에 요청 추가 가능 여부 확인 (메모리 체크)"""
        current_batch_size = len(self.running_requests) + 1
        if current_batch_size > self.max_batch_size:
            return False

        # KV Cache 메모리 체크 (단순화)
        total_tokens = sum(
            len(r.input_ids) + len(r.generated_ids)
            for r in self.running_requests
        ) + len(request.input_ids)

        return total_tokens < self.max_seq_len * self.max_batch_size

    def schedule_iteration(self) -> Tuple[List[Request], List[str]]:
        """
        한 iteration을 위한 배치 스케줄링

        Returns:
            (실행할 요청들, 완료된 요청 ID들)
        """
        completed_ids = []

        # 완료된 요청 처리
        still_running = []
        for req in self.running_requests:
            if req.is_finished:
                self.finished_requests.append(req)
                completed_ids.append(req.request_id)
            else:
                still_running.append(req)
        self.running_requests = still_running

        # 대기 중인 요청을 배치에 추가 (핵심: 빈 슬롯을 즉시 채움)
        while self.waiting_queue and self._can_add_request(self.waiting_queue[0]):
            new_request = self.waiting_queue.pop(0)
            self.running_requests.append(new_request)
            print(f"배치에 요청 {new_request.request_id} 추가 "
                  f"(현재 배치 크기: {len(self.running_requests)})")

        return self.running_requests, completed_ids

    def simulate_one_step(self, model_forward_fn):
        """한 단계 시뮬레이션"""

        active_requests, completed = self.schedule_iteration()

        if not active_requests:
            return []

        # 현재 배치의 입력 준비
        # Prefill 요청: input_ids만 있는 경우
        # Decode 요청: 이전 KV Cache가 있는 경우
        batch_input_ids = []
        for req in active_requests:
            if len(req.generated_ids) == 0:
                # Prefill
                batch_input_ids.append(req.input_ids)
            else:
                # Decode (마지막 생성 토큰만)
                batch_input_ids.append([req.generated_ids[-1]])

        # 모델 실행 (실제로는 PagedAttention으로 처리)
        outputs = model_forward_fn(batch_input_ids)

        # 다음 토큰 처리
        for req, next_token_id in zip(active_requests, outputs):
            req.generated_ids.append(next_token_id)

            # 종료 조건 확인
            if (next_token_id == 2 or  # EOS token
                    len(req.generated_ids) >= req.max_new_tokens):
                req.is_finished = True

        return completed

# 처리량 비교 시뮬레이션
def simulate_throughput_comparison():
    """정적 배치 vs Continuous Batching 처리량 비교"""

    import random

    requests = [
        Request(
            request_id=str(i),
            input_ids=list(range(random.randint(50, 200))),
            max_new_tokens=random.randint(50, 500)
        )
        for i in range(20)
    ]

    # 정적 배치: 모든 요청이 완료될 때까지 기다림
    max_tokens_static = max(r.max_new_tokens for r in requests)
    total_iterations_static = max_tokens_static * 4  # 4개씩 배치

    # Continuous Batching: 완료되자마자 새 요청 추가
    total_tokens = sum(r.max_new_tokens for r in requests)
    total_iterations_cb = total_tokens  # 대략적인 추정

    print(f"정적 배치 예상 iteration 수: {total_iterations_static}")
    print(f"Continuous Batching 예상 iteration 수: {total_iterations_cb}")
    print(f"처리량 향상: {total_iterations_static / total_iterations_cb:.2f}x")

5. Speculative Decoding

5.1 아이디어: 초안 + 검증

Speculative Decoding의 핵심은 작은 드래프트 모델이 여러 토큰을 빠르게 생성하고, 큰 검증 모델이 한번에 검증하는 것입니다.

기존: [큰 모델] → 토큰1 → 토큰2 → 토큰3 → 토큰4 → 토큰5
투기적: [작은 모델]  (토큰1, 토큰2, 토큰3, 토큰4, 토큰5)를 병렬 생성
        [큰 모델]5개 토큰을 한번에 검증 (Prefill처럼 병렬!)
        수락된 토큰들만 사용

5.2 수락률 기반 속도 향상 분석

import numpy as np
import torch
from typing import List, Tuple

def speculative_decode_step(
    draft_model,
    target_model,
    input_ids: torch.Tensor,
    draft_steps: int = 4,
    temperature: float = 1.0
) -> Tuple[torch.Tensor, int, int]:
    """
    Speculative Decoding 한 단계

    Returns:
        (생성된 토큰들, 수락된 토큰 수, 드래프트 토큰 수)
    """
    batch_size = input_ids.size(0)

    # 1. 드래프트 모델로 후보 토큰 생성
    draft_tokens = []
    draft_probs = []

    current_ids = input_ids.clone()

    for _ in range(draft_steps):
        with torch.no_grad():
            draft_output = draft_model(current_ids)
            draft_logits = draft_output.logits[:, -1, :]  # [B, vocab_size]

        # 드래프트 확률 계산
        if temperature > 0:
            draft_prob = torch.softmax(draft_logits / temperature, dim=-1)
        else:
            draft_prob = torch.zeros_like(draft_logits)
            draft_prob.scatter_(1, draft_logits.argmax(dim=-1, keepdim=True), 1.0)

        # 드래프트 토큰 샘플링
        draft_token = torch.multinomial(draft_prob, num_samples=1)  # [B, 1]
        draft_tokens.append(draft_token)
        draft_probs.append(draft_prob)

        # 다음 스텝을 위해 토큰 추가
        current_ids = torch.cat([current_ids, draft_token], dim=1)

    # 드래프트 토큰들을 하나의 텐서로
    draft_sequence = torch.cat(draft_tokens, dim=1)  # [B, draft_steps]
    candidate_ids = torch.cat([input_ids, draft_sequence], dim=1)

    # 2. 검증 모델로 드래프트 토큰 한번에 검증
    with torch.no_grad():
        target_output = target_model(candidate_ids)
        target_logits = target_output.logits[:, input_ids.size(1) - 1:-1, :]  # [B, draft_steps, vocab_size]

    # 검증 모델의 확률
    if temperature > 0:
        target_probs = torch.softmax(target_logits / temperature, dim=-1)
    else:
        target_probs = torch.zeros_like(target_logits)
        target_probs.scatter_(2, target_logits.argmax(dim=-1, keepdim=True), 1.0)

    # 3. 각 드래프트 토큰 수락/거부 결정
    accepted_tokens = []
    num_accepted = 0

    for step in range(draft_steps):
        token = draft_sequence[:, step]  # [B]

        # 수락 확률 계산: min(1, p_target / p_draft)
        p_draft = draft_probs[step].gather(1, token.unsqueeze(1)).squeeze(1)
        p_target = target_probs[:, step, :].gather(1, token.unsqueeze(1)).squeeze(1)

        acceptance_prob = torch.clamp(p_target / (p_draft + 1e-8), max=1.0)

        # 랜덤 수락/거부
        random_val = torch.rand_like(acceptance_prob)
        accepted = random_val < acceptance_prob  # [B]

        if not accepted.all():
            # 첫 번째 거부된 토큰에서 중단
            break

        accepted_tokens.append(token)
        num_accepted += 1

    # 4. 마지막 토큰: 검증 모델이 생성 (또는 수정된 분포에서 샘플링)
    last_target_logits = target_output.logits[:, input_ids.size(1) + num_accepted - 1, :]

    if temperature > 0:
        last_prob = torch.softmax(last_target_logits / temperature, dim=-1)
    else:
        last_prob = torch.zeros_like(last_target_logits)
        last_prob.scatter_(1, last_target_logits.argmax(dim=-1, keepdim=True), 1.0)

    # 분포 수정 (거부된 경우)
    if num_accepted < draft_steps:
        # max(0, p_target - p_draft) 사용
        correction = torch.clamp(
            last_prob - draft_probs[num_accepted],
            min=0
        )
        correction = correction / (correction.sum(dim=-1, keepdim=True) + 1e-8)
        last_token = torch.multinomial(correction, num_samples=1)
    else:
        last_token = torch.multinomial(last_prob, num_samples=1)

    accepted_tokens.append(last_token.squeeze(1))

    final_tokens = torch.stack(accepted_tokens, dim=1)

    return final_tokens, num_accepted, draft_steps


def analyze_speedup(acceptance_rate: float, draft_steps: int = 4) -> dict:
    """수락률에 따른 속도 향상 분석"""

    # 기댓값 계산
    # E[accepted tokens] = sum_{k=0}^{K} alpha^k = (1 - alpha^{K+1}) / (1 - alpha)
    # 여기서 alpha = acceptance_rate, K = draft_steps

    expected_accepted = sum(
        acceptance_rate ** k for k in range(draft_steps + 1)
    )

    # 실제 속도 향상 (드래프트 모델 비용 포함)
    # 드래프트 모델이 검증 모델의 1/10 크기라 가정
    draft_model_ratio = 0.1

    # 시간: 드래프트 K 스텝 + 검증 1 스텝
    # 기존: K+1 스텝
    # 투기적: K * draft_ratio + 1 스텝 (검증)
    steps_with_speculative = draft_steps * draft_model_ratio + 1
    expected_tokens_with_speculative = expected_accepted

    speedup = expected_tokens_with_speculative / steps_with_speculative

    return {
        "acceptance_rate": acceptance_rate,
        "draft_steps": draft_steps,
        "expected_accepted_tokens": expected_accepted,
        "speedup": speedup
    }

# 수락률별 속도 향상 출력
print("Speculative Decoding 수락률별 속도 향상 (드래프트 K=4)")
print("=" * 55)
for alpha in [0.5, 0.6, 0.7, 0.8, 0.9, 0.95]:
    result = analyze_speedup(alpha, draft_steps=4)
    print(f"수락률 {alpha:.0%}: 기대 수락 {result['expected_accepted_tokens']:.2f}토큰, "
          f"속도향상 {result['speedup']:.2f}x")

5.3 Medusa: 멀티 드래프트 헤드

import torch
import torch.nn as nn

class MedusaHead(nn.Module):
    """
    Medusa: 단일 모델에 여러 드래프트 헤드 추가

    각 헤드가 미래 토큰을 예측:
    - Head 1: t+1 예측
    - Head 2: t+2 예측
    - Head N: t+N 예측
    """

    def __init__(
        self,
        hidden_size: int,
        vocab_size: int,
        num_heads: int = 4,
        hidden_layers: int = 1
    ):
        super().__init__()
        self.num_heads = num_heads

        # 각 미래 토큰 위치를 위한 독립적인 헤드
        self.heads = nn.ModuleList([
            nn.Sequential(
                *[nn.Linear(hidden_size, hidden_size, bias=False),
                  nn.SiLU()] * hidden_layers,
                nn.Linear(hidden_size, vocab_size, bias=False)
            )
            for _ in range(num_heads)
        ])

    def forward(self, hidden_states: torch.Tensor):
        """
        Args:
            hidden_states: [batch, seq, hidden_size] - 기본 모델의 마지막 히든 스테이트

        Returns:
            List of logits for each future position [batch, seq, vocab]
        """
        return [head(hidden_states) for head in self.heads]


class MedusaModel(nn.Module):
    """Medusa 전체 모델"""

    def __init__(self, base_model, vocab_size: int, num_medusa_heads: int = 4):
        super().__init__()
        self.base_model = base_model
        hidden_size = base_model.config.hidden_size

        self.medusa_heads = MedusaHead(
            hidden_size=hidden_size,
            vocab_size=vocab_size,
            num_heads=num_medusa_heads
        )

    def forward(self, input_ids: torch.Tensor, use_medusa: bool = False):
        # 기본 모델 실행
        base_output = self.base_model(
            input_ids,
            output_hidden_states=True
        )

        base_logits = base_output.logits

        if not use_medusa:
            return base_logits, None

        # Medusa 헤드로 미래 토큰 예측
        last_hidden_state = base_output.hidden_states[-1]
        medusa_logits = self.medusa_heads(last_hidden_state)

        return base_logits, medusa_logits

    def generate_with_medusa(
        self,
        input_ids: torch.Tensor,
        max_new_tokens: int = 100,
        temperature: float = 1.0,
        medusa_choices: int = 16,  # 후보 토큰 수
        threshold: float = 0.09    # 수락 임계값
    ):
        """Medusa를 이용한 빠른 생성"""

        current_ids = input_ids.clone()
        all_accepted = []

        while len(all_accepted) < max_new_tokens:
            # Medusa 헤드로 후보 토큰 예측
            base_logits, medusa_logits = self.forward(current_ids, use_medusa=True)

            # 각 위치에서 상위 후보들 선택
            candidates = []
            base_probs = torch.softmax(base_logits[:, -1, :] / temperature, dim=-1)
            top_tokens = torch.topk(base_probs, medusa_choices)[1]

            for head_logits in medusa_logits:
                head_probs = torch.softmax(head_logits[:, -1, :] / temperature, dim=-1)
                candidates.append(torch.topk(head_probs, medusa_choices)[1])

            # 트리 어텐션으로 후보들 검증 (단순화)
            # 실제로는 트리 마스크를 이용한 효율적인 검증
            best_token = top_tokens[0, 0]
            all_accepted.append(best_token.item())

            current_ids = torch.cat([current_ids, best_token.unsqueeze(0).unsqueeze(0)], dim=1)

            if best_token.item() == 2:  # EOS
                break

        return all_accepted

6. FlashAttention: 메모리 효율적인 어텐션

6.1 표준 어텐션의 HBM 병목

표준 어텐션은 HBM(High Bandwidth Memory)에 중간 결과를 자주 쓰고 읽습니다.

표준 Attention 메모리 연산:
1. Q, KHBM에서 읽음       → 읽기: O(N * d)
2. S = Q @ K.T 계산          → 쓰기: O(N^2) ← 병목!
3. SHBM에서 읽어 소프트맥스 → 읽기: O(N^2)
4. P = softmax(S) 저장        → 쓰기: O(N^2)
5. P를 읽어 P @ V 계산        → 읽기: O(N^2)
6. 최종 결과 저장             → 쓰기: O(N * d)

HBM 접근: O(N^2) (시퀀스 길이 제곱에 비례!)

6.2 FlashAttention의 타일링 전략

import torch
import math

def flash_attention_v1(Q, K, V, block_size=64):
    """
    FlashAttention v1 단순화 구현
    타일링을 이용하여 전체 어텐션 행렬을 HBM에 저장하지 않음

    핵심: Online Softmax 알고리즘으로 블록 단위 처리
    """
    batch_size, num_heads, seq_len, d_head = Q.shape

    scale = 1.0 / math.sqrt(d_head)
    Q = Q * scale

    # 출력 텐서 초기화 (SRAM에 유지)
    O = torch.zeros_like(Q)
    L = torch.zeros(batch_size, num_heads, seq_len, 1, device=Q.device)  # 소프트맥스 분모
    M = torch.full((batch_size, num_heads, seq_len, 1), float('-inf'), device=Q.device)  # 최대값

    num_blocks = (seq_len + block_size - 1) // block_size

    for j in range(num_blocks):
        # K, V 블록 로드 (HBM → SRAM)
        k_start = j * block_size
        k_end = min((j + 1) * block_size, seq_len)
        K_j = K[:, :, k_start:k_end, :]
        V_j = V[:, :, k_start:k_end, :]

        for i in range(num_blocks):
            # Q 블록 로드
            q_start = i * block_size
            q_end = min((i + 1) * block_size, seq_len)
            Q_i = Q[:, :, q_start:q_end, :]
            O_i = O[:, :, q_start:q_end, :]
            L_i = L[:, :, q_start:q_end, :]
            M_i = M[:, :, q_start:q_end, :]

            # 어텐션 스코어 계산 (SRAM에서)
            S_ij = torch.matmul(Q_i, K_j.transpose(-2, -1))  # [B, H, Br, Bc]

            # Online Softmax 업데이트
            M_ij_new = torch.maximum(M_i, S_ij.max(dim=-1, keepdim=True)[0])
            P_ij = torch.exp(S_ij - M_ij_new)
            L_ij_new = torch.exp(M_i - M_ij_new) * L_i + P_ij.sum(dim=-1, keepdim=True)

            # 출력 업데이트 (재스케일링)
            O_i_new = (
                torch.exp(M_i - M_ij_new) * O_i +
                torch.matmul(P_ij, V_j)
            )

            # 블록 결과를 HBM에 저장
            O[:, :, q_start:q_end, :] = O_i_new
            L[:, :, q_start:q_end, :] = L_ij_new
            M[:, :, q_start:q_end, :] = M_ij_new

    # 최종 정규화
    O = O / L

    return O


def compare_attention_implementations():
    """FlashAttention vs 표준 어텐션 비교"""

    batch_size = 2
    num_heads = 32
    seq_len = 4096
    d_head = 128

    Q = torch.randn(batch_size, num_heads, seq_len, d_head, device='cuda', dtype=torch.float16)
    K = torch.randn_like(Q)
    V = torch.randn_like(Q)

    # PyTorch SDPA (FlashAttention 2 구현 포함)
    import torch.nn.functional as F

    with torch.backends.cuda.sdp_kernel(
        enable_flash=True,
        enable_math=False,
        enable_mem_efficient=False
    ):
        flash_output = F.scaled_dot_product_attention(Q, K, V)

    # 표준 어텐션
    scale = 1.0 / math.sqrt(d_head)
    attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
    attn_weights = torch.softmax(attn_scores, dim=-1)
    standard_output = torch.matmul(attn_weights, V)

    # 결과 비교
    max_diff = (flash_output - standard_output).abs().max().item()
    print(f"FlashAttention vs 표준 어텐션 최대 차이: {max_diff:.6f}")

    # 메모리 사용량 비교
    standard_attn_matrix_size = batch_size * num_heads * seq_len * seq_len * 2  # FP16
    print(f"표준 어텐션 행렬 메모리: {standard_attn_matrix_size / 1e9:.2f} GB")
    print(f"FlashAttention 행렬 메모리: ~0 GB (타일링으로 저장 불필요)")

6.3 PyTorch SDPA 사용법

import torch
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel

def modern_attention(q, k, v, is_causal=True, dropout_p=0.0):
    """
    PyTorch 2.0+ scaled_dot_product_attention 사용

    FlashAttention 2/3를 자동으로 선택
    """

    # 자동 백엔드 선택 (Flash, Memory-efficient, Math)
    output = F.scaled_dot_product_attention(
        q, k, v,
        attn_mask=None,
        dropout_p=dropout_p,
        is_causal=is_causal,  # 인과적 마스킹
        scale=None  # None이면 1/sqrt(d_head) 사용
    )

    return output

# 특정 백엔드 강제 선택
def attention_with_flash_backend(q, k, v):
    with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
        return F.scaled_dot_product_attention(q, k, v, is_causal=True)

def attention_with_efficient_backend(q, k, v):
    with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
        return F.scaled_dot_product_attention(q, k, v, is_causal=True)


# FlashAttention 버전별 특징
flash_versions = {
    "FlashAttention 1": {
        "paper": "arXiv:2205.14135",
        "key_innovation": "타일링 + Online Softmax",
        "memory": "O(N) (어텐션 행렬 저장 불필요)",
        "speedup": "2-4x vs 표준 어텐션"
    },
    "FlashAttention 2": {
        "paper": "arXiv:2307.08691",
        "key_innovation": "작업 분할 최적화, FP16/BF16 지원",
        "memory": "O(N)",
        "speedup": "5-9x vs 표준 어텐션 (H100에서)"
    },
    "FlashAttention 3": {
        "paper": "arXiv:2407.08608",
        "key_innovation": "H100 특화, FP8 지원, 비동기 파이프라인",
        "memory": "O(N)",
        "speedup": "1.5-2x vs FA2 (H100에서)"
    },
}

for name, info in flash_versions.items():
    print(f"\n{name}")
    for k, v in info.items():
        print(f"  {k}: {v}")

7. 멀티 GPU 추론

7.1 Tensor Parallelism

가중치 행렬을 여러 GPU에 분산하여 각 GPU가 일부를 처리합니다.

import torch
import torch.distributed as dist

class TensorParallelLinear(torch.nn.Module):
    """
    Tensor Parallel Linear 레이어
    컬럼 분할 방식 (Column Parallel)
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        world_size: int,
        rank: int
    ):
        super().__init__()
        self.world_size = world_size
        self.rank = rank

        # 각 GPU는 out_features // world_size 개의 출력 뉴런을 담당
        self.local_out_features = out_features // world_size

        self.weight = torch.nn.Parameter(
            torch.randn(self.local_out_features, in_features) / (in_features ** 0.5)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 로컬 계산
        local_output = torch.nn.functional.linear(x, self.weight)

        # All-gather로 모든 GPU의 출력을 합침
        # (실제로는 분산 환경에서)
        # dist.all_gather(output_list, local_output)

        return local_output


def setup_tensor_parallel_llm(model_name: str, tp_size: int):
    """
    Tensor Parallel LLM 설정 예시 (vLLM 방식)

    vLLM 내부적으로 이 방식 사용
    """

    from vllm import LLM, SamplingParams

    # vLLM의 tensor_parallel_size 설정
    llm = LLM(
        model=model_name,
        tensor_parallel_size=tp_size,  # GPU 수
        gpu_memory_utilization=0.9
    )

    return llm

7.2 vLLM 완전 활용

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
import asyncio
import time

# vLLM 기본 사용
def vllm_basic_usage():
    """vLLM 기본 사용법"""

    llm = LLM(
        model="meta-llama/Llama-2-7b-hf",
        tensor_parallel_size=1,         # GPU 수
        gpu_memory_utilization=0.90,    # GPU 메모리 사용률
        max_model_len=4096,             # 최대 시퀀스 길이
        quantization=None,              # "awq", "gptq", "squeezellm"
        dtype="auto",                   # "float16", "bfloat16"
        max_num_seqs=256,               # 최대 동시 시퀀스 수
        enable_prefix_caching=True,     # 프리픽스 캐싱 활성화
        use_v2_block_manager=True,      # PagedAttention v2
    )

    sampling_params = SamplingParams(
        temperature=0.8,
        top_p=0.95,
        max_tokens=200,
        presence_penalty=0.0,
        frequency_penalty=0.0,
    )

    prompts = [
        "Explain quantum computing in simple terms",
        "What is the future of artificial intelligence?",
        "How does the human brain work?",
    ]

    # 배치 추론
    outputs = llm.generate(prompts, sampling_params)

    for output in outputs:
        print(f"Prompt: {output.prompt[:50]}...")
        print(f"Output: {output.outputs[0].text[:100]}...")
        print(f"Tokens generated: {len(output.outputs[0].token_ids)}")
        print()

    return outputs


# vLLM 비동기 서버
async def vllm_async_server():
    """vLLM 비동기 엔진 사용"""

    engine_args = AsyncEngineArgs(
        model="meta-llama/Llama-2-7b-hf",
        tensor_parallel_size=1,
        gpu_memory_utilization=0.90,
        max_model_len=4096,
        enable_prefix_caching=True,
    )

    engine = AsyncLLMEngine.from_engine_args(engine_args)

    async def generate_stream(prompt: str, request_id: str):
        sampling_params = SamplingParams(
            temperature=0.8,
            max_tokens=200
        )

        full_text = ""
        async for output in engine.generate(prompt, sampling_params, request_id):
            if output.outputs:
                delta = output.outputs[0].text[len(full_text):]
                full_text = output.outputs[0].text

                if delta:
                    print(f"[{request_id}] {delta}", end="", flush=True)

            if output.finished:
                print(f"\n[{request_id}] 완료")

    # 여러 요청 동시 처리
    await asyncio.gather(
        generate_stream("What is AI?", "req_1"),
        generate_stream("Explain machine learning", "req_2"),
        generate_stream("What is deep learning?", "req_3"),
    )

8. 추론 엔진 비교

8.1 주요 추론 엔진 특징

엔진개발사핵심 기능적합한 용도
vLLMUC BerkeleyPagedAttention, Continuous Batching범용 LLM 서빙
TGIHuggingFaceFlash Attention 2, SpeculativeHF 모델 서빙
TensorRT-LLMNVIDIANVIDIA GPU 최적화, FP8NVIDIA 최대 성능
DeepSpeed-MIIMicrosoftZeRO 추론, 초대형 모델다중 GPU 대형 모델
llama.cppGeorgi GerganovCPU 최적화, GGUF로컬 실행

8.2 벤치마크 비교

import subprocess
import json
import time
import requests

def benchmark_vllm_server(
    model: str,
    num_requests: int = 100,
    max_tokens: int = 100,
    concurrency: int = 10
):
    """vLLM 서버 벤치마크"""

    results = {
        "total_requests": num_requests,
        "concurrency": concurrency,
        "latencies": [],
        "ttfts": [],
        "throughputs": []
    }

    import asyncio
    import aiohttp

    async def send_request(session, prompt, request_id):
        start = time.perf_counter()
        first_token_time = None

        payload = {
            "model": model,
            "prompt": prompt,
            "max_tokens": max_tokens,
            "stream": True,
            "temperature": 0.0
        }

        async with session.post(
            "http://localhost:8000/v1/completions",
            json=payload
        ) as response:
            async for line in response.content:
                if line.startswith(b"data: "):
                    data = line[6:].decode()
                    if data.strip() == "[DONE]":
                        break

                    if first_token_time is None:
                        first_token_time = time.perf_counter() - start

        end = time.perf_counter()
        return {
            "latency": end - start,
            "ttft": first_token_time,
        }

    async def run_benchmark():
        prompts = [
            f"Tell me about topic number {i}." for i in range(num_requests)
        ]

        start = time.perf_counter()

        async with aiohttp.ClientSession() as session:
            # 동시 요청
            tasks = []
            for i, prompt in enumerate(prompts):
                if len(tasks) >= concurrency:
                    done, tasks = await asyncio.wait(
                        tasks, return_when=asyncio.FIRST_COMPLETED
                    )
                    for task in done:
                        result = await task
                        results["latencies"].append(result["latency"])
                        if result["ttft"]:
                            results["ttfts"].append(result["ttft"])

                tasks.add(asyncio.ensure_future(
                    send_request(session, prompt, i)
                ))

            # 남은 작업 처리
            for coro in asyncio.as_completed(tasks):
                result = await coro
                results["latencies"].append(result["latency"])

        total_time = time.perf_counter() - start
        total_tokens = num_requests * max_tokens
        results["throughput"] = total_tokens / total_time

    asyncio.run(run_benchmark())

    # 통계 계산
    import statistics
    latencies = results["latencies"]

    return {
        "avg_latency_ms": statistics.mean(latencies) * 1000,
        "p50_latency_ms": statistics.median(latencies) * 1000,
        "p99_latency_ms": sorted(latencies)[int(len(latencies) * 0.99)] * 1000,
        "avg_ttft_ms": statistics.mean(results["ttfts"]) * 1000 if results["ttfts"] else 0,
        "throughput_tps": results.get("throughput", 0),
    }

9. 프롬프트 캐싱

9.1 프리픽스 캐싱

동일한 시스템 프롬프트나 문서를 반복적으로 처리할 때 KV Cache를 재사용합니다.

from vllm import LLM, SamplingParams

def demonstrate_prefix_caching():
    """프리픽스 캐싱 효과 시연"""

    llm = LLM(
        model="meta-llama/Llama-2-7b-hf",
        enable_prefix_caching=True,  # 프리픽스 캐싱 활성화
        max_model_len=4096,
    )

    # 긴 시스템 프롬프트 (모든 요청에 공통)
    system_prompt = """You are a helpful AI assistant with expertise in:
    - Python programming and software development
    - Machine learning and deep learning
    - Data science and statistics
    - Cloud computing and DevOps
    [... 긴 시스템 프롬프트 ...]""" * 10  # 1000+ 토큰

    questions = [
        "How do I optimize a Python loop?",
        "What is gradient descent?",
        "Explain containerization.",
        "What is a neural network?",
    ]

    sampling_params = SamplingParams(temperature=0.7, max_tokens=100)

    # 첫 번째 배치: 캐시 없음 (콜드 스타트)
    import time

    cold_prompts = [f"{system_prompt}\n\nQuestion: {q}" for q in questions]

    cold_start = time.time()
    llm.generate(cold_prompts, sampling_params)
    cold_time = time.time() - cold_start

    # 두 번째 배치: 동일한 시스템 프롬프트 (캐시 히트!)
    warm_start = time.time()
    llm.generate(cold_prompts, sampling_params)
    warm_time = time.time() - warm_start

    print(f"첫 번째 (캐시 없음): {cold_time:.2f}초")
    print(f"두 번째 (캐시 히트): {warm_time:.2f}초")
    print(f"속도 향상: {cold_time / warm_time:.2f}x")


def radix_tree_prefix_cache():
    """Radix Tree 기반 프리픽스 캐시 구현"""

    class RadixNode:
        def __init__(self):
            self.children: dict = {}
            self.kv_cache_block_id: int = None

    class RadixTreeCache:
        """토큰 시퀀스를 Radix Tree로 관리하여 공통 프리픽스 KV 캐시 공유"""

        def __init__(self):
            self.root = RadixNode()
            self.cache_hits = 0
            self.cache_misses = 0

        def insert(self, token_ids: list, block_id: int):
            """토큰 시퀀스와 해당 KV Cache 블록 ID 삽입"""
            node = self.root
            for token_id in token_ids:
                if token_id not in node.children:
                    node.children[token_id] = RadixNode()
                node = node.children[token_id]
            node.kv_cache_block_id = block_id

        def lookup(self, token_ids: list) -> tuple:
            """주어진 토큰 시퀀스의 최장 일치 프리픽스 찾기"""
            node = self.root
            matched_len = 0
            last_block_id = None

            for i, token_id in enumerate(token_ids):
                if token_id in node.children:
                    node = node.children[token_id]
                    matched_len = i + 1
                    if node.kv_cache_block_id is not None:
                        last_block_id = node.kv_cache_block_id
                else:
                    break

            if last_block_id is not None:
                self.cache_hits += 1
            else:
                self.cache_misses += 1

            return matched_len, last_block_id

        def get_hit_rate(self) -> float:
            total = self.cache_hits + self.cache_misses
            return self.cache_hits / total if total > 0 else 0.0

    return RadixTreeCache()

10. 실전 최적화 체크리스트

10.1 단계별 최적화 가이드

class LLMOptimizationChecklist:
    """LLM 추론 최적화 체크리스트"""

    optimizations = [
        {
            "category": "기본 설정",
            "level": 1,
            "items": [
                {
                    "name": "FP16/BF16 사용",
                    "impact": "높음",
                    "effort": "낮음",
                    "description": "FP32 → FP16으로 메모리 2배 절약, 속도 향상",
                    "code": """
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,  # 또는 float16
    device_map="auto"
)"""
                },
                {
                    "name": "Flash Attention 2 활성화",
                    "impact": "높음",
                    "effort": "낮음",
                    "description": "어텐션 연산 2-4x 속도 향상, 메모리 절약",
                    "code": """
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16
)"""
                },
            ]
        },
        {
            "category": "KV Cache 최적화",
            "level": 2,
            "items": [
                {
                    "name": "GQA/MQA 모델 선택",
                    "impact": "높음",
                    "effort": "중간",
                    "description": "KV Cache 4-8배 감소, 더 많은 배치 처리 가능"
                },
                {
                    "name": "프리픽스 캐싱",
                    "impact": "중간",
                    "effort": "낮음",
                    "description": "공통 시스템 프롬프트 KV Cache 재사용"
                },
            ]
        },
        {
            "category": "배치 최적화",
            "level": 3,
            "items": [
                {
                    "name": "Continuous Batching (vLLM)",
                    "impact": "매우 높음",
                    "effort": "낮음",
                    "description": "처리량 2-5x 향상",
                    "code": """
from vllm import LLM, SamplingParams

llm = LLM(
    model=model_name,
    gpu_memory_utilization=0.90,
    enable_prefix_caching=True,
)"""
                },
            ]
        },
        {
            "category": "모델 최적화",
            "level": 4,
            "items": [
                {
                    "name": "AWQ 4-bit 양자화",
                    "impact": "높음",
                    "effort": "중간",
                    "description": "메모리 4x 감소, 속도 1.5-2x 향상",
                    "code": """
from awq import AutoAWQForCausalLM

model = AutoAWQForCausalLM.from_quantized(
    "model-awq-4bit",
    fuse_layers=True
)"""
                },
                {
                    "name": "Speculative Decoding",
                    "impact": "중간",
                    "effort": "높음",
                    "description": "2-3x 속도 향상 (적합한 드래프트 모델 필요)"
                },
            ]
        },
        {
            "category": "하드웨어 최적화",
            "level": 5,
            "items": [
                {
                    "name": "Tensor Parallelism",
                    "impact": "매우 높음",
                    "effort": "중간",
                    "description": "다중 GPU로 선형적 처리량 향상"
                },
                {
                    "name": "CUDA 그래프 캡처",
                    "impact": "중간",
                    "effort": "높음",
                    "description": "커널 론치 오버헤드 제거"
                },
            ]
        }
    ]

    @classmethod
    def print_checklist(cls):
        print("=" * 70)
        print("LLM 추론 최적화 단계별 체크리스트")
        print("=" * 70)

        for category in cls.optimizations:
            print(f"\n[레벨 {category['level']}] {category['category']}")
            print("-" * 50)

            for item in category['items']:
                impact_emoji = {"매우 높음": "★★★", "높음": "★★", "중간": "★", "낮음": "☆"}
                print(f"  ✓ {item['name']}")
                print(f"    효과: {impact_emoji.get(item['impact'], '?')} {item['impact']}")
                print(f"    설명: {item['description']}")

        print("\n추천 최적화 순서:")
        print("1. BF16/FP16 전환 (즉시, 무료)")
        print("2. Flash Attention 2 (즉시, 패키지 설치만)")
        print("3. vLLM으로 서빙 (처리량 극대화)")
        print("4. AWQ/GPTQ 4비트 양자화 (메모리 4배 절약)")
        print("5. Speculative Decoding (레이턴시 개선)")
        print("6. 멀티 GPU Tensor Parallelism (규모 확장)")

LLMOptimizationChecklist.print_checklist()

마무리

LLM 추론 최적화는 계층적 접근이 필요합니다.

핵심 요점 정리:

  1. KV Cache 이해: 메모리 사용량 공식 2 * layers * kv_heads * d_head * seq_len * dtype_bytes를 외우고, GQA/MQA로 KV Cache를 4-8배 줄이세요.

  2. PagedAttention: vLLM의 핵심 혁신으로, OS의 가상 메모리에서 영감 받아 KV Cache 단편화를 해결합니다.

  3. Continuous Batching: 요청이 완료되는 즉시 새 요청을 삽입하여 GPU 활용률을 극대화합니다.

  4. Speculative Decoding: 작은 드래프트 모델 + 큰 검증 모델 조합으로 2-3x 속도 향상이 가능합니다.

  5. FlashAttention: 어텐션 계산의 메모리 효율을 O(N^2)에서 O(N)으로 줄여 긴 컨텍스트를 가능하게 합니다.

프로덕션 배포 권고:

  • 소규모 서비스: vLLM + AWQ 4bit + 프리픽스 캐싱
  • 대규모 서비스: TensorRT-LLM 또는 vLLM + Tensor Parallelism
  • 최저 레이턴시 요구: Speculative Decoding + CUDA 그래프

참고 자료

LLM Inference Optimization Complete Guide: KV Cache, Speculative Decoding, Continuous Batching

Introduction

When you deploy a large language model (LLM) to production, you immediately face a challenge: inference speed and cost. A first query to a GPT-4 class model can take several seconds, and throughput degrades sharply as concurrent users grow.

This guide thoroughly explores the core techniques for LLM inference optimization. From KV Cache internals to PagedAttention, Speculative Decoding, FlashAttention, and the latest vLLM and TensorRT-LLM engines — we don't just cover how to use them, we understand why they work.


1. Understanding the LLM Inference Pipeline

1.1 Two Phases: Prefill and Decode

LLM text generation is split into two distinct phases.

Prefill Phase (Prompt Processing)

  • Processes all tokens of the input prompt simultaneously (parallel)
  • Generates and stores Key/Value caches at each layer
  • Compute-Bound: GPU compute throughput is the bottleneck
  • Directly affects TTFT (Time To First Token)

Decode Phase (Token Generation)

  • Generates one token at a time in an autoregressive manner
  • References KV Cache from all previously generated tokens
  • Memory-Bandwidth-Bound: HBM read speed is the bottleneck
  • Directly affects TPOT (Time Per Output Token)
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time

def measure_prefill_decode_time(model, tokenizer, prompt: str, max_new_tokens: int = 100):
    """Measure prefill and decode phase timing"""

    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    input_len = inputs["input_ids"].size(1)

    # Measure TTFT
    torch.cuda.synchronize()
    prefill_start = time.perf_counter()

    with torch.no_grad():
        first_output = model.generate(
            **inputs,
            max_new_tokens=1,
            do_sample=False
        )

    torch.cuda.synchronize()
    ttft = time.perf_counter() - prefill_start

    # Measure total generation
    torch.cuda.synchronize()
    total_start = time.perf_counter()

    with torch.no_grad():
        full_output = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False
        )

    torch.cuda.synchronize()
    total_time = time.perf_counter() - total_start

    output_tokens = full_output.size(1) - input_len
    decode_time = total_time - ttft
    tpot = decode_time / max(output_tokens - 1, 1)

    print(f"Input tokens: {input_len}")
    print(f"Output tokens: {output_tokens}")
    print(f"TTFT (first token latency): {ttft * 1000:.1f} ms")
    print(f"TPOT (per token time): {tpot * 1000:.1f} ms")
    print(f"Throughput: {output_tokens / total_time:.1f} tokens/sec")

    return ttft, tpot

1.2 Memory Bandwidth Analysis

Understanding why the decode phase is memory-bound:

def analyze_memory_bandwidth():
    """LLM inference memory bandwidth analysis"""

    # Example: Llama-2-7B config
    model_params = {
        "num_layers": 32,
        "hidden_size": 4096,
        "num_heads": 32,
        "head_dim": 128,
        "vocab_size": 32000,
    }

    dtype_bytes = 2  # FP16: 2 bytes

    # Weight memory
    attn_weight = 4 * model_params["hidden_size"] ** 2  # Q, K, V, O projections
    ffn_weight = 8 * model_params["hidden_size"] ** 2   # SwiGLU: Up, Gate, Down
    layer_weight = (attn_weight + ffn_weight) * dtype_bytes

    total_weight_bytes = layer_weight * model_params["num_layers"]
    total_weight_gb = total_weight_bytes / 1e9

    print(f"Model weights: {total_weight_gb:.2f} GB")

    # KV Cache memory (proportional to sequence length)
    seq_len = 2048
    kv_cache_per_token = (
        2 *
        model_params["num_layers"] *
        model_params["num_heads"] *
        model_params["head_dim"] *
        dtype_bytes
    )

    kv_cache_total = kv_cache_per_token * seq_len / 1e6
    print(f"KV Cache ({seq_len} tokens): {kv_cache_total:.2f} MB")
    print(f"KV Cache per token: {kv_cache_per_token} bytes")

    # A100 memory bandwidth: 2 TB/s
    memory_bandwidth_tbs = 2.0

    # Decode: weights loaded once per token
    memory_bound_tps = memory_bandwidth_tbs * 1e12 / total_weight_bytes
    # A100 FP16: 312 TFLOPS
    compute_throughput = 312e12
    flops_per_token = 2 * total_weight_bytes
    compute_bound_tps = compute_throughput / flops_per_token

    print(f"\nFor batch_size=1:")
    print(f"Memory-bound throughput: {memory_bound_tps:.1f} tokens/sec")
    print(f"Compute-bound throughput: {compute_bound_tps:.1f} tokens/sec")
    print(f"Actual bottleneck: {'memory' if memory_bound_tps < compute_bound_tps else 'compute'}")

analyze_memory_bandwidth()

1.3 Inference Cost Analysis

def estimate_inference_cost(
    model_size_b: float,
    tokens_per_request: int,
    requests_per_day: int,
    gpu_cost_per_hour: float = 3.0  # A100 hourly price (USD)
):
    """Estimate inference cost"""

    # Empirical throughput estimates
    # 7B model: ~100 tok/s, 70B model: ~20 tok/s (A100)
    throughput_tps = 100 / (model_size_b / 7) ** 0.6

    total_tokens_per_day = tokens_per_request * requests_per_day
    seconds_needed = total_tokens_per_day / throughput_tps
    hours_needed = seconds_needed / 3600

    daily_cost = hours_needed * gpu_cost_per_hour
    cost_per_1k_tokens = daily_cost / (total_tokens_per_day / 1000)

    print(f"Model: {model_size_b}B parameters")
    print(f"Daily requests: {requests_per_day:,}")
    print(f"Tokens per request: {tokens_per_request}")
    print(f"Total daily tokens: {total_tokens_per_day:,}")
    print(f"Estimated throughput: {throughput_tps:.1f} tokens/sec")
    print(f"GPU hours needed: {hours_needed:.2f}")
    print(f"Daily cost: ${daily_cost:.2f}")
    print(f"Cost per 1K tokens: ${cost_per_1k_tokens:.4f}")

estimate_inference_cost(
    model_size_b=7.0,
    tokens_per_request=500,
    requests_per_day=10000
)

2. KV Cache: The Core Optimization

2.1 Why KV Cache is Necessary

In transformer attention, each token computes attention against all previous tokens. To avoid recomputing already-processed tokens, we cache the K and V matrices.

import torch
import torch.nn as nn
import math

class MultiHeadAttentionWithKVCache(nn.Module):
    """Multi-head attention with KV Cache support"""

    def __init__(self, d_model: int, num_heads: int, max_seq_len: int = 4096):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

        # KV Cache initialization
        self.register_buffer(
            'k_cache',
            torch.zeros(1, max_seq_len, num_heads, self.d_head)
        )
        self.register_buffer(
            'v_cache',
            torch.zeros(1, max_seq_len, num_heads, self.d_head)
        )
        self.cache_pos = 0

    def forward(self, x: torch.Tensor, use_cache: bool = True, position: int = None):
        batch_size, seq_len, _ = x.shape

        q = self.W_q(x).reshape(batch_size, seq_len, self.num_heads, self.d_head)
        k = self.W_k(x).reshape(batch_size, seq_len, self.num_heads, self.d_head)
        v = self.W_v(x).reshape(batch_size, seq_len, self.num_heads, self.d_head)

        if use_cache:
            start_pos = self.cache_pos if position is None else position
            self.k_cache[:, start_pos:start_pos + seq_len] = k
            self.v_cache[:, start_pos:start_pos + seq_len] = v

            if position is None:
                self.cache_pos += seq_len

            total_len = self.cache_pos if position is None else start_pos + seq_len
            k = self.k_cache[:, :total_len]
            v = self.v_cache[:, :total_len]

        scale = math.sqrt(self.d_head)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale
        attn_weights = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_weights, v)
        output = output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
        output = self.W_o(output)

        return output

    def clear_cache(self):
        self.k_cache.zero_()
        self.v_cache.zero_()
        self.cache_pos = 0

2.2 KV Cache Memory Calculation

def calculate_kv_cache_memory(
    model_config: dict,
    batch_size: int,
    seq_len: int,
    dtype_bytes: int = 2  # FP16
) -> dict:
    """Calculate KV Cache memory usage"""

    num_layers = model_config["num_layers"]
    num_kv_heads = model_config.get("num_kv_heads", model_config["num_heads"])
    head_dim = model_config["head_dim"]

    # Formula: 2 (K+V) * layers * kv_heads * head_dim * seq_len * batch * dtype
    kv_cache_bytes = (
        2 *
        num_layers *
        num_kv_heads *
        head_dim *
        seq_len *
        batch_size *
        dtype_bytes
    )

    return {
        "kv_cache_bytes": kv_cache_bytes,
        "kv_cache_mb": kv_cache_bytes / 1e6,
        "kv_cache_gb": kv_cache_bytes / 1e9,
        "per_token_bytes": kv_cache_bytes // seq_len,
    }

models = {
    "Llama-2-7B (MHA)": {
        "num_layers": 32, "num_heads": 32,
        "num_kv_heads": 32, "head_dim": 128
    },
    "Llama-2-70B (GQA)": {
        "num_layers": 80, "num_heads": 64,
        "num_kv_heads": 8, "head_dim": 128
    },
    "Mistral-7B (GQA)": {
        "num_layers": 32, "num_heads": 32,
        "num_kv_heads": 8, "head_dim": 128
    },
}

print("KV Cache memory (batch=1, seq=4096)")
print("=" * 65)
for name, config in models.items():
    result = calculate_kv_cache_memory(config, batch_size=1, seq_len=4096)
    print(f"{name:<25} {result['kv_cache_gb']:.2f} GB  "
          f"({result['per_token_bytes']:,} bytes/token)")

2.3 Grouped Query Attention (GQA)

GQA is a key technique for reducing KV Cache — multiple query heads share fewer KV heads.

import torch
import torch.nn as nn
import math

class GroupedQueryAttention(nn.Module):
    """Grouped Query Attention (GQA) implementation"""

    def __init__(self, d_model: int, num_q_heads: int, num_kv_heads: int):
        super().__init__()
        assert num_q_heads % num_kv_heads == 0

        self.d_model = d_model
        self.num_q_heads = num_q_heads
        self.num_kv_heads = num_kv_heads
        self.num_groups = num_q_heads // num_kv_heads
        self.d_head = d_model // num_q_heads

        self.W_q = nn.Linear(d_model, num_q_heads * self.d_head, bias=False)
        self.W_k = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
        self.W_v = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x: torch.Tensor, kv_cache=None):
        batch_size, seq_len, _ = x.shape

        q = self.W_q(x).reshape(batch_size, seq_len, self.num_q_heads, self.d_head)
        k = self.W_k(x).reshape(batch_size, seq_len, self.num_kv_heads, self.d_head)
        v = self.W_v(x).reshape(batch_size, seq_len, self.num_kv_heads, self.d_head)

        if kv_cache is not None:
            k = torch.cat([kv_cache["k"], k], dim=1)
            v = torch.cat([kv_cache["v"], v], dim=1)
        new_kv_cache = {"k": k, "v": v}

        q = q.transpose(1, 2)   # [B, Q_heads, S, d_head]
        k = k.transpose(1, 2)   # [B, KV_heads, T, d_head]
        v = v.transpose(1, 2)

        # GQA: expand KV heads to match Q heads
        k = k.repeat_interleave(self.num_groups, dim=1)
        v = v.repeat_interleave(self.num_groups, dim=1)

        scale = math.sqrt(self.d_head)
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale
        attn_weights = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_weights, v)
        output = output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)

        return self.W_o(output), new_kv_cache

def compare_attention_variants():
    """Compare KV Cache memory across attention variants"""

    # 70B model (Llama-2-70B)
    num_layers = 80
    d_head = 128
    seq_len = 4096
    dtype_bytes = 2

    variants = {
        "MHA (32 KV heads)": 32,
        "GQA (8 KV heads)": 8,
        "MQA (1 KV head)": 1,
    }

    print("70B model attention variant KV Cache comparison")
    print(f"(seq_len={seq_len}, batch=1)")
    print("=" * 55)

    for name, num_kv_heads in variants.items():
        kv_bytes = 2 * num_layers * num_kv_heads * d_head * seq_len * dtype_bytes
        kv_gb = kv_bytes / 1e9
        print(f"{name:<25} {kv_gb:.2f} GB")

compare_attention_variants()

2.4 DeepSeek MLA (Multi-Head Latent Attention)

MLA, introduced in DeepSeek-V2, compresses KV Cache into low-dimensional latent vectors.

class MultiHeadLatentAttention(nn.Module):
    """
    DeepSeek MLA — compresses KV Cache to low-rank latent vectors

    Core idea:
    - Instead of storing high-dim K,V, store low-dim latent c_kv
    - Recover K, V from c_kv via up-projection
    - KV Cache size: num_layers * kv_lora_rank * seq_len
      (vs standard: 2 * num_layers * num_kv_heads * d_head * seq_len)
    """

    def __init__(
        self,
        d_model: int = 5120,
        num_heads: int = 128,
        kv_lora_rank: int = 512,
        qk_nope_head_dim: int = 128,
        qk_rope_head_dim: int = 64,
        v_head_dim: int = 128,
    ):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.kv_lora_rank = kv_lora_rank

        # Q projection (LoRA-style)
        self.q_a_proj = nn.Linear(d_model, 1536, bias=False)
        self.q_b_proj = nn.Linear(
            1536,
            num_heads * (qk_nope_head_dim + qk_rope_head_dim),
            bias=False
        )

        # KV down-projection: d_model -> kv_lora_rank
        # Only THIS is stored in KV Cache!
        self.kv_a_proj = nn.Linear(
            d_model,
            kv_lora_rank + qk_rope_head_dim,
            bias=False
        )

        # KV up-projection: kv_lora_rank -> K, V
        self.kv_b_proj = nn.Linear(
            kv_lora_rank,
            num_heads * (qk_nope_head_dim + v_head_dim),
            bias=False
        )

        self.o_proj = nn.Linear(num_heads * v_head_dim, d_model, bias=False)

    def forward(self, x: torch.Tensor, compressed_kv_cache=None):
        batch_size, seq_len, _ = x.shape

        # KV compression (only this result goes in cache)
        kv_compressed = self.kv_a_proj(x)  # [B, S, kv_lora_rank + rope_dim]

        if compressed_kv_cache is not None:
            kv_compressed_total = torch.cat([compressed_kv_cache, kv_compressed], dim=1)
        else:
            kv_compressed_total = kv_compressed

        # Recover full K, V from cached compressed representation
        kv_content = kv_compressed_total[:, :, :self.kv_lora_rank]
        kv_full = self.kv_b_proj(kv_content)

        return None, kv_compressed

def compare_mla_vs_mha():
    """Compare MLA vs MHA KV Cache"""

    seq_len = 4096
    dtype_bytes = 2  # BF16
    num_layers = 60  # DeepSeek-V2

    num_heads = 128
    head_dim = 128
    mha_kv_gb = 2 * num_layers * num_heads * head_dim * seq_len * dtype_bytes / 1e9

    kv_lora_rank = 512
    rope_dim = 64
    mla_kv_gb = (kv_lora_rank + rope_dim) * num_layers * seq_len * dtype_bytes / 1e9

    print(f"MHA KV Cache: {mha_kv_gb:.2f} GB")
    print(f"MLA KV Cache: {mla_kv_gb:.2f} GB")
    print(f"Reduction: {mha_kv_gb / mla_kv_gb:.1f}x")

compare_mla_vs_mha()

3. PagedAttention: vLLM's Core Innovation

3.1 The Problem with Conventional KV Cache

Traditional LLM serving systems pre-allocate memory for the maximum sequence length per request:

Request 1: [PROMPT=200 tokens] [KV_CACHE=up to 1848 tokens reserved] → internal fragmentation
Request 2: [PROMPT=100 tokens] [KV_CACHE=1948 tokens reserved]
Request 3: Blocked waiting (external fragmentation)

This wastes 60–80% of GPU memory.

3.2 PagedAttention: How It Works

Inspired by OS virtual memory, PagedAttention manages KV Cache in fixed-size physical blocks.

from dataclasses import dataclass, field
from typing import Dict, List, Optional
import torch

@dataclass
class PhysicalBlock:
    """Physical memory block"""
    block_id: int
    block_size: int  # tokens per block (e.g., 16)
    ref_count: int = 0

@dataclass
class LogicalBlock:
    """Logical block (mapped to a request)"""
    physical_block_id: int
    num_filled: int = 0

class PagedKVCacheManager:
    """PagedAttention KV Cache manager"""

    def __init__(
        self,
        num_physical_blocks: int,
        block_size: int,
        num_layers: int,
        num_kv_heads: int,
        head_dim: int,
        device: str = "cuda"
    ):
        self.block_size = block_size
        self.num_layers = num_layers

        self.free_blocks: List[int] = list(range(num_physical_blocks))
        self.all_blocks: Dict[int, PhysicalBlock] = {
            i: PhysicalBlock(block_id=i, block_size=block_size)
            for i in range(num_physical_blocks)
        }
        self.block_tables: Dict[int, List[LogicalBlock]] = {}

        # Actual KV Cache tensor pool
        self.kv_cache = torch.zeros(
            num_physical_blocks, 2, num_layers, block_size, num_kv_heads, head_dim,
            dtype=torch.float16,
            device=device
        )

    def allocate_blocks_for_request(self, request_id: int, num_tokens: int):
        """Allocate blocks for a request"""
        num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size

        if len(self.free_blocks) < num_blocks_needed:
            raise RuntimeError(
                f"OOM: need {num_blocks_needed} blocks, {len(self.free_blocks)} available"
            )

        logical_blocks = []
        for _ in range(num_blocks_needed):
            physical_id = self.free_blocks.pop(0)
            self.all_blocks[physical_id].ref_count = 1
            logical_blocks.append(LogicalBlock(physical_block_id=physical_id))

        self.block_tables[request_id] = logical_blocks
        print(f"Request {request_id}: {num_blocks_needed} blocks allocated, "
              f"{len(self.free_blocks)} remaining")

    def append_token(self, request_id: int, layer: int, token_pos: int,
                     k: torch.Tensor, v: torch.Tensor):
        """Store KV for a new token in the cache"""
        block_idx = token_pos // self.block_size
        token_in_block = token_pos % self.block_size

        logical_block = self.block_tables[request_id][block_idx]
        physical_id = logical_block.physical_block_id

        self.kv_cache[physical_id, 0, layer, token_in_block] = k
        self.kv_cache[physical_id, 1, layer, token_in_block] = v
        logical_block.num_filled = token_in_block + 1

    def free_request(self, request_id: int):
        """Free blocks after request completes"""
        if request_id in self.block_tables:
            for logical_block in self.block_tables[request_id]:
                phys_id = logical_block.physical_block_id
                self.all_blocks[phys_id].ref_count -= 1

                if self.all_blocks[phys_id].ref_count == 0:
                    self.free_blocks.append(phys_id)

            del self.block_tables[request_id]

    def copy_on_write(self, src_request_id: int, dst_request_id: int):
        """Copy-on-Write for prefix caching"""
        src_blocks = self.block_tables[src_request_id]
        dst_blocks = []

        for logical_block in src_blocks:
            phys_id = logical_block.physical_block_id
            self.all_blocks[phys_id].ref_count += 1
            dst_blocks.append(
                LogicalBlock(
                    physical_block_id=phys_id,
                    num_filled=logical_block.num_filled
                )
            )

        self.block_tables[dst_request_id] = dst_blocks


# Demo
manager = PagedKVCacheManager(
    num_physical_blocks=1000,
    block_size=16,
    num_layers=32,
    num_kv_heads=32,
    head_dim=128
)

manager.allocate_blocks_for_request(request_id=1, num_tokens=200)
manager.allocate_blocks_for_request(request_id=2, num_tokens=500)
manager.allocate_blocks_for_request(request_id=3, num_tokens=100)

manager.free_request(request_id=1)
print(f"\nAfter request 1 completes — free blocks: {len(manager.free_blocks)}")

4. Continuous Batching

4.1 The Problem with Static Batching

Static batching waits for all requests in a batch to complete before starting the next:

t=0: [Request A: 500 tokens] [Request B: 100 tokens] [Request C: 300 tokens]
t=1: Request B done — but can't start new requests until A and C finish
t=2: Request C done
t=3: Request A done → only now can a new batch start

4.2 Continuous Batching (Iteration-level Scheduling)

from dataclasses import dataclass, field
from typing import List, Tuple
import time

@dataclass
class Request:
    """Inference request"""
    request_id: str
    input_ids: List[int]
    max_new_tokens: int
    generated_ids: List[int] = field(default_factory=list)
    is_finished: bool = False

class ContinuousBatchingScheduler:
    """Continuous Batching scheduler"""

    def __init__(self, max_batch_size: int = 32, max_seq_len: int = 4096):
        self.max_batch_size = max_batch_size
        self.max_seq_len = max_seq_len

        self.waiting_queue: List[Request] = []
        self.running_requests: List[Request] = []
        self.finished_requests: List[Request] = []

    def add_request(self, request: Request):
        self.waiting_queue.append(request)

    def _can_add_request(self, request: Request) -> bool:
        if len(self.running_requests) + 1 > self.max_batch_size:
            return False

        total_tokens = sum(
            len(r.input_ids) + len(r.generated_ids)
            for r in self.running_requests
        ) + len(request.input_ids)

        return total_tokens < self.max_seq_len * self.max_batch_size

    def schedule_iteration(self) -> Tuple[List[Request], List[str]]:
        """
        Schedule one iteration's batch

        Returns:
            (active requests, list of just-completed request IDs)
        """
        completed_ids = []

        still_running = []
        for req in self.running_requests:
            if req.is_finished:
                self.finished_requests.append(req)
                completed_ids.append(req.request_id)
            else:
                still_running.append(req)
        self.running_requests = still_running

        # Fill empty slots immediately with waiting requests
        while self.waiting_queue and self._can_add_request(self.waiting_queue[0]):
            new_request = self.waiting_queue.pop(0)
            self.running_requests.append(new_request)
            print(f"Added request {new_request.request_id} to batch "
                  f"(batch size: {len(self.running_requests)})")

        return self.running_requests, completed_ids

    def simulate_one_step(self, model_forward_fn):
        """Simulate one step"""
        active_requests, completed = self.schedule_iteration()

        if not active_requests:
            return []

        batch_input_ids = []
        for req in active_requests:
            if len(req.generated_ids) == 0:
                batch_input_ids.append(req.input_ids)  # Prefill
            else:
                batch_input_ids.append([req.generated_ids[-1]])  # Decode

        outputs = model_forward_fn(batch_input_ids)

        for req, next_token_id in zip(active_requests, outputs):
            req.generated_ids.append(next_token_id)

            if next_token_id == 2 or len(req.generated_ids) >= req.max_new_tokens:
                req.is_finished = True

        return completed

5. Speculative Decoding

5.1 The Idea: Draft + Verify

Speculative Decoding's core: a small draft model generates several tokens in parallel, and the large target model verifies them all at once (like a prefill).

Standard: [large model] → token1 → token2 → token3 → token4 → token5
Speculative: [small model]generate (t1, t2, t3, t4, t5) in parallel
             [large model] → verify all 5 tokens at once (parallel, like prefill)
             only accepted tokens are kept

5.2 Acceptance Rate and Speedup Analysis

import torch
from typing import List, Tuple

def speculative_decode_step(
    draft_model,
    target_model,
    input_ids: torch.Tensor,
    draft_steps: int = 4,
    temperature: float = 1.0
) -> Tuple[torch.Tensor, int, int]:
    """
    One step of Speculative Decoding

    Returns:
        (generated tokens, num_accepted, num_drafted)
    """
    # 1. Draft model generates candidate tokens
    draft_tokens = []
    draft_probs = []
    current_ids = input_ids.clone()

    for _ in range(draft_steps):
        with torch.no_grad():
            draft_logits = draft_model(current_ids).logits[:, -1, :]

        draft_prob = torch.softmax(draft_logits / (temperature + 1e-8), dim=-1)
        draft_token = torch.multinomial(draft_prob, num_samples=1)
        draft_tokens.append(draft_token)
        draft_probs.append(draft_prob)
        current_ids = torch.cat([current_ids, draft_token], dim=1)

    draft_sequence = torch.cat(draft_tokens, dim=1)
    candidate_ids = torch.cat([input_ids, draft_sequence], dim=1)

    # 2. Target model verifies all draft tokens in one pass
    with torch.no_grad():
        target_logits = target_model(candidate_ids).logits[
            :, input_ids.size(1) - 1:-1, :
        ]

    target_probs = torch.softmax(target_logits / (temperature + 1e-8), dim=-1)

    # 3. Accept/reject each draft token
    accepted_tokens = []
    num_accepted = 0

    for step in range(draft_steps):
        token = draft_sequence[:, step]
        p_draft = draft_probs[step].gather(1, token.unsqueeze(1)).squeeze(1)
        p_target = target_probs[:, step, :].gather(1, token.unsqueeze(1)).squeeze(1)

        acceptance_prob = torch.clamp(p_target / (p_draft + 1e-8), max=1.0)
        accepted = torch.rand_like(acceptance_prob) < acceptance_prob

        if not accepted.all():
            break

        accepted_tokens.append(token)
        num_accepted += 1

    # 4. Final token from target model
    last_logits = target_model(candidate_ids).logits[:, input_ids.size(1) + num_accepted - 1, :]
    last_prob = torch.softmax(last_logits / (temperature + 1e-8), dim=-1)

    if num_accepted < draft_steps:
        correction = torch.clamp(last_prob - draft_probs[num_accepted], min=0)
        correction = correction / (correction.sum(dim=-1, keepdim=True) + 1e-8)
        last_token = torch.multinomial(correction, num_samples=1)
    else:
        last_token = torch.multinomial(last_prob, num_samples=1)

    accepted_tokens.append(last_token.squeeze(1))
    final_tokens = torch.stack(accepted_tokens, dim=1)

    return final_tokens, num_accepted, draft_steps


def analyze_speedup(acceptance_rate: float, draft_steps: int = 4) -> dict:
    """Speedup analysis based on acceptance rate"""

    # E[accepted tokens] = sum_{k=0}^{K} alpha^k
    expected_accepted = sum(
        acceptance_rate ** k for k in range(draft_steps + 1)
    )

    # Draft model is 1/10 the size of target model
    draft_model_ratio = 0.1

    steps_with_speculative = draft_steps * draft_model_ratio + 1
    speedup = expected_accepted / steps_with_speculative

    return {
        "acceptance_rate": acceptance_rate,
        "expected_accepted": expected_accepted,
        "speedup": speedup
    }

print("Speculative Decoding speedup by acceptance rate (K=4)")
print("=" * 55)
for alpha in [0.5, 0.6, 0.7, 0.8, 0.9, 0.95]:
    result = analyze_speedup(alpha, draft_steps=4)
    print(f"Accept rate {alpha:.0%}: expected {result['expected_accepted']:.2f} tokens, "
          f"speedup {result['speedup']:.2f}x")

5.3 Medusa: Multiple Draft Heads

import torch
import torch.nn as nn

class MedusaHead(nn.Module):
    """
    Medusa: attach multiple draft heads to a single model

    Each head predicts a future token:
    - Head 1: predicts t+1
    - Head 2: predicts t+2
    - Head N: predicts t+N
    """

    def __init__(
        self,
        hidden_size: int,
        vocab_size: int,
        num_heads: int = 4,
        hidden_layers: int = 1
    ):
        super().__init__()
        self.num_heads = num_heads

        self.heads = nn.ModuleList([
            nn.Sequential(
                *[nn.Linear(hidden_size, hidden_size, bias=False),
                  nn.SiLU()] * hidden_layers,
                nn.Linear(hidden_size, vocab_size, bias=False)
            )
            for _ in range(num_heads)
        ])

    def forward(self, hidden_states: torch.Tensor):
        """
        Returns list of logits for each future position
        """
        return [head(hidden_states) for head in self.heads]


class MedusaModel(nn.Module):
    """Medusa full model"""

    def __init__(self, base_model, vocab_size: int, num_medusa_heads: int = 4):
        super().__init__()
        self.base_model = base_model
        hidden_size = base_model.config.hidden_size
        self.medusa_heads = MedusaHead(hidden_size, vocab_size, num_medusa_heads)

    def forward(self, input_ids: torch.Tensor, use_medusa: bool = False):
        base_output = self.base_model(input_ids, output_hidden_states=True)
        base_logits = base_output.logits

        if not use_medusa:
            return base_logits, None

        last_hidden = base_output.hidden_states[-1]
        medusa_logits = self.medusa_heads(last_hidden)

        return base_logits, medusa_logits

6. FlashAttention: Memory-Efficient Attention

6.1 Standard Attention's HBM Bottleneck

Standard attention repeatedly writes and reads intermediate results to HBM:

Standard Attention memory ops:
1. Read Q, K from HBM           → read:  O(N * d)
2. Compute S = Q @ K.T          → write: O(N^2)  ← bottleneck!
3. Read S from HBM for softmax  → read:  O(N^2)
4. Store P = softmax(S)         → write: O(N^2)
5. Read P for P @ V             → read:  O(N^2)
6. Store final output           → write: O(N * d)

Total HBM accesses: O(N^2) — quadratic in sequence length!

6.2 FlashAttention's Tiling Strategy

import torch
import math

def flash_attention_v1(Q, K, V, block_size=64):
    """
    FlashAttention v1 simplified implementation
    Avoids storing the full attention matrix in HBM via tiling

    Key: Online Softmax for block-wise processing
    """
    batch_size, num_heads, seq_len, d_head = Q.shape

    scale = 1.0 / math.sqrt(d_head)
    Q = Q * scale

    O = torch.zeros_like(Q)
    L = torch.zeros(batch_size, num_heads, seq_len, 1, device=Q.device)
    M = torch.full((batch_size, num_heads, seq_len, 1), float('-inf'), device=Q.device)

    num_blocks = (seq_len + block_size - 1) // block_size

    for j in range(num_blocks):
        k_start = j * block_size
        k_end = min((j + 1) * block_size, seq_len)
        K_j = K[:, :, k_start:k_end, :]
        V_j = V[:, :, k_start:k_end, :]

        for i in range(num_blocks):
            q_start = i * block_size
            q_end = min((i + 1) * block_size, seq_len)
            Q_i = Q[:, :, q_start:q_end, :]
            O_i = O[:, :, q_start:q_end, :]
            L_i = L[:, :, q_start:q_end, :]
            M_i = M[:, :, q_start:q_end, :]

            # Compute attention scores in SRAM
            S_ij = torch.matmul(Q_i, K_j.transpose(-2, -1))

            # Online Softmax update
            M_new = torch.maximum(M_i, S_ij.max(dim=-1, keepdim=True)[0])
            P_ij = torch.exp(S_ij - M_new)
            L_new = torch.exp(M_i - M_new) * L_i + P_ij.sum(dim=-1, keepdim=True)

            # Rescale and update output
            O_new = torch.exp(M_i - M_new) * O_i + torch.matmul(P_ij, V_j)

            O[:, :, q_start:q_end, :] = O_new
            L[:, :, q_start:q_end, :] = L_new
            M[:, :, q_start:q_end, :] = M_new

    return O / L


def compare_attention_implementations():
    """Compare FlashAttention vs standard attention"""

    batch_size, num_heads, seq_len, d_head = 2, 32, 4096, 128

    Q = torch.randn(batch_size, num_heads, seq_len, d_head, device='cuda', dtype=torch.float16)
    K = torch.randn_like(Q)
    V = torch.randn_like(Q)

    import torch.nn.functional as F
    from torch.nn.attention import SDPBackend, sdpa_kernel

    with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
        flash_output = F.scaled_dot_product_attention(Q, K, V)

    scale = 1.0 / math.sqrt(d_head)
    attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
    attn_weights = torch.softmax(attn_scores, dim=-1)
    standard_output = torch.matmul(attn_weights, V)

    max_diff = (flash_output - standard_output).abs().max().item()
    print(f"FlashAttention vs standard max diff: {max_diff:.6f}")

    standard_attn_bytes = batch_size * num_heads * seq_len * seq_len * 2  # FP16
    print(f"Standard attention matrix memory: {standard_attn_bytes / 1e9:.2f} GB")
    print(f"FlashAttention matrix memory: ~0 GB (tiled, never fully materialized)")

6.3 Using PyTorch SDPA

import torch
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel

def modern_attention(q, k, v, is_causal=True, dropout_p=0.0):
    """
    PyTorch 2.0+ scaled_dot_product_attention

    Automatically selects FlashAttention 2/3
    """
    return F.scaled_dot_product_attention(
        q, k, v,
        attn_mask=None,
        dropout_p=dropout_p,
        is_causal=is_causal,
        scale=None  # defaults to 1/sqrt(d_head)
    )

# Flash Attention version highlights
flash_versions = {
    "FlashAttention 1 (arXiv:2205.14135)": {
        "key_innovation": "Tiling + Online Softmax",
        "memory": "O(N) — no attention matrix stored",
        "speedup": "2-4x vs standard"
    },
    "FlashAttention 2 (arXiv:2307.08691)": {
        "key_innovation": "Work partitioning, FP16/BF16",
        "memory": "O(N)",
        "speedup": "5-9x vs standard on H100"
    },
    "FlashAttention 3 (arXiv:2407.08608)": {
        "key_innovation": "H100-specific, FP8, async pipeline",
        "memory": "O(N)",
        "speedup": "1.5-2x vs FA2 on H100"
    },
}

for name, info in flash_versions.items():
    print(f"\n{name}")
    for k, v in info.items():
        print(f"  {k}: {v}")

7. Multi-GPU Inference

7.1 Tensor Parallelism

Weight matrices are split across GPUs; each GPU handles a shard.

import torch
import torch.nn as nn

class TensorParallelLinear(nn.Module):
    """
    Tensor Parallel Linear layer (column-parallel)
    Each GPU owns out_features // world_size output neurons
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        world_size: int,
        rank: int
    ):
        super().__init__()
        self.world_size = world_size
        self.rank = rank
        self.local_out_features = out_features // world_size

        self.weight = nn.Parameter(
            torch.randn(self.local_out_features, in_features) / (in_features ** 0.5)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        local_output = nn.functional.linear(x, self.weight)
        # In a real distributed setup: dist.all_gather(output_list, local_output)
        return local_output


def setup_vllm_multiGPU(model_name: str, tp_size: int):
    """Set up multi-GPU vLLM inference"""
    from vllm import LLM, SamplingParams

    llm = LLM(
        model=model_name,
        tensor_parallel_size=tp_size,
        gpu_memory_utilization=0.9
    )

    return llm

7.2 Full vLLM Usage

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
import asyncio
import time

def vllm_basic_usage():
    """vLLM basic usage"""

    llm = LLM(
        model="meta-llama/Llama-2-7b-hf",
        tensor_parallel_size=1,
        gpu_memory_utilization=0.90,
        max_model_len=4096,
        quantization=None,             # "awq", "gptq", "squeezellm"
        dtype="auto",
        max_num_seqs=256,
        enable_prefix_caching=True,
        use_v2_block_manager=True,
    )

    sampling_params = SamplingParams(
        temperature=0.8,
        top_p=0.95,
        max_tokens=200,
    )

    prompts = [
        "Explain quantum computing in simple terms",
        "What is the future of artificial intelligence?",
        "How does the human brain work?",
    ]

    outputs = llm.generate(prompts, sampling_params)

    for output in outputs:
        print(f"Prompt: {output.prompt[:50]}...")
        print(f"Output: {output.outputs[0].text[:100]}...")
        print(f"Tokens: {len(output.outputs[0].token_ids)}")
        print()

    return outputs


async def vllm_async_server():
    """vLLM async engine usage"""

    engine_args = AsyncEngineArgs(
        model="meta-llama/Llama-2-7b-hf",
        tensor_parallel_size=1,
        gpu_memory_utilization=0.90,
        max_model_len=4096,
        enable_prefix_caching=True,
    )

    engine = AsyncLLMEngine.from_engine_args(engine_args)

    async def generate_stream(prompt: str, request_id: str):
        sampling_params = SamplingParams(temperature=0.8, max_tokens=200)

        full_text = ""
        async for output in engine.generate(prompt, sampling_params, request_id):
            if output.outputs:
                delta = output.outputs[0].text[len(full_text):]
                full_text = output.outputs[0].text

                if delta:
                    print(f"[{request_id}] {delta}", end="", flush=True)

            if output.finished:
                print(f"\n[{request_id}] done")

    await asyncio.gather(
        generate_stream("What is AI?", "req_1"),
        generate_stream("Explain machine learning", "req_2"),
        generate_stream("What is deep learning?", "req_3"),
    )

8. Inference Engine Comparison

8.1 Major Engine Features

EngineAuthorKey FeaturesBest For
vLLMUC BerkeleyPagedAttention, Continuous BatchingGeneral LLM serving
TGIHuggingFaceFlash Attention 2, SpeculativeHF model serving
TensorRT-LLMNVIDIANVIDIA-optimized, FP8Max NVIDIA perf
DeepSpeed-MIIMicrosoftZeRO inference, huge modelsMulti-GPU giant models
llama.cppG. GerganovCPU-optimized, GGUFLocal execution

8.2 Benchmark Results

import time

def run_inference_benchmark(engine_name: str, model, tokenizer, prompts, max_tokens=100):
    """Simple inference benchmark"""

    num_warmup = 5
    num_runs = 50

    # Warmup
    for prompt in prompts[:num_warmup]:
        _ = model.generate(
            tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
            max_new_tokens=max_tokens,
            do_sample=False
        )

    # Timed runs
    import torch
    torch.cuda.synchronize()
    start = time.perf_counter()

    total_tokens = 0
    for prompt in prompts[:num_runs]:
        output = model.generate(
            tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
            max_new_tokens=max_tokens,
            do_sample=False
        )
        total_tokens += max_tokens

    torch.cuda.synchronize()
    elapsed = time.perf_counter() - start

    throughput = total_tokens / elapsed
    latency_ms = elapsed / num_runs * 1000

    print(f"\n{engine_name}")
    print(f"  Throughput: {throughput:.1f} tokens/sec")
    print(f"  Avg latency: {latency_ms:.1f} ms/request")

    return throughput, latency_ms


# Example comparison (A100 80GB, Llama-2-7B, batch=1, 100 output tokens)
benchmark_results = {
    "HuggingFace (FP16)":         {"throughput": 52,  "latency_ms": 1924},
    "HuggingFace (Flash Attn 2)": {"throughput": 78,  "latency_ms": 1280},
    "vLLM":                        {"throughput": 120, "latency_ms": 832},
    "vLLM + AWQ 4bit":             {"throughput": 165, "latency_ms": 606},
    "TensorRT-LLM":                {"throughput": 180, "latency_ms": 555},
}

print("LLM inference engine benchmark (A100 80GB, Llama-2-7B)")
print("=" * 65)
print(f"{'Engine':<30} {'Throughput (tok/s)':<22} {'Latency (ms)':<15}")
print("-" * 65)
for engine, stats in benchmark_results.items():
    print(f"{engine:<30} {stats['throughput']:<22} {stats['latency_ms']:<15}")

9. Prompt Caching

9.1 Prefix Caching

Reuse KV Cache when the same system prompt or document is processed repeatedly.

from vllm import LLM, SamplingParams
import time

def demonstrate_prefix_caching():
    """Demonstrate prefix caching benefit"""

    llm = LLM(
        model="meta-llama/Llama-2-7b-hf",
        enable_prefix_caching=True,
        max_model_len=4096,
    )

    # Long system prompt common to all requests (1000+ tokens)
    system_prompt = (
        "You are a helpful AI assistant with expertise in Python, "
        "machine learning, data science, and cloud computing. "
    ) * 50

    questions = [
        "How do I optimize a Python loop?",
        "What is gradient descent?",
        "Explain containerization.",
        "What is a neural network?",
    ]

    sampling_params = SamplingParams(temperature=0.7, max_tokens=100)
    cold_prompts = [f"{system_prompt}\n\nQuestion: {q}" for q in questions]

    # Cold start (no cache)
    cold_start = time.time()
    llm.generate(cold_prompts, sampling_params)
    cold_time = time.time() - cold_start

    # Warm start (cache hit)
    warm_start = time.time()
    llm.generate(cold_prompts, sampling_params)
    warm_time = time.time() - warm_start

    print(f"Cold start (no cache): {cold_time:.2f}s")
    print(f"Warm start (cache hit): {warm_time:.2f}s")
    print(f"Speedup: {cold_time / warm_time:.2f}x")


def radix_tree_prefix_cache():
    """Radix Tree prefix cache implementation"""

    class RadixNode:
        def __init__(self):
            self.children: dict = {}
            self.kv_cache_block_id: int = None

    class RadixTreeCache:
        """
        Manages token sequences in a Radix Tree to share
        common-prefix KV caches
        """

        def __init__(self):
            self.root = RadixNode()
            self.cache_hits = 0
            self.cache_misses = 0

        def insert(self, token_ids: list, block_id: int):
            node = self.root
            for token_id in token_ids:
                if token_id not in node.children:
                    node.children[token_id] = RadixNode()
                node = node.children[token_id]
            node.kv_cache_block_id = block_id

        def lookup(self, token_ids: list) -> tuple:
            """Find the longest matching prefix"""
            node = self.root
            matched_len = 0
            last_block_id = None

            for i, token_id in enumerate(token_ids):
                if token_id in node.children:
                    node = node.children[token_id]
                    matched_len = i + 1
                    if node.kv_cache_block_id is not None:
                        last_block_id = node.kv_cache_block_id
                else:
                    break

            if last_block_id is not None:
                self.cache_hits += 1
            else:
                self.cache_misses += 1

            return matched_len, last_block_id

        def get_hit_rate(self) -> float:
            total = self.cache_hits + self.cache_misses
            return self.cache_hits / total if total > 0 else 0.0

    return RadixTreeCache()

10. Practical Optimization Checklist

10.1 Step-by-Step Optimization Guide

class LLMOptimizationChecklist:
    """LLM inference optimization checklist"""

    optimizations = [
        {
            "category": "Baseline",
            "level": 1,
            "items": [
                {
                    "name": "Use FP16/BF16",
                    "impact": "High",
                    "effort": "Low",
                    "description": "FP32 → FP16: 2x memory saving, speed improvement",
                    "code": """
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)"""
                },
                {
                    "name": "Enable Flash Attention 2",
                    "impact": "High",
                    "effort": "Low",
                    "description": "2-4x attention speedup, memory saving",
                    "code": """
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16
)"""
                },
            ]
        },
        {
            "category": "KV Cache Optimization",
            "level": 2,
            "items": [
                {
                    "name": "Choose GQA/MQA model",
                    "impact": "High",
                    "effort": "Medium",
                    "description": "4-8x KV Cache reduction, larger effective batch"
                },
                {
                    "name": "Prefix caching",
                    "impact": "Medium",
                    "effort": "Low",
                    "description": "Reuse KV Cache for common system prompts"
                },
            ]
        },
        {
            "category": "Batching Optimization",
            "level": 3,
            "items": [
                {
                    "name": "Continuous Batching with vLLM",
                    "impact": "Very High",
                    "effort": "Low",
                    "description": "2-5x throughput improvement",
                    "code": """
from vllm import LLM, SamplingParams

llm = LLM(
    model=model_name,
    gpu_memory_utilization=0.90,
    enable_prefix_caching=True,
)"""
                },
            ]
        },
        {
            "category": "Model Optimization",
            "level": 4,
            "items": [
                {
                    "name": "AWQ 4-bit quantization",
                    "impact": "High",
                    "effort": "Medium",
                    "description": "4x memory reduction, 1.5-2x speed",
                    "code": """
from awq import AutoAWQForCausalLM

model = AutoAWQForCausalLM.from_quantized(
    "model-awq-4bit",
    fuse_layers=True
)"""
                },
                {
                    "name": "Speculative Decoding",
                    "impact": "Medium",
                    "effort": "High",
                    "description": "2-3x speedup (requires suitable draft model)"
                },
            ]
        },
        {
            "category": "Hardware Optimization",
            "level": 5,
            "items": [
                {
                    "name": "Tensor Parallelism",
                    "impact": "Very High",
                    "effort": "Medium",
                    "description": "Linear throughput scaling with multiple GPUs"
                },
                {
                    "name": "CUDA graph capture",
                    "impact": "Medium",
                    "effort": "High",
                    "description": "Eliminate kernel launch overhead"
                },
            ]
        }
    ]

    @classmethod
    def print_checklist(cls):
        print("=" * 70)
        print("LLM Inference Optimization — Step-by-Step Checklist")
        print("=" * 70)

        for category in cls.optimizations:
            print(f"\n[Level {category['level']}] {category['category']}")
            print("-" * 50)

            for item in category['items']:
                impact_stars = {"Very High": "★★★", "High": "★★", "Medium": "★", "Low": "☆"}
                print(f"  ✓ {item['name']}")
                print(f"    Impact: {impact_stars.get(item['impact'], '?')} {item['impact']}")
                print(f"    Note: {item['description']}")

        print("\nRecommended optimization order:")
        print("1. Switch to BF16/FP16                  (immediate, free)")
        print("2. Enable Flash Attention 2             (immediate, just install)")
        print("3. Serve with vLLM                      (max throughput)")
        print("4. AWQ/GPTQ 4-bit quantization          (4x memory reduction)")
        print("5. Speculative Decoding                 (latency improvement)")
        print("6. Multi-GPU Tensor Parallelism         (scale out)")

LLMOptimizationChecklist.print_checklist()

10.2 End-to-End Production Setup

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from fastapi import FastAPI
from pydantic import BaseModel
import asyncio
import uvicorn

app = FastAPI(title="LLM Inference API")

class GenerateRequest(BaseModel):
    prompt: str
    max_tokens: int = 200
    temperature: float = 0.7
    top_p: float = 0.95
    stream: bool = False

class GenerateResponse(BaseModel):
    text: str
    tokens_generated: int
    finish_reason: str

# Global engine instance
engine: AsyncLLMEngine = None

def create_optimized_engine(model_name: str, **kwargs) -> AsyncLLMEngine:
    """Create a production-optimized vLLM engine"""

    engine_args = AsyncEngineArgs(
        model=model_name,
        tensor_parallel_size=kwargs.get("tp_size", 1),
        gpu_memory_utilization=kwargs.get("gpu_util", 0.90),
        max_model_len=kwargs.get("max_model_len", 4096),
        quantization=kwargs.get("quantization", None),  # "awq" or "gptq"
        dtype="auto",
        max_num_seqs=kwargs.get("max_num_seqs", 256),
        enable_prefix_caching=True,
        use_v2_block_manager=True,
        speculative_model=kwargs.get("draft_model", None),  # optional draft model
        num_speculative_tokens=kwargs.get("num_spec_tokens", 5),
    )

    return AsyncLLMEngine.from_engine_args(engine_args)

@app.on_event("startup")
async def startup():
    global engine
    engine = create_optimized_engine(
        model_name="meta-llama/Llama-2-7b-hf",
        tp_size=1,
        gpu_util=0.90,
        max_model_len=4096,
    )

@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
    sampling_params = SamplingParams(
        temperature=request.temperature,
        top_p=request.top_p,
        max_tokens=request.max_tokens,
    )

    request_id = f"req_{id(request)}"

    final_output = None
    async for output in engine.generate(request.prompt, sampling_params, request_id):
        final_output = output

    if final_output and final_output.outputs:
        result = final_output.outputs[0]
        return GenerateResponse(
            text=result.text,
            tokens_generated=len(result.token_ids),
            finish_reason=result.finish_reason or "length"
        )

    return GenerateResponse(text="", tokens_generated=0, finish_reason="error")

# Run with: uvicorn script:app --host 0.0.0.0 --port 8000 --workers 1

Conclusion

LLM inference optimization requires a layered approach.

Key takeaways:

  1. Understand KV Cache: Memorize 2 * layers * kv_heads * d_head * seq_len * dtype_bytes. Use GQA/MQA to cut KV Cache by 4–8x.

  2. PagedAttention: vLLM's core innovation — borrowed from OS virtual memory to eliminate KV Cache fragmentation.

  3. Continuous Batching: Immediately insert new requests as completions happen, maximizing GPU utilization.

  4. Speculative Decoding: Small draft model + large verifier = 2–3x speedup at the right acceptance rate.

  5. FlashAttention: Reduces attention's memory from O(N^2) to O(N), enabling long contexts.

Production deployment recommendations:

  • Small services: vLLM + AWQ 4-bit + prefix caching
  • Large services: TensorRT-LLM or vLLM + Tensor Parallelism
  • Lowest latency: Speculative Decoding + CUDA graphs

References