Skip to content
Published on

LLM Inference Optimization Complete Guide: KV Cache, Speculative Decoding, Continuous Batching

Authors

Introduction

When you deploy a large language model (LLM) to production, you immediately face a challenge: inference speed and cost. A first query to a GPT-4 class model can take several seconds, and throughput degrades sharply as concurrent users grow.

This guide thoroughly explores the core techniques for LLM inference optimization. From KV Cache internals to PagedAttention, Speculative Decoding, FlashAttention, and the latest vLLM and TensorRT-LLM engines — we don't just cover how to use them, we understand why they work.


1. Understanding the LLM Inference Pipeline

1.1 Two Phases: Prefill and Decode

LLM text generation is split into two distinct phases.

Prefill Phase (Prompt Processing)

  • Processes all tokens of the input prompt simultaneously (parallel)
  • Generates and stores Key/Value caches at each layer
  • Compute-Bound: GPU compute throughput is the bottleneck
  • Directly affects TTFT (Time To First Token)

Decode Phase (Token Generation)

  • Generates one token at a time in an autoregressive manner
  • References KV Cache from all previously generated tokens
  • Memory-Bandwidth-Bound: HBM read speed is the bottleneck
  • Directly affects TPOT (Time Per Output Token)
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time

def measure_prefill_decode_time(model, tokenizer, prompt: str, max_new_tokens: int = 100):
    """Measure prefill and decode phase timing"""

    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    input_len = inputs["input_ids"].size(1)

    # Measure TTFT
    torch.cuda.synchronize()
    prefill_start = time.perf_counter()

    with torch.no_grad():
        first_output = model.generate(
            **inputs,
            max_new_tokens=1,
            do_sample=False
        )

    torch.cuda.synchronize()
    ttft = time.perf_counter() - prefill_start

    # Measure total generation
    torch.cuda.synchronize()
    total_start = time.perf_counter()

    with torch.no_grad():
        full_output = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False
        )

    torch.cuda.synchronize()
    total_time = time.perf_counter() - total_start

    output_tokens = full_output.size(1) - input_len
    decode_time = total_time - ttft
    tpot = decode_time / max(output_tokens - 1, 1)

    print(f"Input tokens: {input_len}")
    print(f"Output tokens: {output_tokens}")
    print(f"TTFT (first token latency): {ttft * 1000:.1f} ms")
    print(f"TPOT (per token time): {tpot * 1000:.1f} ms")
    print(f"Throughput: {output_tokens / total_time:.1f} tokens/sec")

    return ttft, tpot

1.2 Memory Bandwidth Analysis

Understanding why the decode phase is memory-bound:

def analyze_memory_bandwidth():
    """LLM inference memory bandwidth analysis"""

    # Example: Llama-2-7B config
    model_params = {
        "num_layers": 32,
        "hidden_size": 4096,
        "num_heads": 32,
        "head_dim": 128,
        "vocab_size": 32000,
    }

    dtype_bytes = 2  # FP16: 2 bytes

    # Weight memory
    attn_weight = 4 * model_params["hidden_size"] ** 2  # Q, K, V, O projections
    ffn_weight = 8 * model_params["hidden_size"] ** 2   # SwiGLU: Up, Gate, Down
    layer_weight = (attn_weight + ffn_weight) * dtype_bytes

    total_weight_bytes = layer_weight * model_params["num_layers"]
    total_weight_gb = total_weight_bytes / 1e9

    print(f"Model weights: {total_weight_gb:.2f} GB")

    # KV Cache memory (proportional to sequence length)
    seq_len = 2048
    kv_cache_per_token = (
        2 *
        model_params["num_layers"] *
        model_params["num_heads"] *
        model_params["head_dim"] *
        dtype_bytes
    )

    kv_cache_total = kv_cache_per_token * seq_len / 1e6
    print(f"KV Cache ({seq_len} tokens): {kv_cache_total:.2f} MB")
    print(f"KV Cache per token: {kv_cache_per_token} bytes")

    # A100 memory bandwidth: 2 TB/s
    memory_bandwidth_tbs = 2.0

    # Decode: weights loaded once per token
    memory_bound_tps = memory_bandwidth_tbs * 1e12 / total_weight_bytes
    # A100 FP16: 312 TFLOPS
    compute_throughput = 312e12
    flops_per_token = 2 * total_weight_bytes
    compute_bound_tps = compute_throughput / flops_per_token

    print(f"\nFor batch_size=1:")
    print(f"Memory-bound throughput: {memory_bound_tps:.1f} tokens/sec")
    print(f"Compute-bound throughput: {compute_bound_tps:.1f} tokens/sec")
    print(f"Actual bottleneck: {'memory' if memory_bound_tps < compute_bound_tps else 'compute'}")

analyze_memory_bandwidth()

1.3 Inference Cost Analysis

def estimate_inference_cost(
    model_size_b: float,
    tokens_per_request: int,
    requests_per_day: int,
    gpu_cost_per_hour: float = 3.0  # A100 hourly price (USD)
):
    """Estimate inference cost"""

    # Empirical throughput estimates
    # 7B model: ~100 tok/s, 70B model: ~20 tok/s (A100)
    throughput_tps = 100 / (model_size_b / 7) ** 0.6

    total_tokens_per_day = tokens_per_request * requests_per_day
    seconds_needed = total_tokens_per_day / throughput_tps
    hours_needed = seconds_needed / 3600

    daily_cost = hours_needed * gpu_cost_per_hour
    cost_per_1k_tokens = daily_cost / (total_tokens_per_day / 1000)

    print(f"Model: {model_size_b}B parameters")
    print(f"Daily requests: {requests_per_day:,}")
    print(f"Tokens per request: {tokens_per_request}")
    print(f"Total daily tokens: {total_tokens_per_day:,}")
    print(f"Estimated throughput: {throughput_tps:.1f} tokens/sec")
    print(f"GPU hours needed: {hours_needed:.2f}")
    print(f"Daily cost: ${daily_cost:.2f}")
    print(f"Cost per 1K tokens: ${cost_per_1k_tokens:.4f}")

estimate_inference_cost(
    model_size_b=7.0,
    tokens_per_request=500,
    requests_per_day=10000
)

2. KV Cache: The Core Optimization

2.1 Why KV Cache is Necessary

In transformer attention, each token computes attention against all previous tokens. To avoid recomputing already-processed tokens, we cache the K and V matrices.

import torch
import torch.nn as nn
import math

class MultiHeadAttentionWithKVCache(nn.Module):
    """Multi-head attention with KV Cache support"""

    def __init__(self, d_model: int, num_heads: int, max_seq_len: int = 4096):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

        # KV Cache initialization
        self.register_buffer(
            'k_cache',
            torch.zeros(1, max_seq_len, num_heads, self.d_head)
        )
        self.register_buffer(
            'v_cache',
            torch.zeros(1, max_seq_len, num_heads, self.d_head)
        )
        self.cache_pos = 0

    def forward(self, x: torch.Tensor, use_cache: bool = True, position: int = None):
        batch_size, seq_len, _ = x.shape

        q = self.W_q(x).reshape(batch_size, seq_len, self.num_heads, self.d_head)
        k = self.W_k(x).reshape(batch_size, seq_len, self.num_heads, self.d_head)
        v = self.W_v(x).reshape(batch_size, seq_len, self.num_heads, self.d_head)

        if use_cache:
            start_pos = self.cache_pos if position is None else position
            self.k_cache[:, start_pos:start_pos + seq_len] = k
            self.v_cache[:, start_pos:start_pos + seq_len] = v

            if position is None:
                self.cache_pos += seq_len

            total_len = self.cache_pos if position is None else start_pos + seq_len
            k = self.k_cache[:, :total_len]
            v = self.v_cache[:, :total_len]

        scale = math.sqrt(self.d_head)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale
        attn_weights = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_weights, v)
        output = output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
        output = self.W_o(output)

        return output

    def clear_cache(self):
        self.k_cache.zero_()
        self.v_cache.zero_()
        self.cache_pos = 0

2.2 KV Cache Memory Calculation

def calculate_kv_cache_memory(
    model_config: dict,
    batch_size: int,
    seq_len: int,
    dtype_bytes: int = 2  # FP16
) -> dict:
    """Calculate KV Cache memory usage"""

    num_layers = model_config["num_layers"]
    num_kv_heads = model_config.get("num_kv_heads", model_config["num_heads"])
    head_dim = model_config["head_dim"]

    # Formula: 2 (K+V) * layers * kv_heads * head_dim * seq_len * batch * dtype
    kv_cache_bytes = (
        2 *
        num_layers *
        num_kv_heads *
        head_dim *
        seq_len *
        batch_size *
        dtype_bytes
    )

    return {
        "kv_cache_bytes": kv_cache_bytes,
        "kv_cache_mb": kv_cache_bytes / 1e6,
        "kv_cache_gb": kv_cache_bytes / 1e9,
        "per_token_bytes": kv_cache_bytes // seq_len,
    }

models = {
    "Llama-2-7B (MHA)": {
        "num_layers": 32, "num_heads": 32,
        "num_kv_heads": 32, "head_dim": 128
    },
    "Llama-2-70B (GQA)": {
        "num_layers": 80, "num_heads": 64,
        "num_kv_heads": 8, "head_dim": 128
    },
    "Mistral-7B (GQA)": {
        "num_layers": 32, "num_heads": 32,
        "num_kv_heads": 8, "head_dim": 128
    },
}

print("KV Cache memory (batch=1, seq=4096)")
print("=" * 65)
for name, config in models.items():
    result = calculate_kv_cache_memory(config, batch_size=1, seq_len=4096)
    print(f"{name:<25} {result['kv_cache_gb']:.2f} GB  "
          f"({result['per_token_bytes']:,} bytes/token)")

2.3 Grouped Query Attention (GQA)

GQA is a key technique for reducing KV Cache — multiple query heads share fewer KV heads.

import torch
import torch.nn as nn
import math

class GroupedQueryAttention(nn.Module):
    """Grouped Query Attention (GQA) implementation"""

    def __init__(self, d_model: int, num_q_heads: int, num_kv_heads: int):
        super().__init__()
        assert num_q_heads % num_kv_heads == 0

        self.d_model = d_model
        self.num_q_heads = num_q_heads
        self.num_kv_heads = num_kv_heads
        self.num_groups = num_q_heads // num_kv_heads
        self.d_head = d_model // num_q_heads

        self.W_q = nn.Linear(d_model, num_q_heads * self.d_head, bias=False)
        self.W_k = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
        self.W_v = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x: torch.Tensor, kv_cache=None):
        batch_size, seq_len, _ = x.shape

        q = self.W_q(x).reshape(batch_size, seq_len, self.num_q_heads, self.d_head)
        k = self.W_k(x).reshape(batch_size, seq_len, self.num_kv_heads, self.d_head)
        v = self.W_v(x).reshape(batch_size, seq_len, self.num_kv_heads, self.d_head)

        if kv_cache is not None:
            k = torch.cat([kv_cache["k"], k], dim=1)
            v = torch.cat([kv_cache["v"], v], dim=1)
        new_kv_cache = {"k": k, "v": v}

        q = q.transpose(1, 2)   # [B, Q_heads, S, d_head]
        k = k.transpose(1, 2)   # [B, KV_heads, T, d_head]
        v = v.transpose(1, 2)

        # GQA: expand KV heads to match Q heads
        k = k.repeat_interleave(self.num_groups, dim=1)
        v = v.repeat_interleave(self.num_groups, dim=1)

        scale = math.sqrt(self.d_head)
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale
        attn_weights = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_weights, v)
        output = output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)

        return self.W_o(output), new_kv_cache

def compare_attention_variants():
    """Compare KV Cache memory across attention variants"""

    # 70B model (Llama-2-70B)
    num_layers = 80
    d_head = 128
    seq_len = 4096
    dtype_bytes = 2

    variants = {
        "MHA (32 KV heads)": 32,
        "GQA (8 KV heads)": 8,
        "MQA (1 KV head)": 1,
    }

    print("70B model attention variant KV Cache comparison")
    print(f"(seq_len={seq_len}, batch=1)")
    print("=" * 55)

    for name, num_kv_heads in variants.items():
        kv_bytes = 2 * num_layers * num_kv_heads * d_head * seq_len * dtype_bytes
        kv_gb = kv_bytes / 1e9
        print(f"{name:<25} {kv_gb:.2f} GB")

compare_attention_variants()

2.4 DeepSeek MLA (Multi-Head Latent Attention)

MLA, introduced in DeepSeek-V2, compresses KV Cache into low-dimensional latent vectors.

class MultiHeadLatentAttention(nn.Module):
    """
    DeepSeek MLA — compresses KV Cache to low-rank latent vectors

    Core idea:
    - Instead of storing high-dim K,V, store low-dim latent c_kv
    - Recover K, V from c_kv via up-projection
    - KV Cache size: num_layers * kv_lora_rank * seq_len
      (vs standard: 2 * num_layers * num_kv_heads * d_head * seq_len)
    """

    def __init__(
        self,
        d_model: int = 5120,
        num_heads: int = 128,
        kv_lora_rank: int = 512,
        qk_nope_head_dim: int = 128,
        qk_rope_head_dim: int = 64,
        v_head_dim: int = 128,
    ):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.kv_lora_rank = kv_lora_rank

        # Q projection (LoRA-style)
        self.q_a_proj = nn.Linear(d_model, 1536, bias=False)
        self.q_b_proj = nn.Linear(
            1536,
            num_heads * (qk_nope_head_dim + qk_rope_head_dim),
            bias=False
        )

        # KV down-projection: d_model -> kv_lora_rank
        # Only THIS is stored in KV Cache!
        self.kv_a_proj = nn.Linear(
            d_model,
            kv_lora_rank + qk_rope_head_dim,
            bias=False
        )

        # KV up-projection: kv_lora_rank -> K, V
        self.kv_b_proj = nn.Linear(
            kv_lora_rank,
            num_heads * (qk_nope_head_dim + v_head_dim),
            bias=False
        )

        self.o_proj = nn.Linear(num_heads * v_head_dim, d_model, bias=False)

    def forward(self, x: torch.Tensor, compressed_kv_cache=None):
        batch_size, seq_len, _ = x.shape

        # KV compression (only this result goes in cache)
        kv_compressed = self.kv_a_proj(x)  # [B, S, kv_lora_rank + rope_dim]

        if compressed_kv_cache is not None:
            kv_compressed_total = torch.cat([compressed_kv_cache, kv_compressed], dim=1)
        else:
            kv_compressed_total = kv_compressed

        # Recover full K, V from cached compressed representation
        kv_content = kv_compressed_total[:, :, :self.kv_lora_rank]
        kv_full = self.kv_b_proj(kv_content)

        return None, kv_compressed

def compare_mla_vs_mha():
    """Compare MLA vs MHA KV Cache"""

    seq_len = 4096
    dtype_bytes = 2  # BF16
    num_layers = 60  # DeepSeek-V2

    num_heads = 128
    head_dim = 128
    mha_kv_gb = 2 * num_layers * num_heads * head_dim * seq_len * dtype_bytes / 1e9

    kv_lora_rank = 512
    rope_dim = 64
    mla_kv_gb = (kv_lora_rank + rope_dim) * num_layers * seq_len * dtype_bytes / 1e9

    print(f"MHA KV Cache: {mha_kv_gb:.2f} GB")
    print(f"MLA KV Cache: {mla_kv_gb:.2f} GB")
    print(f"Reduction: {mha_kv_gb / mla_kv_gb:.1f}x")

compare_mla_vs_mha()

3. PagedAttention: vLLM's Core Innovation

3.1 The Problem with Conventional KV Cache

Traditional LLM serving systems pre-allocate memory for the maximum sequence length per request:

Request 1: [PROMPT=200 tokens] [KV_CACHE=up to 1848 tokens reserved] → internal fragmentation
Request 2: [PROMPT=100 tokens] [KV_CACHE=1948 tokens reserved]
Request 3: Blocked waiting (external fragmentation)

This wastes 60–80% of GPU memory.

3.2 PagedAttention: How It Works

Inspired by OS virtual memory, PagedAttention manages KV Cache in fixed-size physical blocks.

from dataclasses import dataclass, field
from typing import Dict, List, Optional
import torch

@dataclass
class PhysicalBlock:
    """Physical memory block"""
    block_id: int
    block_size: int  # tokens per block (e.g., 16)
    ref_count: int = 0

@dataclass
class LogicalBlock:
    """Logical block (mapped to a request)"""
    physical_block_id: int
    num_filled: int = 0

class PagedKVCacheManager:
    """PagedAttention KV Cache manager"""

    def __init__(
        self,
        num_physical_blocks: int,
        block_size: int,
        num_layers: int,
        num_kv_heads: int,
        head_dim: int,
        device: str = "cuda"
    ):
        self.block_size = block_size
        self.num_layers = num_layers

        self.free_blocks: List[int] = list(range(num_physical_blocks))
        self.all_blocks: Dict[int, PhysicalBlock] = {
            i: PhysicalBlock(block_id=i, block_size=block_size)
            for i in range(num_physical_blocks)
        }
        self.block_tables: Dict[int, List[LogicalBlock]] = {}

        # Actual KV Cache tensor pool
        self.kv_cache = torch.zeros(
            num_physical_blocks, 2, num_layers, block_size, num_kv_heads, head_dim,
            dtype=torch.float16,
            device=device
        )

    def allocate_blocks_for_request(self, request_id: int, num_tokens: int):
        """Allocate blocks for a request"""
        num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size

        if len(self.free_blocks) < num_blocks_needed:
            raise RuntimeError(
                f"OOM: need {num_blocks_needed} blocks, {len(self.free_blocks)} available"
            )

        logical_blocks = []
        for _ in range(num_blocks_needed):
            physical_id = self.free_blocks.pop(0)
            self.all_blocks[physical_id].ref_count = 1
            logical_blocks.append(LogicalBlock(physical_block_id=physical_id))

        self.block_tables[request_id] = logical_blocks
        print(f"Request {request_id}: {num_blocks_needed} blocks allocated, "
              f"{len(self.free_blocks)} remaining")

    def append_token(self, request_id: int, layer: int, token_pos: int,
                     k: torch.Tensor, v: torch.Tensor):
        """Store KV for a new token in the cache"""
        block_idx = token_pos // self.block_size
        token_in_block = token_pos % self.block_size

        logical_block = self.block_tables[request_id][block_idx]
        physical_id = logical_block.physical_block_id

        self.kv_cache[physical_id, 0, layer, token_in_block] = k
        self.kv_cache[physical_id, 1, layer, token_in_block] = v
        logical_block.num_filled = token_in_block + 1

    def free_request(self, request_id: int):
        """Free blocks after request completes"""
        if request_id in self.block_tables:
            for logical_block in self.block_tables[request_id]:
                phys_id = logical_block.physical_block_id
                self.all_blocks[phys_id].ref_count -= 1

                if self.all_blocks[phys_id].ref_count == 0:
                    self.free_blocks.append(phys_id)

            del self.block_tables[request_id]

    def copy_on_write(self, src_request_id: int, dst_request_id: int):
        """Copy-on-Write for prefix caching"""
        src_blocks = self.block_tables[src_request_id]
        dst_blocks = []

        for logical_block in src_blocks:
            phys_id = logical_block.physical_block_id
            self.all_blocks[phys_id].ref_count += 1
            dst_blocks.append(
                LogicalBlock(
                    physical_block_id=phys_id,
                    num_filled=logical_block.num_filled
                )
            )

        self.block_tables[dst_request_id] = dst_blocks


# Demo
manager = PagedKVCacheManager(
    num_physical_blocks=1000,
    block_size=16,
    num_layers=32,
    num_kv_heads=32,
    head_dim=128
)

manager.allocate_blocks_for_request(request_id=1, num_tokens=200)
manager.allocate_blocks_for_request(request_id=2, num_tokens=500)
manager.allocate_blocks_for_request(request_id=3, num_tokens=100)

manager.free_request(request_id=1)
print(f"\nAfter request 1 completes — free blocks: {len(manager.free_blocks)}")

4. Continuous Batching

4.1 The Problem with Static Batching

Static batching waits for all requests in a batch to complete before starting the next:

t=0: [Request A: 500 tokens] [Request B: 100 tokens] [Request C: 300 tokens]
t=1: Request B done — but can't start new requests until A and C finish
t=2: Request C done
t=3: Request A done → only now can a new batch start

4.2 Continuous Batching (Iteration-level Scheduling)

from dataclasses import dataclass, field
from typing import List, Tuple
import time

@dataclass
class Request:
    """Inference request"""
    request_id: str
    input_ids: List[int]
    max_new_tokens: int
    generated_ids: List[int] = field(default_factory=list)
    is_finished: bool = False

class ContinuousBatchingScheduler:
    """Continuous Batching scheduler"""

    def __init__(self, max_batch_size: int = 32, max_seq_len: int = 4096):
        self.max_batch_size = max_batch_size
        self.max_seq_len = max_seq_len

        self.waiting_queue: List[Request] = []
        self.running_requests: List[Request] = []
        self.finished_requests: List[Request] = []

    def add_request(self, request: Request):
        self.waiting_queue.append(request)

    def _can_add_request(self, request: Request) -> bool:
        if len(self.running_requests) + 1 > self.max_batch_size:
            return False

        total_tokens = sum(
            len(r.input_ids) + len(r.generated_ids)
            for r in self.running_requests
        ) + len(request.input_ids)

        return total_tokens < self.max_seq_len * self.max_batch_size

    def schedule_iteration(self) -> Tuple[List[Request], List[str]]:
        """
        Schedule one iteration's batch

        Returns:
            (active requests, list of just-completed request IDs)
        """
        completed_ids = []

        still_running = []
        for req in self.running_requests:
            if req.is_finished:
                self.finished_requests.append(req)
                completed_ids.append(req.request_id)
            else:
                still_running.append(req)
        self.running_requests = still_running

        # Fill empty slots immediately with waiting requests
        while self.waiting_queue and self._can_add_request(self.waiting_queue[0]):
            new_request = self.waiting_queue.pop(0)
            self.running_requests.append(new_request)
            print(f"Added request {new_request.request_id} to batch "
                  f"(batch size: {len(self.running_requests)})")

        return self.running_requests, completed_ids

    def simulate_one_step(self, model_forward_fn):
        """Simulate one step"""
        active_requests, completed = self.schedule_iteration()

        if not active_requests:
            return []

        batch_input_ids = []
        for req in active_requests:
            if len(req.generated_ids) == 0:
                batch_input_ids.append(req.input_ids)  # Prefill
            else:
                batch_input_ids.append([req.generated_ids[-1]])  # Decode

        outputs = model_forward_fn(batch_input_ids)

        for req, next_token_id in zip(active_requests, outputs):
            req.generated_ids.append(next_token_id)

            if next_token_id == 2 or len(req.generated_ids) >= req.max_new_tokens:
                req.is_finished = True

        return completed

5. Speculative Decoding

5.1 The Idea: Draft + Verify

Speculative Decoding's core: a small draft model generates several tokens in parallel, and the large target model verifies them all at once (like a prefill).

Standard: [large model] → token1 → token2 → token3 → token4 → token5
Speculative: [small model]generate (t1, t2, t3, t4, t5) in parallel
             [large model] → verify all 5 tokens at once (parallel, like prefill)
             only accepted tokens are kept

5.2 Acceptance Rate and Speedup Analysis

import torch
from typing import List, Tuple

def speculative_decode_step(
    draft_model,
    target_model,
    input_ids: torch.Tensor,
    draft_steps: int = 4,
    temperature: float = 1.0
) -> Tuple[torch.Tensor, int, int]:
    """
    One step of Speculative Decoding

    Returns:
        (generated tokens, num_accepted, num_drafted)
    """
    # 1. Draft model generates candidate tokens
    draft_tokens = []
    draft_probs = []
    current_ids = input_ids.clone()

    for _ in range(draft_steps):
        with torch.no_grad():
            draft_logits = draft_model(current_ids).logits[:, -1, :]

        draft_prob = torch.softmax(draft_logits / (temperature + 1e-8), dim=-1)
        draft_token = torch.multinomial(draft_prob, num_samples=1)
        draft_tokens.append(draft_token)
        draft_probs.append(draft_prob)
        current_ids = torch.cat([current_ids, draft_token], dim=1)

    draft_sequence = torch.cat(draft_tokens, dim=1)
    candidate_ids = torch.cat([input_ids, draft_sequence], dim=1)

    # 2. Target model verifies all draft tokens in one pass
    with torch.no_grad():
        target_logits = target_model(candidate_ids).logits[
            :, input_ids.size(1) - 1:-1, :
        ]

    target_probs = torch.softmax(target_logits / (temperature + 1e-8), dim=-1)

    # 3. Accept/reject each draft token
    accepted_tokens = []
    num_accepted = 0

    for step in range(draft_steps):
        token = draft_sequence[:, step]
        p_draft = draft_probs[step].gather(1, token.unsqueeze(1)).squeeze(1)
        p_target = target_probs[:, step, :].gather(1, token.unsqueeze(1)).squeeze(1)

        acceptance_prob = torch.clamp(p_target / (p_draft + 1e-8), max=1.0)
        accepted = torch.rand_like(acceptance_prob) < acceptance_prob

        if not accepted.all():
            break

        accepted_tokens.append(token)
        num_accepted += 1

    # 4. Final token from target model
    last_logits = target_model(candidate_ids).logits[:, input_ids.size(1) + num_accepted - 1, :]
    last_prob = torch.softmax(last_logits / (temperature + 1e-8), dim=-1)

    if num_accepted < draft_steps:
        correction = torch.clamp(last_prob - draft_probs[num_accepted], min=0)
        correction = correction / (correction.sum(dim=-1, keepdim=True) + 1e-8)
        last_token = torch.multinomial(correction, num_samples=1)
    else:
        last_token = torch.multinomial(last_prob, num_samples=1)

    accepted_tokens.append(last_token.squeeze(1))
    final_tokens = torch.stack(accepted_tokens, dim=1)

    return final_tokens, num_accepted, draft_steps


def analyze_speedup(acceptance_rate: float, draft_steps: int = 4) -> dict:
    """Speedup analysis based on acceptance rate"""

    # E[accepted tokens] = sum_{k=0}^{K} alpha^k
    expected_accepted = sum(
        acceptance_rate ** k for k in range(draft_steps + 1)
    )

    # Draft model is 1/10 the size of target model
    draft_model_ratio = 0.1

    steps_with_speculative = draft_steps * draft_model_ratio + 1
    speedup = expected_accepted / steps_with_speculative

    return {
        "acceptance_rate": acceptance_rate,
        "expected_accepted": expected_accepted,
        "speedup": speedup
    }

print("Speculative Decoding speedup by acceptance rate (K=4)")
print("=" * 55)
for alpha in [0.5, 0.6, 0.7, 0.8, 0.9, 0.95]:
    result = analyze_speedup(alpha, draft_steps=4)
    print(f"Accept rate {alpha:.0%}: expected {result['expected_accepted']:.2f} tokens, "
          f"speedup {result['speedup']:.2f}x")

5.3 Medusa: Multiple Draft Heads

import torch
import torch.nn as nn

class MedusaHead(nn.Module):
    """
    Medusa: attach multiple draft heads to a single model

    Each head predicts a future token:
    - Head 1: predicts t+1
    - Head 2: predicts t+2
    - Head N: predicts t+N
    """

    def __init__(
        self,
        hidden_size: int,
        vocab_size: int,
        num_heads: int = 4,
        hidden_layers: int = 1
    ):
        super().__init__()
        self.num_heads = num_heads

        self.heads = nn.ModuleList([
            nn.Sequential(
                *[nn.Linear(hidden_size, hidden_size, bias=False),
                  nn.SiLU()] * hidden_layers,
                nn.Linear(hidden_size, vocab_size, bias=False)
            )
            for _ in range(num_heads)
        ])

    def forward(self, hidden_states: torch.Tensor):
        """
        Returns list of logits for each future position
        """
        return [head(hidden_states) for head in self.heads]


class MedusaModel(nn.Module):
    """Medusa full model"""

    def __init__(self, base_model, vocab_size: int, num_medusa_heads: int = 4):
        super().__init__()
        self.base_model = base_model
        hidden_size = base_model.config.hidden_size
        self.medusa_heads = MedusaHead(hidden_size, vocab_size, num_medusa_heads)

    def forward(self, input_ids: torch.Tensor, use_medusa: bool = False):
        base_output = self.base_model(input_ids, output_hidden_states=True)
        base_logits = base_output.logits

        if not use_medusa:
            return base_logits, None

        last_hidden = base_output.hidden_states[-1]
        medusa_logits = self.medusa_heads(last_hidden)

        return base_logits, medusa_logits

6. FlashAttention: Memory-Efficient Attention

6.1 Standard Attention's HBM Bottleneck

Standard attention repeatedly writes and reads intermediate results to HBM:

Standard Attention memory ops:
1. Read Q, K from HBM           → read:  O(N * d)
2. Compute S = Q @ K.T          → write: O(N^2)  ← bottleneck!
3. Read S from HBM for softmax  → read:  O(N^2)
4. Store P = softmax(S)         → write: O(N^2)
5. Read P for P @ V             → read:  O(N^2)
6. Store final output           → write: O(N * d)

Total HBM accesses: O(N^2) — quadratic in sequence length!

6.2 FlashAttention's Tiling Strategy

import torch
import math

def flash_attention_v1(Q, K, V, block_size=64):
    """
    FlashAttention v1 simplified implementation
    Avoids storing the full attention matrix in HBM via tiling

    Key: Online Softmax for block-wise processing
    """
    batch_size, num_heads, seq_len, d_head = Q.shape

    scale = 1.0 / math.sqrt(d_head)
    Q = Q * scale

    O = torch.zeros_like(Q)
    L = torch.zeros(batch_size, num_heads, seq_len, 1, device=Q.device)
    M = torch.full((batch_size, num_heads, seq_len, 1), float('-inf'), device=Q.device)

    num_blocks = (seq_len + block_size - 1) // block_size

    for j in range(num_blocks):
        k_start = j * block_size
        k_end = min((j + 1) * block_size, seq_len)
        K_j = K[:, :, k_start:k_end, :]
        V_j = V[:, :, k_start:k_end, :]

        for i in range(num_blocks):
            q_start = i * block_size
            q_end = min((i + 1) * block_size, seq_len)
            Q_i = Q[:, :, q_start:q_end, :]
            O_i = O[:, :, q_start:q_end, :]
            L_i = L[:, :, q_start:q_end, :]
            M_i = M[:, :, q_start:q_end, :]

            # Compute attention scores in SRAM
            S_ij = torch.matmul(Q_i, K_j.transpose(-2, -1))

            # Online Softmax update
            M_new = torch.maximum(M_i, S_ij.max(dim=-1, keepdim=True)[0])
            P_ij = torch.exp(S_ij - M_new)
            L_new = torch.exp(M_i - M_new) * L_i + P_ij.sum(dim=-1, keepdim=True)

            # Rescale and update output
            O_new = torch.exp(M_i - M_new) * O_i + torch.matmul(P_ij, V_j)

            O[:, :, q_start:q_end, :] = O_new
            L[:, :, q_start:q_end, :] = L_new
            M[:, :, q_start:q_end, :] = M_new

    return O / L


def compare_attention_implementations():
    """Compare FlashAttention vs standard attention"""

    batch_size, num_heads, seq_len, d_head = 2, 32, 4096, 128

    Q = torch.randn(batch_size, num_heads, seq_len, d_head, device='cuda', dtype=torch.float16)
    K = torch.randn_like(Q)
    V = torch.randn_like(Q)

    import torch.nn.functional as F
    from torch.nn.attention import SDPBackend, sdpa_kernel

    with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
        flash_output = F.scaled_dot_product_attention(Q, K, V)

    scale = 1.0 / math.sqrt(d_head)
    attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
    attn_weights = torch.softmax(attn_scores, dim=-1)
    standard_output = torch.matmul(attn_weights, V)

    max_diff = (flash_output - standard_output).abs().max().item()
    print(f"FlashAttention vs standard max diff: {max_diff:.6f}")

    standard_attn_bytes = batch_size * num_heads * seq_len * seq_len * 2  # FP16
    print(f"Standard attention matrix memory: {standard_attn_bytes / 1e9:.2f} GB")
    print(f"FlashAttention matrix memory: ~0 GB (tiled, never fully materialized)")

6.3 Using PyTorch SDPA

import torch
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel

def modern_attention(q, k, v, is_causal=True, dropout_p=0.0):
    """
    PyTorch 2.0+ scaled_dot_product_attention

    Automatically selects FlashAttention 2/3
    """
    return F.scaled_dot_product_attention(
        q, k, v,
        attn_mask=None,
        dropout_p=dropout_p,
        is_causal=is_causal,
        scale=None  # defaults to 1/sqrt(d_head)
    )

# Flash Attention version highlights
flash_versions = {
    "FlashAttention 1 (arXiv:2205.14135)": {
        "key_innovation": "Tiling + Online Softmax",
        "memory": "O(N) — no attention matrix stored",
        "speedup": "2-4x vs standard"
    },
    "FlashAttention 2 (arXiv:2307.08691)": {
        "key_innovation": "Work partitioning, FP16/BF16",
        "memory": "O(N)",
        "speedup": "5-9x vs standard on H100"
    },
    "FlashAttention 3 (arXiv:2407.08608)": {
        "key_innovation": "H100-specific, FP8, async pipeline",
        "memory": "O(N)",
        "speedup": "1.5-2x vs FA2 on H100"
    },
}

for name, info in flash_versions.items():
    print(f"\n{name}")
    for k, v in info.items():
        print(f"  {k}: {v}")

7. Multi-GPU Inference

7.1 Tensor Parallelism

Weight matrices are split across GPUs; each GPU handles a shard.

import torch
import torch.nn as nn

class TensorParallelLinear(nn.Module):
    """
    Tensor Parallel Linear layer (column-parallel)
    Each GPU owns out_features // world_size output neurons
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        world_size: int,
        rank: int
    ):
        super().__init__()
        self.world_size = world_size
        self.rank = rank
        self.local_out_features = out_features // world_size

        self.weight = nn.Parameter(
            torch.randn(self.local_out_features, in_features) / (in_features ** 0.5)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        local_output = nn.functional.linear(x, self.weight)
        # In a real distributed setup: dist.all_gather(output_list, local_output)
        return local_output


def setup_vllm_multiGPU(model_name: str, tp_size: int):
    """Set up multi-GPU vLLM inference"""
    from vllm import LLM, SamplingParams

    llm = LLM(
        model=model_name,
        tensor_parallel_size=tp_size,
        gpu_memory_utilization=0.9
    )

    return llm

7.2 Full vLLM Usage

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
import asyncio
import time

def vllm_basic_usage():
    """vLLM basic usage"""

    llm = LLM(
        model="meta-llama/Llama-2-7b-hf",
        tensor_parallel_size=1,
        gpu_memory_utilization=0.90,
        max_model_len=4096,
        quantization=None,             # "awq", "gptq", "squeezellm"
        dtype="auto",
        max_num_seqs=256,
        enable_prefix_caching=True,
        use_v2_block_manager=True,
    )

    sampling_params = SamplingParams(
        temperature=0.8,
        top_p=0.95,
        max_tokens=200,
    )

    prompts = [
        "Explain quantum computing in simple terms",
        "What is the future of artificial intelligence?",
        "How does the human brain work?",
    ]

    outputs = llm.generate(prompts, sampling_params)

    for output in outputs:
        print(f"Prompt: {output.prompt[:50]}...")
        print(f"Output: {output.outputs[0].text[:100]}...")
        print(f"Tokens: {len(output.outputs[0].token_ids)}")
        print()

    return outputs


async def vllm_async_server():
    """vLLM async engine usage"""

    engine_args = AsyncEngineArgs(
        model="meta-llama/Llama-2-7b-hf",
        tensor_parallel_size=1,
        gpu_memory_utilization=0.90,
        max_model_len=4096,
        enable_prefix_caching=True,
    )

    engine = AsyncLLMEngine.from_engine_args(engine_args)

    async def generate_stream(prompt: str, request_id: str):
        sampling_params = SamplingParams(temperature=0.8, max_tokens=200)

        full_text = ""
        async for output in engine.generate(prompt, sampling_params, request_id):
            if output.outputs:
                delta = output.outputs[0].text[len(full_text):]
                full_text = output.outputs[0].text

                if delta:
                    print(f"[{request_id}] {delta}", end="", flush=True)

            if output.finished:
                print(f"\n[{request_id}] done")

    await asyncio.gather(
        generate_stream("What is AI?", "req_1"),
        generate_stream("Explain machine learning", "req_2"),
        generate_stream("What is deep learning?", "req_3"),
    )

8. Inference Engine Comparison

8.1 Major Engine Features

EngineAuthorKey FeaturesBest For
vLLMUC BerkeleyPagedAttention, Continuous BatchingGeneral LLM serving
TGIHuggingFaceFlash Attention 2, SpeculativeHF model serving
TensorRT-LLMNVIDIANVIDIA-optimized, FP8Max NVIDIA perf
DeepSpeed-MIIMicrosoftZeRO inference, huge modelsMulti-GPU giant models
llama.cppG. GerganovCPU-optimized, GGUFLocal execution

8.2 Benchmark Results

import time

def run_inference_benchmark(engine_name: str, model, tokenizer, prompts, max_tokens=100):
    """Simple inference benchmark"""

    num_warmup = 5
    num_runs = 50

    # Warmup
    for prompt in prompts[:num_warmup]:
        _ = model.generate(
            tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
            max_new_tokens=max_tokens,
            do_sample=False
        )

    # Timed runs
    import torch
    torch.cuda.synchronize()
    start = time.perf_counter()

    total_tokens = 0
    for prompt in prompts[:num_runs]:
        output = model.generate(
            tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
            max_new_tokens=max_tokens,
            do_sample=False
        )
        total_tokens += max_tokens

    torch.cuda.synchronize()
    elapsed = time.perf_counter() - start

    throughput = total_tokens / elapsed
    latency_ms = elapsed / num_runs * 1000

    print(f"\n{engine_name}")
    print(f"  Throughput: {throughput:.1f} tokens/sec")
    print(f"  Avg latency: {latency_ms:.1f} ms/request")

    return throughput, latency_ms


# Example comparison (A100 80GB, Llama-2-7B, batch=1, 100 output tokens)
benchmark_results = {
    "HuggingFace (FP16)":         {"throughput": 52,  "latency_ms": 1924},
    "HuggingFace (Flash Attn 2)": {"throughput": 78,  "latency_ms": 1280},
    "vLLM":                        {"throughput": 120, "latency_ms": 832},
    "vLLM + AWQ 4bit":             {"throughput": 165, "latency_ms": 606},
    "TensorRT-LLM":                {"throughput": 180, "latency_ms": 555},
}

print("LLM inference engine benchmark (A100 80GB, Llama-2-7B)")
print("=" * 65)
print(f"{'Engine':<30} {'Throughput (tok/s)':<22} {'Latency (ms)':<15}")
print("-" * 65)
for engine, stats in benchmark_results.items():
    print(f"{engine:<30} {stats['throughput']:<22} {stats['latency_ms']:<15}")

9. Prompt Caching

9.1 Prefix Caching

Reuse KV Cache when the same system prompt or document is processed repeatedly.

from vllm import LLM, SamplingParams
import time

def demonstrate_prefix_caching():
    """Demonstrate prefix caching benefit"""

    llm = LLM(
        model="meta-llama/Llama-2-7b-hf",
        enable_prefix_caching=True,
        max_model_len=4096,
    )

    # Long system prompt common to all requests (1000+ tokens)
    system_prompt = (
        "You are a helpful AI assistant with expertise in Python, "
        "machine learning, data science, and cloud computing. "
    ) * 50

    questions = [
        "How do I optimize a Python loop?",
        "What is gradient descent?",
        "Explain containerization.",
        "What is a neural network?",
    ]

    sampling_params = SamplingParams(temperature=0.7, max_tokens=100)
    cold_prompts = [f"{system_prompt}\n\nQuestion: {q}" for q in questions]

    # Cold start (no cache)
    cold_start = time.time()
    llm.generate(cold_prompts, sampling_params)
    cold_time = time.time() - cold_start

    # Warm start (cache hit)
    warm_start = time.time()
    llm.generate(cold_prompts, sampling_params)
    warm_time = time.time() - warm_start

    print(f"Cold start (no cache): {cold_time:.2f}s")
    print(f"Warm start (cache hit): {warm_time:.2f}s")
    print(f"Speedup: {cold_time / warm_time:.2f}x")


def radix_tree_prefix_cache():
    """Radix Tree prefix cache implementation"""

    class RadixNode:
        def __init__(self):
            self.children: dict = {}
            self.kv_cache_block_id: int = None

    class RadixTreeCache:
        """
        Manages token sequences in a Radix Tree to share
        common-prefix KV caches
        """

        def __init__(self):
            self.root = RadixNode()
            self.cache_hits = 0
            self.cache_misses = 0

        def insert(self, token_ids: list, block_id: int):
            node = self.root
            for token_id in token_ids:
                if token_id not in node.children:
                    node.children[token_id] = RadixNode()
                node = node.children[token_id]
            node.kv_cache_block_id = block_id

        def lookup(self, token_ids: list) -> tuple:
            """Find the longest matching prefix"""
            node = self.root
            matched_len = 0
            last_block_id = None

            for i, token_id in enumerate(token_ids):
                if token_id in node.children:
                    node = node.children[token_id]
                    matched_len = i + 1
                    if node.kv_cache_block_id is not None:
                        last_block_id = node.kv_cache_block_id
                else:
                    break

            if last_block_id is not None:
                self.cache_hits += 1
            else:
                self.cache_misses += 1

            return matched_len, last_block_id

        def get_hit_rate(self) -> float:
            total = self.cache_hits + self.cache_misses
            return self.cache_hits / total if total > 0 else 0.0

    return RadixTreeCache()

10. Practical Optimization Checklist

10.1 Step-by-Step Optimization Guide

class LLMOptimizationChecklist:
    """LLM inference optimization checklist"""

    optimizations = [
        {
            "category": "Baseline",
            "level": 1,
            "items": [
                {
                    "name": "Use FP16/BF16",
                    "impact": "High",
                    "effort": "Low",
                    "description": "FP32 → FP16: 2x memory saving, speed improvement",
                    "code": """
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)"""
                },
                {
                    "name": "Enable Flash Attention 2",
                    "impact": "High",
                    "effort": "Low",
                    "description": "2-4x attention speedup, memory saving",
                    "code": """
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16
)"""
                },
            ]
        },
        {
            "category": "KV Cache Optimization",
            "level": 2,
            "items": [
                {
                    "name": "Choose GQA/MQA model",
                    "impact": "High",
                    "effort": "Medium",
                    "description": "4-8x KV Cache reduction, larger effective batch"
                },
                {
                    "name": "Prefix caching",
                    "impact": "Medium",
                    "effort": "Low",
                    "description": "Reuse KV Cache for common system prompts"
                },
            ]
        },
        {
            "category": "Batching Optimization",
            "level": 3,
            "items": [
                {
                    "name": "Continuous Batching with vLLM",
                    "impact": "Very High",
                    "effort": "Low",
                    "description": "2-5x throughput improvement",
                    "code": """
from vllm import LLM, SamplingParams

llm = LLM(
    model=model_name,
    gpu_memory_utilization=0.90,
    enable_prefix_caching=True,
)"""
                },
            ]
        },
        {
            "category": "Model Optimization",
            "level": 4,
            "items": [
                {
                    "name": "AWQ 4-bit quantization",
                    "impact": "High",
                    "effort": "Medium",
                    "description": "4x memory reduction, 1.5-2x speed",
                    "code": """
from awq import AutoAWQForCausalLM

model = AutoAWQForCausalLM.from_quantized(
    "model-awq-4bit",
    fuse_layers=True
)"""
                },
                {
                    "name": "Speculative Decoding",
                    "impact": "Medium",
                    "effort": "High",
                    "description": "2-3x speedup (requires suitable draft model)"
                },
            ]
        },
        {
            "category": "Hardware Optimization",
            "level": 5,
            "items": [
                {
                    "name": "Tensor Parallelism",
                    "impact": "Very High",
                    "effort": "Medium",
                    "description": "Linear throughput scaling with multiple GPUs"
                },
                {
                    "name": "CUDA graph capture",
                    "impact": "Medium",
                    "effort": "High",
                    "description": "Eliminate kernel launch overhead"
                },
            ]
        }
    ]

    @classmethod
    def print_checklist(cls):
        print("=" * 70)
        print("LLM Inference Optimization — Step-by-Step Checklist")
        print("=" * 70)

        for category in cls.optimizations:
            print(f"\n[Level {category['level']}] {category['category']}")
            print("-" * 50)

            for item in category['items']:
                impact_stars = {"Very High": "★★★", "High": "★★", "Medium": "★", "Low": "☆"}
                print(f"  ✓ {item['name']}")
                print(f"    Impact: {impact_stars.get(item['impact'], '?')} {item['impact']}")
                print(f"    Note: {item['description']}")

        print("\nRecommended optimization order:")
        print("1. Switch to BF16/FP16                  (immediate, free)")
        print("2. Enable Flash Attention 2             (immediate, just install)")
        print("3. Serve with vLLM                      (max throughput)")
        print("4. AWQ/GPTQ 4-bit quantization          (4x memory reduction)")
        print("5. Speculative Decoding                 (latency improvement)")
        print("6. Multi-GPU Tensor Parallelism         (scale out)")

LLMOptimizationChecklist.print_checklist()

10.2 End-to-End Production Setup

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from fastapi import FastAPI
from pydantic import BaseModel
import asyncio
import uvicorn

app = FastAPI(title="LLM Inference API")

class GenerateRequest(BaseModel):
    prompt: str
    max_tokens: int = 200
    temperature: float = 0.7
    top_p: float = 0.95
    stream: bool = False

class GenerateResponse(BaseModel):
    text: str
    tokens_generated: int
    finish_reason: str

# Global engine instance
engine: AsyncLLMEngine = None

def create_optimized_engine(model_name: str, **kwargs) -> AsyncLLMEngine:
    """Create a production-optimized vLLM engine"""

    engine_args = AsyncEngineArgs(
        model=model_name,
        tensor_parallel_size=kwargs.get("tp_size", 1),
        gpu_memory_utilization=kwargs.get("gpu_util", 0.90),
        max_model_len=kwargs.get("max_model_len", 4096),
        quantization=kwargs.get("quantization", None),  # "awq" or "gptq"
        dtype="auto",
        max_num_seqs=kwargs.get("max_num_seqs", 256),
        enable_prefix_caching=True,
        use_v2_block_manager=True,
        speculative_model=kwargs.get("draft_model", None),  # optional draft model
        num_speculative_tokens=kwargs.get("num_spec_tokens", 5),
    )

    return AsyncLLMEngine.from_engine_args(engine_args)

@app.on_event("startup")
async def startup():
    global engine
    engine = create_optimized_engine(
        model_name="meta-llama/Llama-2-7b-hf",
        tp_size=1,
        gpu_util=0.90,
        max_model_len=4096,
    )

@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
    sampling_params = SamplingParams(
        temperature=request.temperature,
        top_p=request.top_p,
        max_tokens=request.max_tokens,
    )

    request_id = f"req_{id(request)}"

    final_output = None
    async for output in engine.generate(request.prompt, sampling_params, request_id):
        final_output = output

    if final_output and final_output.outputs:
        result = final_output.outputs[0]
        return GenerateResponse(
            text=result.text,
            tokens_generated=len(result.token_ids),
            finish_reason=result.finish_reason or "length"
        )

    return GenerateResponse(text="", tokens_generated=0, finish_reason="error")

# Run with: uvicorn script:app --host 0.0.0.0 --port 8000 --workers 1

Conclusion

LLM inference optimization requires a layered approach.

Key takeaways:

  1. Understand KV Cache: Memorize 2 * layers * kv_heads * d_head * seq_len * dtype_bytes. Use GQA/MQA to cut KV Cache by 4–8x.

  2. PagedAttention: vLLM's core innovation — borrowed from OS virtual memory to eliminate KV Cache fragmentation.

  3. Continuous Batching: Immediately insert new requests as completions happen, maximizing GPU utilization.

  4. Speculative Decoding: Small draft model + large verifier = 2–3x speedup at the right acceptance rate.

  5. FlashAttention: Reduces attention's memory from O(N^2) to O(N), enabling long contexts.

Production deployment recommendations:

  • Small services: vLLM + AWQ 4-bit + prefix caching
  • Large services: TensorRT-LLM or vLLM + Tensor Parallelism
  • Lowest latency: Speculative Decoding + CUDA graphs

References