Skip to content
Published on

LLM 추론 최적화 완전 가이드: KV Cache, Speculative Decoding, Continuous Batching

Authors

들어가며

대형 언어 모델(LLM)을 프로덕션에 배포하면 즉시 직면하는 문제가 있습니다. 바로 추론 속도와 비용입니다. GPT-4 수준의 모델을 처음 쿼리했을 때 수 초의 지연이 발생하고, 동시 사용자가 늘어나면 처리량이 급격히 저하됩니다.

이 가이드는 LLM 추론 최적화의 핵심 기술을 완전히 파헤칩니다. KV Cache의 작동 원리부터 PagedAttention, Speculative Decoding, FlashAttention, 그리고 최신 vLLM과 TensorRT-LLM 엔진까지 — 단순한 사용법이 아니라 왜 작동하는지 원리를 이해합니다.


1. LLM 추론 과정 이해

1.1 두 단계: Prefill과 Decode

LLM 텍스트 생성은 두 단계로 나뉩니다.

Prefill 단계 (프롬프트 처리)

  • 입력 프롬프트의 모든 토큰을 동시에 처리
  • 각 레이어에서 Key/Value 캐시를 생성하여 저장
  • 연산 집약적(Compute-Bound): GPU 연산 능력이 병목
  • TTFT (Time To First Token)에 직접 영향

Decode 단계 (토큰 생성)

  • 한 번에 하나의 토큰을 자기회귀 방식으로 생성
  • 이전에 생성된 모든 토큰의 KV Cache를 참조
  • 메모리 대역폭 집약적(Memory-Bound): HBM 읽기 속도가 병목
  • TPOT (Time Per Output Token)에 직접 영향
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time

def measure_prefill_decode_time(model, tokenizer, prompt: str, max_new_tokens: int = 100):
    """Prefill과 Decode 단계 시간 측정"""

    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    input_len = inputs["input_ids"].size(1)

    # Prefill 측정
    torch.cuda.synchronize()
    prefill_start = time.perf_counter()

    with torch.no_grad():
        # 첫 번째 토큰까지 (Prefill + 첫 Decode)
        first_output = model.generate(
            **inputs,
            max_new_tokens=1,
            do_sample=False
        )

    torch.cuda.synchronize()
    ttft = time.perf_counter() - prefill_start

    # 전체 생성 측정
    torch.cuda.synchronize()
    total_start = time.perf_counter()

    with torch.no_grad():
        full_output = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False
        )

    torch.cuda.synchronize()
    total_time = time.perf_counter() - total_start

    output_tokens = full_output.size(1) - input_len
    decode_time = total_time - ttft
    tpot = decode_time / max(output_tokens - 1, 1)

    print(f"입력 토큰 수: {input_len}")
    print(f"생성 토큰 수: {output_tokens}")
    print(f"TTFT (첫 토큰 지연): {ttft * 1000:.1f} ms")
    print(f"TPOT (토큰당 시간): {tpot * 1000:.1f} ms")
    print(f"처리량: {output_tokens / total_time:.1f} 토큰/초")

    return ttft, tpot

1.2 메모리 병목 분석

Decode 단계가 왜 메모리 집약적인지 이해합니다.

def analyze_memory_bandwidth():
    """LLM 추론의 메모리 대역폭 분석"""

    # 예시: Llama-2-7B 설정
    model_params = {
        "num_layers": 32,
        "hidden_size": 4096,
        "num_heads": 32,
        "head_dim": 128,
        "vocab_size": 32000,
    }

    dtype_bytes = 2  # FP16: 2 bytes

    # 가중치 메모리 (한 번 로드)
    # 각 Transformer 레이어의 가중치
    attn_weight = 4 * model_params["hidden_size"] ** 2  # Q, K, V, O 프로젝션
    ffn_weight = 8 * model_params["hidden_size"] ** 2  # Up, Gate, Down 프로젝션 (SwiGLU)
    layer_weight = (attn_weight + ffn_weight) * dtype_bytes

    total_weight_bytes = layer_weight * model_params["num_layers"]
    total_weight_gb = total_weight_bytes / 1e9

    print(f"모델 가중치: {total_weight_gb:.2f} GB")

    # KV Cache 메모리 (시퀀스 길이에 비례)
    seq_len = 2048
    kv_cache_per_token = (
        2 *  # K와 V
        model_params["num_layers"] *
        model_params["num_heads"] *
        model_params["head_dim"] *
        dtype_bytes
    )

    kv_cache_total = kv_cache_per_token * seq_len / 1e6
    print(f"KV Cache ({seq_len} 토큰): {kv_cache_total:.2f} MB")
    print(f"토큰당 KV Cache: {kv_cache_per_token} bytes")

    # A100 메모리 대역폭: 2 TB/s
    memory_bandwidth_tbs = 2.0  # TB/s

    # Decode 단계: 한 토큰 생성 시 가중치를 한 번씩 읽음
    # 배치 크기가 작을수록 연산 대비 메모리 읽기 비율이 높아짐
    batch_size = 1
    flops_per_token = 2 * total_weight_bytes  # 대략적인 FLOPs

    # A100 FP16: 312 TFLOPS
    compute_throughput = 312e12  # FLOPS

    # 메모리 대역폭 기준 처리량
    memory_bound_tps = memory_bandwidth_tbs * 1e12 / total_weight_bytes
    # 연산 기준 처리량
    compute_bound_tps = compute_throughput / flops_per_token

    print(f"\n배치 크기 {batch_size}일 때:")
    print(f"메모리 병목 처리량: {memory_bound_tps:.1f} 토큰/초")
    print(f"연산 병목 처리량: {compute_bound_tps:.1f} 토큰/초")
    print(f"실제 병목: {'메모리' if memory_bound_tps < compute_bound_tps else '연산'}")

analyze_memory_bandwidth()

1.3 추론 비용 분석

def estimate_inference_cost(
    model_size_b: float,
    tokens_per_request: int,
    requests_per_day: int,
    gpu_cost_per_hour: float = 3.0  # A100 시간당 가격 (USD)
):
    """추론 비용 추정"""

    # 처리량 추정 (경험적 수치)
    # 7B 모델: ~100 tok/s, 70B 모델: ~20 tok/s (A100 기준)
    throughput_tps = 100 / (model_size_b / 7) ** 0.6

    total_tokens_per_day = tokens_per_request * requests_per_day
    seconds_needed = total_tokens_per_day / throughput_tps
    hours_needed = seconds_needed / 3600

    # GPU 수 (병렬 처리 고려)
    # 일반적으로 1개 GPU로 처리하면
    daily_cost = hours_needed * gpu_cost_per_hour

    cost_per_1k_tokens = daily_cost / (total_tokens_per_day / 1000)

    print(f"모델: {model_size_b}B 파라미터")
    print(f"일일 요청: {requests_per_day:,}건")
    print(f"요청당 토큰: {tokens_per_request}")
    print(f"총 일일 토큰: {total_tokens_per_day:,}")
    print(f"예상 처리량: {throughput_tps:.1f} 토큰/초")
    print(f"필요 GPU 시간: {hours_needed:.2f}시간")
    print(f"일일 비용: ${daily_cost:.2f}")
    print(f"1K 토큰당 비용: ${cost_per_1k_tokens:.4f}")

# 예시
estimate_inference_cost(
    model_size_b=7.0,
    tokens_per_request=500,
    requests_per_day=10000
)

2. KV Cache: 핵심 최적화 기술

2.1 KV Cache의 필요성

트랜스포머 어텐션에서 각 토큰은 이전 모든 토큰과 어텐션을 계산합니다. 이미 처리한 토큰을 재계산하지 않으려면 K와 V 행렬을 캐싱합니다.

import torch
import torch.nn as nn
import math

class MultiHeadAttentionWithKVCache(nn.Module):
    """KV Cache를 지원하는 멀티헤드 어텐션"""

    def __init__(self, d_model: int, num_heads: int, max_seq_len: int = 4096):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

        # KV Cache 초기화
        self.register_buffer(
            'k_cache',
            torch.zeros(1, max_seq_len, num_heads, self.d_head)
        )
        self.register_buffer(
            'v_cache',
            torch.zeros(1, max_seq_len, num_heads, self.d_head)
        )
        self.cache_pos = 0

    def forward(
        self,
        x: torch.Tensor,
        use_cache: bool = True,
        position: int = None
    ):
        batch_size, seq_len, _ = x.shape

        # Q, K, V 계산
        q = self.W_q(x).reshape(batch_size, seq_len, self.num_heads, self.d_head)
        k = self.W_k(x).reshape(batch_size, seq_len, self.num_heads, self.d_head)
        v = self.W_v(x).reshape(batch_size, seq_len, self.num_heads, self.d_head)

        if use_cache:
            # 캐시에 현재 K, V 저장
            start_pos = self.cache_pos if position is None else position
            self.k_cache[:, start_pos:start_pos + seq_len] = k
            self.v_cache[:, start_pos:start_pos + seq_len] = v

            if position is None:
                self.cache_pos += seq_len

            # 캐시된 전체 K, V 사용
            total_len = self.cache_pos if position is None else start_pos + seq_len
            k = self.k_cache[:, :total_len]
            v = self.v_cache[:, :total_len]

        # 어텐션 계산
        scale = math.sqrt(self.d_head)

        # [batch, num_heads, seq_len, d_head] 형태로 변환
        q = q.transpose(1, 2)  # [B, H, S, D]
        k = k.transpose(1, 2)  # [B, H, T, D] (T: total cached length)
        v = v.transpose(1, 2)

        # 어텐션 스코어: [B, H, S, T]
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale

        # 소프트맥스
        attn_weights = torch.softmax(attn_scores, dim=-1)

        # 가중 합: [B, H, S, D]
        output = torch.matmul(attn_weights, v)

        # 원래 형태로 변환
        output = output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
        output = self.W_o(output)

        return output

    def clear_cache(self):
        """캐시 초기화"""
        self.k_cache.zero_()
        self.v_cache.zero_()
        self.cache_pos = 0

2.2 KV Cache 메모리 계산

def calculate_kv_cache_memory(
    model_config: dict,
    batch_size: int,
    seq_len: int,
    dtype_bytes: int = 2  # FP16
) -> dict:
    """KV Cache 메모리 사용량 계산"""

    num_layers = model_config["num_layers"]
    num_kv_heads = model_config.get("num_kv_heads", model_config["num_heads"])
    head_dim = model_config["head_dim"]

    # KV Cache 크기: 2 (K+V) * layers * kv_heads * head_dim * seq_len * dtype
    kv_cache_bytes = (
        2 *           # K와 V
        num_layers *
        num_kv_heads *
        head_dim *
        seq_len *
        batch_size *
        dtype_bytes
    )

    return {
        "kv_cache_bytes": kv_cache_bytes,
        "kv_cache_mb": kv_cache_bytes / 1e6,
        "kv_cache_gb": kv_cache_bytes / 1e9,
        "per_token_bytes": kv_cache_bytes // seq_len,
    }

# 모델별 KV Cache 비교
models = {
    "Llama-2-7B": {
        "num_layers": 32, "num_heads": 32,
        "num_kv_heads": 32, "head_dim": 128
    },
    "Llama-2-13B": {
        "num_layers": 40, "num_heads": 40,
        "num_kv_heads": 40, "head_dim": 128
    },
    "Llama-2-70B (GQA)": {
        "num_layers": 80, "num_heads": 64,
        "num_kv_heads": 8, "head_dim": 128  # GQA: 8 KV 헤드
    },
    "Mistral-7B (GQA)": {
        "num_layers": 32, "num_heads": 32,
        "num_kv_heads": 8, "head_dim": 128  # GQA: 8 KV 헤드
    },
}

print("KV Cache 메모리 사용량 (배치=1, seq=4096)")
print("=" * 70)
for name, config in models.items():
    result = calculate_kv_cache_memory(config, batch_size=1, seq_len=4096)
    print(f"{name:<25} {result['kv_cache_gb']:.2f} GB  "
          f"(토큰당 {result['per_token_bytes']:,} bytes)")

2.3 Grouped Query Attention (GQA)

GQA는 KV Cache를 줄이는 핵심 기술입니다. 여러 Query 헤드가 더 적은 수의 KV 헤드를 공유합니다.

import torch
import torch.nn as nn
import math

class GroupedQueryAttention(nn.Module):
    """Grouped Query Attention (GQA) 구현"""

    def __init__(
        self,
        d_model: int,
        num_q_heads: int,
        num_kv_heads: int,
    ):
        super().__init__()
        assert num_q_heads % num_kv_heads == 0

        self.d_model = d_model
        self.num_q_heads = num_q_heads
        self.num_kv_heads = num_kv_heads
        self.num_groups = num_q_heads // num_kv_heads
        self.d_head = d_model // num_q_heads

        self.W_q = nn.Linear(d_model, num_q_heads * self.d_head, bias=False)
        self.W_k = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
        self.W_v = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x: torch.Tensor, kv_cache=None):
        batch_size, seq_len, _ = x.shape

        # Q: [B, S, num_q_heads * d_head]
        q = self.W_q(x).reshape(batch_size, seq_len, self.num_q_heads, self.d_head)
        k = self.W_k(x).reshape(batch_size, seq_len, self.num_kv_heads, self.d_head)
        v = self.W_v(x).reshape(batch_size, seq_len, self.num_kv_heads, self.d_head)

        # KV Cache 업데이트
        if kv_cache is not None:
            k = torch.cat([kv_cache["k"], k], dim=1)
            v = torch.cat([kv_cache["v"], v], dim=1)
        new_kv_cache = {"k": k, "v": v}

        total_len = k.size(1)

        # [B, num_heads, S, d_head] 형태로
        q = q.transpose(1, 2)  # [B, Q_heads, S, d_head]
        k = k.transpose(1, 2)  # [B, KV_heads, T, d_head]
        v = v.transpose(1, 2)

        # GQA: KV 헤드를 Q 헤드 수만큼 반복
        # k: [B, KV_heads, T, d_head] -> [B, Q_heads, T, d_head]
        k = k.repeat_interleave(self.num_groups, dim=1)
        v = v.repeat_interleave(self.num_groups, dim=1)

        # 어텐션 계산
        scale = math.sqrt(self.d_head)
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale
        attn_weights = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_weights, v)

        output = output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
        return self.W_o(output), new_kv_cache

# MHA vs GQA vs MQA 메모리 비교
def compare_attention_variants():
    """어텐션 변형 KV Cache 메모리 비교"""

    # 70B 모델 기준 (Llama-2-70B)
    num_layers = 80
    d_head = 128
    seq_len = 4096
    batch_size = 1
    dtype_bytes = 2  # FP16

    variants = {
        "MHA (32 KV heads)": 32,
        "GQA (8 KV heads)": 8,
        "MQA (1 KV head)": 1,
    }

    print("70B 모델 어텐션 변형별 KV Cache 비교")
    print(f"(seq_len={seq_len}, batch={batch_size})")
    print("=" * 55)

    for name, num_kv_heads in variants.items():
        kv_bytes = 2 * num_layers * num_kv_heads * d_head * seq_len * batch_size * dtype_bytes
        kv_gb = kv_bytes / 1e9
        print(f"{name:<25} {kv_gb:.2f} GB")

compare_attention_variants()

2.4 DeepSeek MLA (Multi-Head Latent Attention)

DeepSeek-V2에서 도입된 MLA는 KV Cache를 저차원 잠재 벡터로 압축합니다.

class MultiHeadLatentAttention(nn.Module):
    """
    DeepSeek MLA - KV Cache를 저차원 잠재 벡터로 압축

    핵심 아이디어:
    - KV를 고차원에 저장하는 대신, 저차원 잠재 벡터 c_kv를 저장
    - c_kv에서 K, V를 복원 (업프로젝션)
    - KV Cache 크기: num_layers * kv_lora_rank * seq_len
      (vs 기존: 2 * num_layers * num_kv_heads * d_head * seq_len)
    """

    def __init__(
        self,
        d_model: int = 5120,
        num_heads: int = 128,
        kv_lora_rank: int = 512,  # 저차원 잠재 차원
        qk_nope_head_dim: int = 128,
        qk_rope_head_dim: int = 64,
        v_head_dim: int = 128,
    ):
        super().__init__()

        self.d_model = d_model
        self.num_heads = num_heads
        self.kv_lora_rank = kv_lora_rank
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.v_head_dim = v_head_dim

        # Q 프로젝션 (LoRA 스타일)
        self.q_a_proj = nn.Linear(d_model, 1536, bias=False)  # 다운프로젝션
        self.q_b_proj = nn.Linear(1536, num_heads * (qk_nope_head_dim + qk_rope_head_dim), bias=False)

        # KV 다운프로젝션: d_model -> kv_lora_rank
        # 이것만 KV Cache에 저장!
        self.kv_a_proj = nn.Linear(
            d_model,
            kv_lora_rank + qk_rope_head_dim,
            bias=False
        )

        # KV 업프로젝션: kv_lora_rank -> K, V
        self.kv_b_proj = nn.Linear(
            kv_lora_rank,
            num_heads * (qk_nope_head_dim + v_head_dim),
            bias=False
        )

        self.o_proj = nn.Linear(num_heads * v_head_dim, d_model, bias=False)

    def forward(self, x: torch.Tensor, compressed_kv_cache=None):
        """
        Args:
            x: [batch, seq, d_model]
            compressed_kv_cache: [batch, cache_len, kv_lora_rank + rope_dim]
        """
        batch_size, seq_len, _ = x.shape

        # Q 계산
        q = self.q_b_proj(self.q_a_proj(x))

        # KV 압축 (이 결과만 캐시에 저장)
        kv_compressed = self.kv_a_proj(x)  # [B, S, kv_lora_rank + rope_dim]

        # KV Cache 업데이트
        if compressed_kv_cache is not None:
            kv_compressed_total = torch.cat([compressed_kv_cache, kv_compressed], dim=1)
        else:
            kv_compressed_total = kv_compressed

        # 캐시된 압축 KV에서 실제 K, V 복원 (업프로젝션)
        kv_content = kv_compressed_total[:, :, :self.kv_lora_rank]  # rope 제외
        kv_full = self.kv_b_proj(kv_content)  # [B, T, num_heads * (nope + v_dim)]

        # 최종 어텐션 계산 (생략)

        return None, kv_compressed

# KV Cache 크기 비교 (DeepSeek-V2 기준)
def compare_mla_vs_mha():
    """MLA vs MHA KV Cache 비교"""

    seq_len = 4096
    dtype_bytes = 2  # BF16
    num_layers = 60  # DeepSeek-V2

    # MHA (기존)
    num_heads = 128
    head_dim = 128
    mha_kv_gb = 2 * num_layers * num_heads * head_dim * seq_len * dtype_bytes / 1e9

    # MLA (DeepSeek-V2)
    kv_lora_rank = 512
    rope_dim = 64
    mla_kv_gb = (kv_lora_rank + rope_dim) * num_layers * seq_len * dtype_bytes / 1e9

    print(f"MHA KV Cache: {mha_kv_gb:.2f} GB")
    print(f"MLA KV Cache: {mla_kv_gb:.2f} GB")
    print(f"절감 비율: {mha_kv_gb / mla_kv_gb:.1f}x")

compare_mla_vs_mha()

3. PagedAttention: vLLM의 핵심 혁신

3.1 기존 KV Cache의 문제점

기존 LLM 서빙 시스템은 요청마다 최대 시퀀스 길이를 위한 메모리를 미리 할당합니다.

요청 1: [PROMPT=200 tokens] [KV_CACHE=최대 2048-200=1848 tokens 예약] → 내부 단편화
요청 2: [PROMPT=100 tokens] [KV_CACHE=1948 tokens 예약]
요청 3: 메모리 부족으로 대기 (외부 단편화)

이로 인해 실제 GPU 메모리의 60~80%가 낭비됩니다.

3.2 PagedAttention 원리

OS의 가상 메모리에서 영감을 받아 KV Cache를 고정 크기의 물리 블록으로 관리합니다.

from dataclasses import dataclass, field
from typing import Dict, List, Optional, Set
import torch

@dataclass
class PhysicalBlock:
    """물리 메모리 블록"""
    block_id: int
    block_size: int  # 블록에 저장 가능한 토큰 수 (예: 16)
    device: str = "cuda"
    ref_count: int = 0  # 참조 카운트 (CoW를 위해)

    def __post_init__(self):
        # 실제 KV 텐서 할당
        # [2, num_layers, block_size, num_heads, head_dim]
        pass


@dataclass
class LogicalBlock:
    """논리 블록 (요청과 매핑)"""
    physical_block_id: int
    num_filled: int = 0  # 현재 채워진 토큰 수

class PagedKVCacheManager:
    """PagedAttention KV Cache 관리자"""

    def __init__(
        self,
        num_physical_blocks: int,
        block_size: int,
        num_layers: int,
        num_kv_heads: int,
        head_dim: int,
        device: str = "cuda"
    ):
        self.block_size = block_size
        self.num_layers = num_layers
        self.num_kv_heads = num_kv_heads
        self.head_dim = head_dim
        self.device = device

        # 물리 블록 풀 초기화
        self.free_blocks: List[int] = list(range(num_physical_blocks))
        self.all_blocks: Dict[int, PhysicalBlock] = {
            i: PhysicalBlock(block_id=i, block_size=block_size)
            for i in range(num_physical_blocks)
        }

        # 요청별 논리 블록 테이블
        self.block_tables: Dict[int, List[LogicalBlock]] = {}

        # 실제 KV Cache 텐서
        # [num_blocks, 2, num_layers, block_size, num_kv_heads, head_dim]
        self.kv_cache = torch.zeros(
            num_physical_blocks, 2, num_layers, block_size, num_kv_heads, head_dim,
            dtype=torch.float16,
            device=device
        )

    def allocate_blocks_for_request(self, request_id: int, num_tokens: int):
        """요청에 필요한 블록 할당"""
        num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size

        if len(self.free_blocks) < num_blocks_needed:
            raise RuntimeError(f"메모리 부족: {num_blocks_needed} 블록 필요, {len(self.free_blocks)} 가용")

        logical_blocks = []
        for i in range(num_blocks_needed):
            physical_id = self.free_blocks.pop(0)
            self.all_blocks[physical_id].ref_count = 1
            logical_blocks.append(
                LogicalBlock(physical_block_id=physical_id)
            )

        self.block_tables[request_id] = logical_blocks
        print(f"요청 {request_id}: {num_blocks_needed} 블록 할당, "
              f"잔여 블록: {len(self.free_blocks)}")

    def append_token(self, request_id: int, layer: int, token_pos: int, k: torch.Tensor, v: torch.Tensor):
        """새 토큰의 KV를 캐시에 추가"""
        block_idx = token_pos // self.block_size
        token_in_block = token_pos % self.block_size

        logical_block = self.block_tables[request_id][block_idx]
        physical_id = logical_block.physical_block_id

        # KV Cache에 저장
        self.kv_cache[physical_id, 0, layer, token_in_block] = k  # K
        self.kv_cache[physical_id, 1, layer, token_in_block] = v  # V
        logical_block.num_filled = token_in_block + 1

    def get_physical_block_ids(self, request_id: int) -> List[int]:
        """요청의 물리 블록 ID 목록 반환"""
        return [lb.physical_block_id for lb in self.block_tables[request_id]]

    def free_request(self, request_id: int):
        """요청 완료 후 블록 해제"""
        if request_id in self.block_tables:
            for logical_block in self.block_tables[request_id]:
                phys_id = logical_block.physical_block_id
                self.all_blocks[phys_id].ref_count -= 1

                if self.all_blocks[phys_id].ref_count == 0:
                    self.free_blocks.append(phys_id)

            del self.block_tables[request_id]

    def copy_on_write(self, src_request_id: int, dst_request_id: int):
        """Prefix Caching을 위한 Copy-on-Write"""
        src_blocks = self.block_tables[src_request_id]
        dst_blocks = []

        for logical_block in src_blocks:
            phys_id = logical_block.physical_block_id
            # 참조 카운트 증가 (실제로는 쓰기 시에만 복사)
            self.all_blocks[phys_id].ref_count += 1
            dst_blocks.append(
                LogicalBlock(physical_block_id=phys_id, num_filled=logical_block.num_filled)
            )

        self.block_tables[dst_request_id] = dst_blocks


# 사용 예시
manager = PagedKVCacheManager(
    num_physical_blocks=1000,
    block_size=16,
    num_layers=32,
    num_kv_heads=32,
    head_dim=128
)

# 3개의 요청 처리
manager.allocate_blocks_for_request(request_id=1, num_tokens=200)
manager.allocate_blocks_for_request(request_id=2, num_tokens=500)
manager.allocate_blocks_for_request(request_id=3, num_tokens=100)

# 요청 1 완료 후 해제
manager.free_request(request_id=1)
print(f"\n요청 1 완료 후 가용 블록: {len(manager.free_blocks)}")

4. Continuous Batching

4.1 정적 배치의 문제점

기존 배치 처리 방식은 요청들이 모두 완료될 때까지 기다립니다.

시간 t=0: [요청A: 500토큰] [요청B: 100토큰] [요청C: 300토큰]
시간 t=1: 요청B 완료, 하지만 A, C 대기 중이므로 GPU에 여유 있어도 새 요청 불가
시간 t=2: 요청C 완료
시간 t=3: 요청A 완료 → 이제야 새 배치 시작!

4.2 Continuous Batching (Iteration-level Scheduling)

from typing import List, Optional, Tuple
from dataclasses import dataclass
import asyncio
import torch
from queue import Queue
import threading

@dataclass
class Request:
    """추론 요청"""
    request_id: str
    input_ids: List[int]
    max_new_tokens: int
    generated_ids: List[int] = None
    is_finished: bool = False

    def __post_init__(self):
        self.generated_ids = []

class ContinuousBatchingScheduler:
    """Continuous Batching 스케줄러"""

    def __init__(
        self,
        max_batch_size: int = 32,
        max_seq_len: int = 4096
    ):
        self.max_batch_size = max_batch_size
        self.max_seq_len = max_seq_len

        self.waiting_queue: List[Request] = []
        self.running_requests: List[Request] = []
        self.finished_requests: List[Request] = []

    def add_request(self, request: Request):
        """새 요청 추가"""
        self.waiting_queue.append(request)

    def _can_add_request(self, request: Request) -> bool:
        """배치에 요청 추가 가능 여부 확인 (메모리 체크)"""
        current_batch_size = len(self.running_requests) + 1
        if current_batch_size > self.max_batch_size:
            return False

        # KV Cache 메모리 체크 (단순화)
        total_tokens = sum(
            len(r.input_ids) + len(r.generated_ids)
            for r in self.running_requests
        ) + len(request.input_ids)

        return total_tokens < self.max_seq_len * self.max_batch_size

    def schedule_iteration(self) -> Tuple[List[Request], List[str]]:
        """
        한 iteration을 위한 배치 스케줄링

        Returns:
            (실행할 요청들, 완료된 요청 ID들)
        """
        completed_ids = []

        # 완료된 요청 처리
        still_running = []
        for req in self.running_requests:
            if req.is_finished:
                self.finished_requests.append(req)
                completed_ids.append(req.request_id)
            else:
                still_running.append(req)
        self.running_requests = still_running

        # 대기 중인 요청을 배치에 추가 (핵심: 빈 슬롯을 즉시 채움)
        while self.waiting_queue and self._can_add_request(self.waiting_queue[0]):
            new_request = self.waiting_queue.pop(0)
            self.running_requests.append(new_request)
            print(f"배치에 요청 {new_request.request_id} 추가 "
                  f"(현재 배치 크기: {len(self.running_requests)})")

        return self.running_requests, completed_ids

    def simulate_one_step(self, model_forward_fn):
        """한 단계 시뮬레이션"""

        active_requests, completed = self.schedule_iteration()

        if not active_requests:
            return []

        # 현재 배치의 입력 준비
        # Prefill 요청: input_ids만 있는 경우
        # Decode 요청: 이전 KV Cache가 있는 경우
        batch_input_ids = []
        for req in active_requests:
            if len(req.generated_ids) == 0:
                # Prefill
                batch_input_ids.append(req.input_ids)
            else:
                # Decode (마지막 생성 토큰만)
                batch_input_ids.append([req.generated_ids[-1]])

        # 모델 실행 (실제로는 PagedAttention으로 처리)
        outputs = model_forward_fn(batch_input_ids)

        # 다음 토큰 처리
        for req, next_token_id in zip(active_requests, outputs):
            req.generated_ids.append(next_token_id)

            # 종료 조건 확인
            if (next_token_id == 2 or  # EOS token
                    len(req.generated_ids) >= req.max_new_tokens):
                req.is_finished = True

        return completed

# 처리량 비교 시뮬레이션
def simulate_throughput_comparison():
    """정적 배치 vs Continuous Batching 처리량 비교"""

    import random

    requests = [
        Request(
            request_id=str(i),
            input_ids=list(range(random.randint(50, 200))),
            max_new_tokens=random.randint(50, 500)
        )
        for i in range(20)
    ]

    # 정적 배치: 모든 요청이 완료될 때까지 기다림
    max_tokens_static = max(r.max_new_tokens for r in requests)
    total_iterations_static = max_tokens_static * 4  # 4개씩 배치

    # Continuous Batching: 완료되자마자 새 요청 추가
    total_tokens = sum(r.max_new_tokens for r in requests)
    total_iterations_cb = total_tokens  # 대략적인 추정

    print(f"정적 배치 예상 iteration 수: {total_iterations_static}")
    print(f"Continuous Batching 예상 iteration 수: {total_iterations_cb}")
    print(f"처리량 향상: {total_iterations_static / total_iterations_cb:.2f}x")

5. Speculative Decoding

5.1 아이디어: 초안 + 검증

Speculative Decoding의 핵심은 작은 드래프트 모델이 여러 토큰을 빠르게 생성하고, 큰 검증 모델이 한번에 검증하는 것입니다.

기존: [큰 모델] → 토큰1 → 토큰2 → 토큰3 → 토큰4 → 토큰5
투기적: [작은 모델]  (토큰1, 토큰2, 토큰3, 토큰4, 토큰5)를 병렬 생성
        [큰 모델]5개 토큰을 한번에 검증 (Prefill처럼 병렬!)
        수락된 토큰들만 사용

5.2 수락률 기반 속도 향상 분석

import numpy as np
import torch
from typing import List, Tuple

def speculative_decode_step(
    draft_model,
    target_model,
    input_ids: torch.Tensor,
    draft_steps: int = 4,
    temperature: float = 1.0
) -> Tuple[torch.Tensor, int, int]:
    """
    Speculative Decoding 한 단계

    Returns:
        (생성된 토큰들, 수락된 토큰 수, 드래프트 토큰 수)
    """
    batch_size = input_ids.size(0)

    # 1. 드래프트 모델로 후보 토큰 생성
    draft_tokens = []
    draft_probs = []

    current_ids = input_ids.clone()

    for _ in range(draft_steps):
        with torch.no_grad():
            draft_output = draft_model(current_ids)
            draft_logits = draft_output.logits[:, -1, :]  # [B, vocab_size]

        # 드래프트 확률 계산
        if temperature > 0:
            draft_prob = torch.softmax(draft_logits / temperature, dim=-1)
        else:
            draft_prob = torch.zeros_like(draft_logits)
            draft_prob.scatter_(1, draft_logits.argmax(dim=-1, keepdim=True), 1.0)

        # 드래프트 토큰 샘플링
        draft_token = torch.multinomial(draft_prob, num_samples=1)  # [B, 1]
        draft_tokens.append(draft_token)
        draft_probs.append(draft_prob)

        # 다음 스텝을 위해 토큰 추가
        current_ids = torch.cat([current_ids, draft_token], dim=1)

    # 드래프트 토큰들을 하나의 텐서로
    draft_sequence = torch.cat(draft_tokens, dim=1)  # [B, draft_steps]
    candidate_ids = torch.cat([input_ids, draft_sequence], dim=1)

    # 2. 검증 모델로 드래프트 토큰 한번에 검증
    with torch.no_grad():
        target_output = target_model(candidate_ids)
        target_logits = target_output.logits[:, input_ids.size(1) - 1:-1, :]  # [B, draft_steps, vocab_size]

    # 검증 모델의 확률
    if temperature > 0:
        target_probs = torch.softmax(target_logits / temperature, dim=-1)
    else:
        target_probs = torch.zeros_like(target_logits)
        target_probs.scatter_(2, target_logits.argmax(dim=-1, keepdim=True), 1.0)

    # 3. 각 드래프트 토큰 수락/거부 결정
    accepted_tokens = []
    num_accepted = 0

    for step in range(draft_steps):
        token = draft_sequence[:, step]  # [B]

        # 수락 확률 계산: min(1, p_target / p_draft)
        p_draft = draft_probs[step].gather(1, token.unsqueeze(1)).squeeze(1)
        p_target = target_probs[:, step, :].gather(1, token.unsqueeze(1)).squeeze(1)

        acceptance_prob = torch.clamp(p_target / (p_draft + 1e-8), max=1.0)

        # 랜덤 수락/거부
        random_val = torch.rand_like(acceptance_prob)
        accepted = random_val < acceptance_prob  # [B]

        if not accepted.all():
            # 첫 번째 거부된 토큰에서 중단
            break

        accepted_tokens.append(token)
        num_accepted += 1

    # 4. 마지막 토큰: 검증 모델이 생성 (또는 수정된 분포에서 샘플링)
    last_target_logits = target_output.logits[:, input_ids.size(1) + num_accepted - 1, :]

    if temperature > 0:
        last_prob = torch.softmax(last_target_logits / temperature, dim=-1)
    else:
        last_prob = torch.zeros_like(last_target_logits)
        last_prob.scatter_(1, last_target_logits.argmax(dim=-1, keepdim=True), 1.0)

    # 분포 수정 (거부된 경우)
    if num_accepted < draft_steps:
        # max(0, p_target - p_draft) 사용
        correction = torch.clamp(
            last_prob - draft_probs[num_accepted],
            min=0
        )
        correction = correction / (correction.sum(dim=-1, keepdim=True) + 1e-8)
        last_token = torch.multinomial(correction, num_samples=1)
    else:
        last_token = torch.multinomial(last_prob, num_samples=1)

    accepted_tokens.append(last_token.squeeze(1))

    final_tokens = torch.stack(accepted_tokens, dim=1)

    return final_tokens, num_accepted, draft_steps


def analyze_speedup(acceptance_rate: float, draft_steps: int = 4) -> dict:
    """수락률에 따른 속도 향상 분석"""

    # 기댓값 계산
    # E[accepted tokens] = sum_{k=0}^{K} alpha^k = (1 - alpha^{K+1}) / (1 - alpha)
    # 여기서 alpha = acceptance_rate, K = draft_steps

    expected_accepted = sum(
        acceptance_rate ** k for k in range(draft_steps + 1)
    )

    # 실제 속도 향상 (드래프트 모델 비용 포함)
    # 드래프트 모델이 검증 모델의 1/10 크기라 가정
    draft_model_ratio = 0.1

    # 시간: 드래프트 K 스텝 + 검증 1 스텝
    # 기존: K+1 스텝
    # 투기적: K * draft_ratio + 1 스텝 (검증)
    steps_with_speculative = draft_steps * draft_model_ratio + 1
    expected_tokens_with_speculative = expected_accepted

    speedup = expected_tokens_with_speculative / steps_with_speculative

    return {
        "acceptance_rate": acceptance_rate,
        "draft_steps": draft_steps,
        "expected_accepted_tokens": expected_accepted,
        "speedup": speedup
    }

# 수락률별 속도 향상 출력
print("Speculative Decoding 수락률별 속도 향상 (드래프트 K=4)")
print("=" * 55)
for alpha in [0.5, 0.6, 0.7, 0.8, 0.9, 0.95]:
    result = analyze_speedup(alpha, draft_steps=4)
    print(f"수락률 {alpha:.0%}: 기대 수락 {result['expected_accepted_tokens']:.2f}토큰, "
          f"속도향상 {result['speedup']:.2f}x")

5.3 Medusa: 멀티 드래프트 헤드

import torch
import torch.nn as nn

class MedusaHead(nn.Module):
    """
    Medusa: 단일 모델에 여러 드래프트 헤드 추가

    각 헤드가 미래 토큰을 예측:
    - Head 1: t+1 예측
    - Head 2: t+2 예측
    - Head N: t+N 예측
    """

    def __init__(
        self,
        hidden_size: int,
        vocab_size: int,
        num_heads: int = 4,
        hidden_layers: int = 1
    ):
        super().__init__()
        self.num_heads = num_heads

        # 각 미래 토큰 위치를 위한 독립적인 헤드
        self.heads = nn.ModuleList([
            nn.Sequential(
                *[nn.Linear(hidden_size, hidden_size, bias=False),
                  nn.SiLU()] * hidden_layers,
                nn.Linear(hidden_size, vocab_size, bias=False)
            )
            for _ in range(num_heads)
        ])

    def forward(self, hidden_states: torch.Tensor):
        """
        Args:
            hidden_states: [batch, seq, hidden_size] - 기본 모델의 마지막 히든 스테이트

        Returns:
            List of logits for each future position [batch, seq, vocab]
        """
        return [head(hidden_states) for head in self.heads]


class MedusaModel(nn.Module):
    """Medusa 전체 모델"""

    def __init__(self, base_model, vocab_size: int, num_medusa_heads: int = 4):
        super().__init__()
        self.base_model = base_model
        hidden_size = base_model.config.hidden_size

        self.medusa_heads = MedusaHead(
            hidden_size=hidden_size,
            vocab_size=vocab_size,
            num_heads=num_medusa_heads
        )

    def forward(self, input_ids: torch.Tensor, use_medusa: bool = False):
        # 기본 모델 실행
        base_output = self.base_model(
            input_ids,
            output_hidden_states=True
        )

        base_logits = base_output.logits

        if not use_medusa:
            return base_logits, None

        # Medusa 헤드로 미래 토큰 예측
        last_hidden_state = base_output.hidden_states[-1]
        medusa_logits = self.medusa_heads(last_hidden_state)

        return base_logits, medusa_logits

    def generate_with_medusa(
        self,
        input_ids: torch.Tensor,
        max_new_tokens: int = 100,
        temperature: float = 1.0,
        medusa_choices: int = 16,  # 후보 토큰 수
        threshold: float = 0.09    # 수락 임계값
    ):
        """Medusa를 이용한 빠른 생성"""

        current_ids = input_ids.clone()
        all_accepted = []

        while len(all_accepted) < max_new_tokens:
            # Medusa 헤드로 후보 토큰 예측
            base_logits, medusa_logits = self.forward(current_ids, use_medusa=True)

            # 각 위치에서 상위 후보들 선택
            candidates = []
            base_probs = torch.softmax(base_logits[:, -1, :] / temperature, dim=-1)
            top_tokens = torch.topk(base_probs, medusa_choices)[1]

            for head_logits in medusa_logits:
                head_probs = torch.softmax(head_logits[:, -1, :] / temperature, dim=-1)
                candidates.append(torch.topk(head_probs, medusa_choices)[1])

            # 트리 어텐션으로 후보들 검증 (단순화)
            # 실제로는 트리 마스크를 이용한 효율적인 검증
            best_token = top_tokens[0, 0]
            all_accepted.append(best_token.item())

            current_ids = torch.cat([current_ids, best_token.unsqueeze(0).unsqueeze(0)], dim=1)

            if best_token.item() == 2:  # EOS
                break

        return all_accepted

6. FlashAttention: 메모리 효율적인 어텐션

6.1 표준 어텐션의 HBM 병목

표준 어텐션은 HBM(High Bandwidth Memory)에 중간 결과를 자주 쓰고 읽습니다.

표준 Attention 메모리 연산:
1. Q, KHBM에서 읽음       → 읽기: O(N * d)
2. S = Q @ K.T 계산          → 쓰기: O(N^2) ← 병목!
3. SHBM에서 읽어 소프트맥스 → 읽기: O(N^2)
4. P = softmax(S) 저장        → 쓰기: O(N^2)
5. P를 읽어 P @ V 계산        → 읽기: O(N^2)
6. 최종 결과 저장             → 쓰기: O(N * d)

HBM 접근: O(N^2) (시퀀스 길이 제곱에 비례!)

6.2 FlashAttention의 타일링 전략

import torch
import math

def flash_attention_v1(Q, K, V, block_size=64):
    """
    FlashAttention v1 단순화 구현
    타일링을 이용하여 전체 어텐션 행렬을 HBM에 저장하지 않음

    핵심: Online Softmax 알고리즘으로 블록 단위 처리
    """
    batch_size, num_heads, seq_len, d_head = Q.shape

    scale = 1.0 / math.sqrt(d_head)
    Q = Q * scale

    # 출력 텐서 초기화 (SRAM에 유지)
    O = torch.zeros_like(Q)
    L = torch.zeros(batch_size, num_heads, seq_len, 1, device=Q.device)  # 소프트맥스 분모
    M = torch.full((batch_size, num_heads, seq_len, 1), float('-inf'), device=Q.device)  # 최대값

    num_blocks = (seq_len + block_size - 1) // block_size

    for j in range(num_blocks):
        # K, V 블록 로드 (HBM → SRAM)
        k_start = j * block_size
        k_end = min((j + 1) * block_size, seq_len)
        K_j = K[:, :, k_start:k_end, :]
        V_j = V[:, :, k_start:k_end, :]

        for i in range(num_blocks):
            # Q 블록 로드
            q_start = i * block_size
            q_end = min((i + 1) * block_size, seq_len)
            Q_i = Q[:, :, q_start:q_end, :]
            O_i = O[:, :, q_start:q_end, :]
            L_i = L[:, :, q_start:q_end, :]
            M_i = M[:, :, q_start:q_end, :]

            # 어텐션 스코어 계산 (SRAM에서)
            S_ij = torch.matmul(Q_i, K_j.transpose(-2, -1))  # [B, H, Br, Bc]

            # Online Softmax 업데이트
            M_ij_new = torch.maximum(M_i, S_ij.max(dim=-1, keepdim=True)[0])
            P_ij = torch.exp(S_ij - M_ij_new)
            L_ij_new = torch.exp(M_i - M_ij_new) * L_i + P_ij.sum(dim=-1, keepdim=True)

            # 출력 업데이트 (재스케일링)
            O_i_new = (
                torch.exp(M_i - M_ij_new) * O_i +
                torch.matmul(P_ij, V_j)
            )

            # 블록 결과를 HBM에 저장
            O[:, :, q_start:q_end, :] = O_i_new
            L[:, :, q_start:q_end, :] = L_ij_new
            M[:, :, q_start:q_end, :] = M_ij_new

    # 최종 정규화
    O = O / L

    return O


def compare_attention_implementations():
    """FlashAttention vs 표준 어텐션 비교"""

    batch_size = 2
    num_heads = 32
    seq_len = 4096
    d_head = 128

    Q = torch.randn(batch_size, num_heads, seq_len, d_head, device='cuda', dtype=torch.float16)
    K = torch.randn_like(Q)
    V = torch.randn_like(Q)

    # PyTorch SDPA (FlashAttention 2 구현 포함)
    import torch.nn.functional as F

    with torch.backends.cuda.sdp_kernel(
        enable_flash=True,
        enable_math=False,
        enable_mem_efficient=False
    ):
        flash_output = F.scaled_dot_product_attention(Q, K, V)

    # 표준 어텐션
    scale = 1.0 / math.sqrt(d_head)
    attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
    attn_weights = torch.softmax(attn_scores, dim=-1)
    standard_output = torch.matmul(attn_weights, V)

    # 결과 비교
    max_diff = (flash_output - standard_output).abs().max().item()
    print(f"FlashAttention vs 표준 어텐션 최대 차이: {max_diff:.6f}")

    # 메모리 사용량 비교
    standard_attn_matrix_size = batch_size * num_heads * seq_len * seq_len * 2  # FP16
    print(f"표준 어텐션 행렬 메모리: {standard_attn_matrix_size / 1e9:.2f} GB")
    print(f"FlashAttention 행렬 메모리: ~0 GB (타일링으로 저장 불필요)")

6.3 PyTorch SDPA 사용법

import torch
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel

def modern_attention(q, k, v, is_causal=True, dropout_p=0.0):
    """
    PyTorch 2.0+ scaled_dot_product_attention 사용

    FlashAttention 2/3를 자동으로 선택
    """

    # 자동 백엔드 선택 (Flash, Memory-efficient, Math)
    output = F.scaled_dot_product_attention(
        q, k, v,
        attn_mask=None,
        dropout_p=dropout_p,
        is_causal=is_causal,  # 인과적 마스킹
        scale=None  # None이면 1/sqrt(d_head) 사용
    )

    return output

# 특정 백엔드 강제 선택
def attention_with_flash_backend(q, k, v):
    with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
        return F.scaled_dot_product_attention(q, k, v, is_causal=True)

def attention_with_efficient_backend(q, k, v):
    with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
        return F.scaled_dot_product_attention(q, k, v, is_causal=True)


# FlashAttention 버전별 특징
flash_versions = {
    "FlashAttention 1": {
        "paper": "arXiv:2205.14135",
        "key_innovation": "타일링 + Online Softmax",
        "memory": "O(N) (어텐션 행렬 저장 불필요)",
        "speedup": "2-4x vs 표준 어텐션"
    },
    "FlashAttention 2": {
        "paper": "arXiv:2307.08691",
        "key_innovation": "작업 분할 최적화, FP16/BF16 지원",
        "memory": "O(N)",
        "speedup": "5-9x vs 표준 어텐션 (H100에서)"
    },
    "FlashAttention 3": {
        "paper": "arXiv:2407.08608",
        "key_innovation": "H100 특화, FP8 지원, 비동기 파이프라인",
        "memory": "O(N)",
        "speedup": "1.5-2x vs FA2 (H100에서)"
    },
}

for name, info in flash_versions.items():
    print(f"\n{name}")
    for k, v in info.items():
        print(f"  {k}: {v}")

7. 멀티 GPU 추론

7.1 Tensor Parallelism

가중치 행렬을 여러 GPU에 분산하여 각 GPU가 일부를 처리합니다.

import torch
import torch.distributed as dist

class TensorParallelLinear(torch.nn.Module):
    """
    Tensor Parallel Linear 레이어
    컬럼 분할 방식 (Column Parallel)
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        world_size: int,
        rank: int
    ):
        super().__init__()
        self.world_size = world_size
        self.rank = rank

        # 각 GPU는 out_features // world_size 개의 출력 뉴런을 담당
        self.local_out_features = out_features // world_size

        self.weight = torch.nn.Parameter(
            torch.randn(self.local_out_features, in_features) / (in_features ** 0.5)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 로컬 계산
        local_output = torch.nn.functional.linear(x, self.weight)

        # All-gather로 모든 GPU의 출력을 합침
        # (실제로는 분산 환경에서)
        # dist.all_gather(output_list, local_output)

        return local_output


def setup_tensor_parallel_llm(model_name: str, tp_size: int):
    """
    Tensor Parallel LLM 설정 예시 (vLLM 방식)

    vLLM 내부적으로 이 방식 사용
    """

    from vllm import LLM, SamplingParams

    # vLLM의 tensor_parallel_size 설정
    llm = LLM(
        model=model_name,
        tensor_parallel_size=tp_size,  # GPU 수
        gpu_memory_utilization=0.9
    )

    return llm

7.2 vLLM 완전 활용

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
import asyncio
import time

# vLLM 기본 사용
def vllm_basic_usage():
    """vLLM 기본 사용법"""

    llm = LLM(
        model="meta-llama/Llama-2-7b-hf",
        tensor_parallel_size=1,         # GPU 수
        gpu_memory_utilization=0.90,    # GPU 메모리 사용률
        max_model_len=4096,             # 최대 시퀀스 길이
        quantization=None,              # "awq", "gptq", "squeezellm"
        dtype="auto",                   # "float16", "bfloat16"
        max_num_seqs=256,               # 최대 동시 시퀀스 수
        enable_prefix_caching=True,     # 프리픽스 캐싱 활성화
        use_v2_block_manager=True,      # PagedAttention v2
    )

    sampling_params = SamplingParams(
        temperature=0.8,
        top_p=0.95,
        max_tokens=200,
        presence_penalty=0.0,
        frequency_penalty=0.0,
    )

    prompts = [
        "Explain quantum computing in simple terms",
        "What is the future of artificial intelligence?",
        "How does the human brain work?",
    ]

    # 배치 추론
    outputs = llm.generate(prompts, sampling_params)

    for output in outputs:
        print(f"Prompt: {output.prompt[:50]}...")
        print(f"Output: {output.outputs[0].text[:100]}...")
        print(f"Tokens generated: {len(output.outputs[0].token_ids)}")
        print()

    return outputs


# vLLM 비동기 서버
async def vllm_async_server():
    """vLLM 비동기 엔진 사용"""

    engine_args = AsyncEngineArgs(
        model="meta-llama/Llama-2-7b-hf",
        tensor_parallel_size=1,
        gpu_memory_utilization=0.90,
        max_model_len=4096,
        enable_prefix_caching=True,
    )

    engine = AsyncLLMEngine.from_engine_args(engine_args)

    async def generate_stream(prompt: str, request_id: str):
        sampling_params = SamplingParams(
            temperature=0.8,
            max_tokens=200
        )

        full_text = ""
        async for output in engine.generate(prompt, sampling_params, request_id):
            if output.outputs:
                delta = output.outputs[0].text[len(full_text):]
                full_text = output.outputs[0].text

                if delta:
                    print(f"[{request_id}] {delta}", end="", flush=True)

            if output.finished:
                print(f"\n[{request_id}] 완료")

    # 여러 요청 동시 처리
    await asyncio.gather(
        generate_stream("What is AI?", "req_1"),
        generate_stream("Explain machine learning", "req_2"),
        generate_stream("What is deep learning?", "req_3"),
    )

8. 추론 엔진 비교

8.1 주요 추론 엔진 특징

엔진개발사핵심 기능적합한 용도
vLLMUC BerkeleyPagedAttention, Continuous Batching범용 LLM 서빙
TGIHuggingFaceFlash Attention 2, SpeculativeHF 모델 서빙
TensorRT-LLMNVIDIANVIDIA GPU 최적화, FP8NVIDIA 최대 성능
DeepSpeed-MIIMicrosoftZeRO 추론, 초대형 모델다중 GPU 대형 모델
llama.cppGeorgi GerganovCPU 최적화, GGUF로컬 실행

8.2 벤치마크 비교

import subprocess
import json
import time
import requests

def benchmark_vllm_server(
    model: str,
    num_requests: int = 100,
    max_tokens: int = 100,
    concurrency: int = 10
):
    """vLLM 서버 벤치마크"""

    results = {
        "total_requests": num_requests,
        "concurrency": concurrency,
        "latencies": [],
        "ttfts": [],
        "throughputs": []
    }

    import asyncio
    import aiohttp

    async def send_request(session, prompt, request_id):
        start = time.perf_counter()
        first_token_time = None

        payload = {
            "model": model,
            "prompt": prompt,
            "max_tokens": max_tokens,
            "stream": True,
            "temperature": 0.0
        }

        async with session.post(
            "http://localhost:8000/v1/completions",
            json=payload
        ) as response:
            async for line in response.content:
                if line.startswith(b"data: "):
                    data = line[6:].decode()
                    if data.strip() == "[DONE]":
                        break

                    if first_token_time is None:
                        first_token_time = time.perf_counter() - start

        end = time.perf_counter()
        return {
            "latency": end - start,
            "ttft": first_token_time,
        }

    async def run_benchmark():
        prompts = [
            f"Tell me about topic number {i}." for i in range(num_requests)
        ]

        start = time.perf_counter()

        async with aiohttp.ClientSession() as session:
            # 동시 요청
            tasks = []
            for i, prompt in enumerate(prompts):
                if len(tasks) >= concurrency:
                    done, tasks = await asyncio.wait(
                        tasks, return_when=asyncio.FIRST_COMPLETED
                    )
                    for task in done:
                        result = await task
                        results["latencies"].append(result["latency"])
                        if result["ttft"]:
                            results["ttfts"].append(result["ttft"])

                tasks.add(asyncio.ensure_future(
                    send_request(session, prompt, i)
                ))

            # 남은 작업 처리
            for coro in asyncio.as_completed(tasks):
                result = await coro
                results["latencies"].append(result["latency"])

        total_time = time.perf_counter() - start
        total_tokens = num_requests * max_tokens
        results["throughput"] = total_tokens / total_time

    asyncio.run(run_benchmark())

    # 통계 계산
    import statistics
    latencies = results["latencies"]

    return {
        "avg_latency_ms": statistics.mean(latencies) * 1000,
        "p50_latency_ms": statistics.median(latencies) * 1000,
        "p99_latency_ms": sorted(latencies)[int(len(latencies) * 0.99)] * 1000,
        "avg_ttft_ms": statistics.mean(results["ttfts"]) * 1000 if results["ttfts"] else 0,
        "throughput_tps": results.get("throughput", 0),
    }

9. 프롬프트 캐싱

9.1 프리픽스 캐싱

동일한 시스템 프롬프트나 문서를 반복적으로 처리할 때 KV Cache를 재사용합니다.

from vllm import LLM, SamplingParams

def demonstrate_prefix_caching():
    """프리픽스 캐싱 효과 시연"""

    llm = LLM(
        model="meta-llama/Llama-2-7b-hf",
        enable_prefix_caching=True,  # 프리픽스 캐싱 활성화
        max_model_len=4096,
    )

    # 긴 시스템 프롬프트 (모든 요청에 공통)
    system_prompt = """You are a helpful AI assistant with expertise in:
    - Python programming and software development
    - Machine learning and deep learning
    - Data science and statistics
    - Cloud computing and DevOps
    [... 긴 시스템 프롬프트 ...]""" * 10  # 1000+ 토큰

    questions = [
        "How do I optimize a Python loop?",
        "What is gradient descent?",
        "Explain containerization.",
        "What is a neural network?",
    ]

    sampling_params = SamplingParams(temperature=0.7, max_tokens=100)

    # 첫 번째 배치: 캐시 없음 (콜드 스타트)
    import time

    cold_prompts = [f"{system_prompt}\n\nQuestion: {q}" for q in questions]

    cold_start = time.time()
    llm.generate(cold_prompts, sampling_params)
    cold_time = time.time() - cold_start

    # 두 번째 배치: 동일한 시스템 프롬프트 (캐시 히트!)
    warm_start = time.time()
    llm.generate(cold_prompts, sampling_params)
    warm_time = time.time() - warm_start

    print(f"첫 번째 (캐시 없음): {cold_time:.2f}초")
    print(f"두 번째 (캐시 히트): {warm_time:.2f}초")
    print(f"속도 향상: {cold_time / warm_time:.2f}x")


def radix_tree_prefix_cache():
    """Radix Tree 기반 프리픽스 캐시 구현"""

    class RadixNode:
        def __init__(self):
            self.children: dict = {}
            self.kv_cache_block_id: int = None

    class RadixTreeCache:
        """토큰 시퀀스를 Radix Tree로 관리하여 공통 프리픽스 KV 캐시 공유"""

        def __init__(self):
            self.root = RadixNode()
            self.cache_hits = 0
            self.cache_misses = 0

        def insert(self, token_ids: list, block_id: int):
            """토큰 시퀀스와 해당 KV Cache 블록 ID 삽입"""
            node = self.root
            for token_id in token_ids:
                if token_id not in node.children:
                    node.children[token_id] = RadixNode()
                node = node.children[token_id]
            node.kv_cache_block_id = block_id

        def lookup(self, token_ids: list) -> tuple:
            """주어진 토큰 시퀀스의 최장 일치 프리픽스 찾기"""
            node = self.root
            matched_len = 0
            last_block_id = None

            for i, token_id in enumerate(token_ids):
                if token_id in node.children:
                    node = node.children[token_id]
                    matched_len = i + 1
                    if node.kv_cache_block_id is not None:
                        last_block_id = node.kv_cache_block_id
                else:
                    break

            if last_block_id is not None:
                self.cache_hits += 1
            else:
                self.cache_misses += 1

            return matched_len, last_block_id

        def get_hit_rate(self) -> float:
            total = self.cache_hits + self.cache_misses
            return self.cache_hits / total if total > 0 else 0.0

    return RadixTreeCache()

10. 실전 최적화 체크리스트

10.1 단계별 최적화 가이드

class LLMOptimizationChecklist:
    """LLM 추론 최적화 체크리스트"""

    optimizations = [
        {
            "category": "기본 설정",
            "level": 1,
            "items": [
                {
                    "name": "FP16/BF16 사용",
                    "impact": "높음",
                    "effort": "낮음",
                    "description": "FP32 → FP16으로 메모리 2배 절약, 속도 향상",
                    "code": """
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,  # 또는 float16
    device_map="auto"
)"""
                },
                {
                    "name": "Flash Attention 2 활성화",
                    "impact": "높음",
                    "effort": "낮음",
                    "description": "어텐션 연산 2-4x 속도 향상, 메모리 절약",
                    "code": """
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16
)"""
                },
            ]
        },
        {
            "category": "KV Cache 최적화",
            "level": 2,
            "items": [
                {
                    "name": "GQA/MQA 모델 선택",
                    "impact": "높음",
                    "effort": "중간",
                    "description": "KV Cache 4-8배 감소, 더 많은 배치 처리 가능"
                },
                {
                    "name": "프리픽스 캐싱",
                    "impact": "중간",
                    "effort": "낮음",
                    "description": "공통 시스템 프롬프트 KV Cache 재사용"
                },
            ]
        },
        {
            "category": "배치 최적화",
            "level": 3,
            "items": [
                {
                    "name": "Continuous Batching (vLLM)",
                    "impact": "매우 높음",
                    "effort": "낮음",
                    "description": "처리량 2-5x 향상",
                    "code": """
from vllm import LLM, SamplingParams

llm = LLM(
    model=model_name,
    gpu_memory_utilization=0.90,
    enable_prefix_caching=True,
)"""
                },
            ]
        },
        {
            "category": "모델 최적화",
            "level": 4,
            "items": [
                {
                    "name": "AWQ 4-bit 양자화",
                    "impact": "높음",
                    "effort": "중간",
                    "description": "메모리 4x 감소, 속도 1.5-2x 향상",
                    "code": """
from awq import AutoAWQForCausalLM

model = AutoAWQForCausalLM.from_quantized(
    "model-awq-4bit",
    fuse_layers=True
)"""
                },
                {
                    "name": "Speculative Decoding",
                    "impact": "중간",
                    "effort": "높음",
                    "description": "2-3x 속도 향상 (적합한 드래프트 모델 필요)"
                },
            ]
        },
        {
            "category": "하드웨어 최적화",
            "level": 5,
            "items": [
                {
                    "name": "Tensor Parallelism",
                    "impact": "매우 높음",
                    "effort": "중간",
                    "description": "다중 GPU로 선형적 처리량 향상"
                },
                {
                    "name": "CUDA 그래프 캡처",
                    "impact": "중간",
                    "effort": "높음",
                    "description": "커널 론치 오버헤드 제거"
                },
            ]
        }
    ]

    @classmethod
    def print_checklist(cls):
        print("=" * 70)
        print("LLM 추론 최적화 단계별 체크리스트")
        print("=" * 70)

        for category in cls.optimizations:
            print(f"\n[레벨 {category['level']}] {category['category']}")
            print("-" * 50)

            for item in category['items']:
                impact_emoji = {"매우 높음": "★★★", "높음": "★★", "중간": "★", "낮음": "☆"}
                print(f"  ✓ {item['name']}")
                print(f"    효과: {impact_emoji.get(item['impact'], '?')} {item['impact']}")
                print(f"    설명: {item['description']}")

        print("\n추천 최적화 순서:")
        print("1. BF16/FP16 전환 (즉시, 무료)")
        print("2. Flash Attention 2 (즉시, 패키지 설치만)")
        print("3. vLLM으로 서빙 (처리량 극대화)")
        print("4. AWQ/GPTQ 4비트 양자화 (메모리 4배 절약)")
        print("5. Speculative Decoding (레이턴시 개선)")
        print("6. 멀티 GPU Tensor Parallelism (규모 확장)")

LLMOptimizationChecklist.print_checklist()

마무리

LLM 추론 최적화는 계층적 접근이 필요합니다.

핵심 요점 정리:

  1. KV Cache 이해: 메모리 사용량 공식 2 * layers * kv_heads * d_head * seq_len * dtype_bytes를 외우고, GQA/MQA로 KV Cache를 4-8배 줄이세요.

  2. PagedAttention: vLLM의 핵심 혁신으로, OS의 가상 메모리에서 영감 받아 KV Cache 단편화를 해결합니다.

  3. Continuous Batching: 요청이 완료되는 즉시 새 요청을 삽입하여 GPU 활용률을 극대화합니다.

  4. Speculative Decoding: 작은 드래프트 모델 + 큰 검증 모델 조합으로 2-3x 속도 향상이 가능합니다.

  5. FlashAttention: 어텐션 계산의 메모리 효율을 O(N^2)에서 O(N)으로 줄여 긴 컨텍스트를 가능하게 합니다.

프로덕션 배포 권고:

  • 소규모 서비스: vLLM + AWQ 4bit + 프리픽스 캐싱
  • 대규모 서비스: TensorRT-LLM 또는 vLLM + Tensor Parallelism
  • 최저 레이턴시 요구: Speculative Decoding + CUDA 그래프

참고 자료