Skip to content
Published on

LLM推論最適化完全ガイド: KVキャッシュ・投機的デコーディング・連続バッチング

Authors

はじめに

大規模言語モデル(LLM)をプロダクションにデプロイすると、すぐに課題に直面します。推論速度とコストです。GPT-4クラスのモデルへの最初のクエリは数秒かかり、同時ユーザーが増えるとスループットが急激に低下します。

このガイドでは、LLM推論最適化のコアテクニックを徹底解説します。KVキャッシュの内部構造からPagedAttention、投機的デコーディング、FlashAttention、そして最新のvLLMとTensorRT-LLMエンジンまで、使い方だけでなく、なぜそれが機能するのかを理解します。


1. LLM推論パイプラインの理解

1.1 2つのフェーズ: PrefillとDecode

LLMのテキスト生成は2つの明確なフェーズに分かれます。

Prefillフェーズ(プロンプト処理)

  • 入力プロンプトの全トークンを同時に処理(並列)
  • 各レイヤーでKey/Valueキャッシュを生成して保存
  • Compute-Bound: GPUの計算スループットがボトルネック
  • TTFT(Time To First Token)に直接影響

Decodeフェーズ(トークン生成)

  • 自己回帰的に1トークンずつ生成
  • 以前に生成した全トークンのKVキャッシュを参照
  • Memory-Bandwidth-Bound: HBMの読み取り速度がボトルネック
  • TPOT(Time Per Output Token)に直接影響
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time

def measure_prefill_decode_time(model, tokenizer, prompt: str, max_new_tokens: int = 100):
    """PrefillとDecodeフェーズのタイミングを計測"""

    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    input_len = inputs["input_ids"].size(1)

    # TTFTを計測
    torch.cuda.synchronize()
    prefill_start = time.perf_counter()

    with torch.no_grad():
        first_output = model.generate(
            **inputs,
            max_new_tokens=1,
            do_sample=False
        )

    torch.cuda.synchronize()
    ttft = time.perf_counter() - prefill_start

    # 全体の生成を計測
    torch.cuda.synchronize()
    total_start = time.perf_counter()

    with torch.no_grad():
        full_output = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False
        )

    torch.cuda.synchronize()
    total_time = time.perf_counter() - total_start

    output_tokens = full_output.size(1) - input_len
    decode_time = total_time - ttft
    tpot = decode_time / max(output_tokens - 1, 1)

    print(f"入力トークン数: {input_len}")
    print(f"出力トークン数: {output_tokens}")
    print(f"TTFT(最初のトークンレイテンシ): {ttft * 1000:.1f} ms")
    print(f"TPOT(トークンあたりの時間): {tpot * 1000:.1f} ms")
    print(f"スループット: {output_tokens / total_time:.1f} tokens/sec")

    return ttft, tpot

1.2 メモリ帯域幅の分析

Decodeフェーズがメモリバウンドになる理由を理解します:

def analyze_memory_bandwidth():
    """LLM推論のメモリ帯域幅分析"""

    # 例: Llama-2-7Bの設定
    model_params = {
        "num_layers": 32,
        "hidden_size": 4096,
        "num_heads": 32,
        "head_dim": 128,
        "vocab_size": 32000,
    }

    dtype_bytes = 2  # FP16: 2バイト

    # 重みメモリ
    attn_weight = 4 * model_params["hidden_size"] ** 2  # Q, K, V, Oプロジェクション
    ffn_weight = 8 * model_params["hidden_size"] ** 2   # SwiGLU: Up, Gate, Down
    layer_weight = (attn_weight + ffn_weight) * dtype_bytes

    total_weight_bytes = layer_weight * model_params["num_layers"]
    total_weight_gb = total_weight_bytes / 1e9

    print(f"モデル重み: {total_weight_gb:.2f} GB")

    # KVキャッシュメモリ(シーケンス長に比例)
    seq_len = 2048
    kv_cache_per_token = (
        2 *
        model_params["num_layers"] *
        model_params["num_heads"] *
        model_params["head_dim"] *
        dtype_bytes
    )

    kv_cache_total = kv_cache_per_token * seq_len / 1e6
    print(f"KVキャッシュ({seq_len}トークン): {kv_cache_total:.2f} MB")
    print(f"トークンあたりのKVキャッシュ: {kv_cache_per_token} バイト")

    # A100のメモリ帯域幅: 2 TB/s
    memory_bandwidth_tbs = 2.0

    # Decode: 重みはトークンごとに1回読み込まれる
    memory_bound_tps = memory_bandwidth_tbs * 1e12 / total_weight_bytes
    # A100 FP16: 312 TFLOPS
    compute_throughput = 312e12
    flops_per_token = 2 * total_weight_bytes
    compute_bound_tps = compute_throughput / flops_per_token

    print(f"\nbatch_size=1の場合:")
    print(f"メモリバウンドスループット: {memory_bound_tps:.1f} tokens/sec")
    print(f"コンピュートバウンドスループット: {compute_bound_tps:.1f} tokens/sec")
    print(f"実際のボトルネック: {'メモリ' if memory_bound_tps < compute_bound_tps else 'コンピュート'}")

analyze_memory_bandwidth()

1.3 推論コスト分析

def estimate_inference_cost(
    model_size_b: float,
    tokens_per_request: int,
    requests_per_day: int,
    gpu_cost_per_hour: float = 3.0  # A100の時間単価(USD)
):
    """推論コストを推定"""

    # 経験的スループット推定
    # 7Bモデル: ~100 tok/s、70Bモデル: ~20 tok/s (A100)
    throughput_tps = 100 / (model_size_b / 7) ** 0.6

    total_tokens_per_day = tokens_per_request * requests_per_day
    seconds_needed = total_tokens_per_day / throughput_tps
    hours_needed = seconds_needed / 3600

    daily_cost = hours_needed * gpu_cost_per_hour
    cost_per_1k_tokens = daily_cost / (total_tokens_per_day / 1000)

    print(f"モデル: {model_size_b}Bパラメータ")
    print(f"日次リクエスト数: {requests_per_day:,}")
    print(f"リクエストあたりトークン数: {tokens_per_request}")
    print(f"日次総トークン数: {total_tokens_per_day:,}")
    print(f"推定スループット: {throughput_tps:.1f} tokens/sec")
    print(f"必要GPU時間: {hours_needed:.2f}")
    print(f"日次コスト: ${daily_cost:.2f}")
    print(f"1Kトークンあたりのコスト: ${cost_per_1k_tokens:.4f}")

estimate_inference_cost(
    model_size_b=7.0,
    tokens_per_request=500,
    requests_per_day=10000
)

2. KVキャッシュ: コアとなる最適化

2.1 KVキャッシュが必要な理由

Transformerのアテンションでは、各トークンは全ての過去のトークンに対してアテンションを計算します。既に処理済みのトークンの再計算を避けるために、KとVの行列をキャッシュします。

import torch
import torch.nn as nn
import math

class MultiHeadAttentionWithKVCache(nn.Module):
    """KVキャッシュをサポートするマルチヘッドアテンション"""

    def __init__(self, d_model: int, num_heads: int, max_seq_len: int = 4096):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

        # KVキャッシュの初期化
        self.register_buffer(
            'k_cache',
            torch.zeros(1, max_seq_len, num_heads, self.d_head)
        )
        self.register_buffer(
            'v_cache',
            torch.zeros(1, max_seq_len, num_heads, self.d_head)
        )
        self.cache_pos = 0

    def forward(self, x: torch.Tensor, use_cache: bool = True, position: int = None):
        batch_size, seq_len, _ = x.shape

        q = self.W_q(x).reshape(batch_size, seq_len, self.num_heads, self.d_head)
        k = self.W_k(x).reshape(batch_size, seq_len, self.num_heads, self.d_head)
        v = self.W_v(x).reshape(batch_size, seq_len, self.num_heads, self.d_head)

        if use_cache:
            start_pos = self.cache_pos if position is None else position
            self.k_cache[:, start_pos:start_pos + seq_len] = k
            self.v_cache[:, start_pos:start_pos + seq_len] = v

            if position is None:
                self.cache_pos += seq_len

            total_len = self.cache_pos if position is None else start_pos + seq_len
            k = self.k_cache[:, :total_len]
            v = self.v_cache[:, :total_len]

        scale = math.sqrt(self.d_head)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale
        attn_weights = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_weights, v)
        output = output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
        output = self.W_o(output)

        return output

    def clear_cache(self):
        self.k_cache.zero_()
        self.v_cache.zero_()
        self.cache_pos = 0

2.2 KVキャッシュのメモリ計算

def calculate_kv_cache_memory(
    model_config: dict,
    batch_size: int,
    seq_len: int,
    dtype_bytes: int = 2  # FP16
) -> dict:
    """KVキャッシュのメモリ使用量を計算"""

    num_layers = model_config["num_layers"]
    num_kv_heads = model_config.get("num_kv_heads", model_config["num_heads"])
    head_dim = model_config["head_dim"]

    # 計算式: 2(K+V) * レイヤー数 * kvヘッド数 * ヘッド次元 * シーケンス長 * バッチ * dtype
    kv_cache_bytes = (
        2 *
        num_layers *
        num_kv_heads *
        head_dim *
        seq_len *
        batch_size *
        dtype_bytes
    )

    return {
        "kv_cache_bytes": kv_cache_bytes,
        "kv_cache_mb": kv_cache_bytes / 1e6,
        "kv_cache_gb": kv_cache_bytes / 1e9,
        "per_token_bytes": kv_cache_bytes // seq_len,
    }

models = {
    "Llama-2-7B (MHA)": {
        "num_layers": 32, "num_heads": 32,
        "num_kv_heads": 32, "head_dim": 128
    },
    "Llama-2-70B (GQA)": {
        "num_layers": 80, "num_heads": 64,
        "num_kv_heads": 8, "head_dim": 128
    },
    "Mistral-7B (GQA)": {
        "num_layers": 32, "num_heads": 32,
        "num_kv_heads": 8, "head_dim": 128
    },
}

print("KVキャッシュメモリ (batch=1, seq=4096)")
print("=" * 65)
for name, config in models.items():
    result = calculate_kv_cache_memory(config, batch_size=1, seq_len=4096)
    print(f"{name:<25} {result['kv_cache_gb']:.2f} GB  "
          f"({result['per_token_bytes']:,} bytes/token)")

2.3 グループクエリアテンション(GQA)

GQAはKVキャッシュを削減するキーテクニックです。複数のクエリヘッドがより少ないKVヘッドを共有します。

import torch
import torch.nn as nn
import math

class GroupedQueryAttention(nn.Module):
    """グループクエリアテンション(GQA)の実装"""

    def __init__(self, d_model: int, num_q_heads: int, num_kv_heads: int):
        super().__init__()
        assert num_q_heads % num_kv_heads == 0

        self.d_model = d_model
        self.num_q_heads = num_q_heads
        self.num_kv_heads = num_kv_heads
        self.num_groups = num_q_heads // num_kv_heads
        self.d_head = d_model // num_q_heads

        self.W_q = nn.Linear(d_model, num_q_heads * self.d_head, bias=False)
        self.W_k = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
        self.W_v = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x: torch.Tensor, kv_cache=None):
        batch_size, seq_len, _ = x.shape

        q = self.W_q(x).reshape(batch_size, seq_len, self.num_q_heads, self.d_head)
        k = self.W_k(x).reshape(batch_size, seq_len, self.num_kv_heads, self.d_head)
        v = self.W_v(x).reshape(batch_size, seq_len, self.num_kv_heads, self.d_head)

        if kv_cache is not None:
            k = torch.cat([kv_cache["k"], k], dim=1)
            v = torch.cat([kv_cache["v"], v], dim=1)
        new_kv_cache = {"k": k, "v": v}

        q = q.transpose(1, 2)   # [B, Q_heads, S, d_head]
        k = k.transpose(1, 2)   # [B, KV_heads, T, d_head]
        v = v.transpose(1, 2)

        # GQA: KVヘッドをQヘッドに合わせて拡張
        k = k.repeat_interleave(self.num_groups, dim=1)
        v = v.repeat_interleave(self.num_groups, dim=1)

        scale = math.sqrt(self.d_head)
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale
        attn_weights = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_weights, v)
        output = output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)

        return self.W_o(output), new_kv_cache

2.4 DeepSeek MLA(マルチヘッド潜在アテンション)

DeepSeek-V2で導入されたMLAは、KVキャッシュを低次元の潜在ベクトルに圧縮します。

class MultiHeadLatentAttention(nn.Module):
    """
    DeepSeek MLA — KVキャッシュを低ランク潜在ベクトルに圧縮

    コアアイデア:
    - 高次元のK,Vを保存するのではなく、低次元の潜在c_kvを保存
    - c_kvからアッププロジェクションでK, Vを復元
    - KVキャッシュサイズ: num_layers * kv_lora_rank * seq_len
      (標準: 2 * num_layers * num_kv_heads * d_head * seq_len と比較)
    """

    def __init__(
        self,
        d_model: int = 5120,
        num_heads: int = 128,
        kv_lora_rank: int = 512,
        qk_nope_head_dim: int = 128,
        qk_rope_head_dim: int = 64,
        v_head_dim: int = 128,
    ):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.kv_lora_rank = kv_lora_rank

        # Qプロジェクション(LoRAスタイル)
        self.q_a_proj = nn.Linear(d_model, 1536, bias=False)
        self.q_b_proj = nn.Linear(
            1536,
            num_heads * (qk_nope_head_dim + qk_rope_head_dim),
            bias=False
        )

        # KVダウンプロジェクション: d_model -> kv_lora_rank
        # これだけがKVキャッシュに保存される!
        self.kv_a_proj = nn.Linear(
            d_model,
            kv_lora_rank + qk_rope_head_dim,
            bias=False
        )

        # KVアッププロジェクション: kv_lora_rank -> K, V
        self.kv_b_proj = nn.Linear(
            kv_lora_rank,
            num_heads * (qk_nope_head_dim + v_head_dim),
            bias=False
        )

        self.o_proj = nn.Linear(num_heads * v_head_dim, d_model, bias=False)

    def forward(self, x: torch.Tensor, compressed_kv_cache=None):
        batch_size, seq_len, _ = x.shape

        # KV圧縮(この結果のみキャッシュに入る)
        kv_compressed = self.kv_a_proj(x)  # [B, S, kv_lora_rank + rope_dim]

        if compressed_kv_cache is not None:
            kv_compressed_total = torch.cat([compressed_kv_cache, kv_compressed], dim=1)
        else:
            kv_compressed_total = kv_compressed

        # キャッシュされた圧縮表現から完全なK, Vを復元
        kv_content = kv_compressed_total[:, :, :self.kv_lora_rank]
        kv_full = self.kv_b_proj(kv_content)

        return None, kv_compressed

3. PagedAttention: vLLMのコアイノベーション

3.1 従来のKVキャッシュの問題

従来のLLMサービングシステムはリクエストごとに最大シーケンス長のメモリを事前割り当てします:

リクエスト1: [PROMPT=200トークン] [KV_CACHE=最大1848トークン予約済み] → 内部断片化
リクエスト2: [PROMPT=100トークン] [KV_CACHE=1948トークン予約済み]
リクエスト3: 待機中(外部断片化)

これにより60-80%のGPUメモリが無駄になります。

3.2 PagedAttention: 仕組み

OSの仮想メモリに着想を得て、PagedAttentionは固定サイズの物理ブロックでKVキャッシュを管理します。

from dataclasses import dataclass, field
from typing import Dict, List, Optional
import torch

@dataclass
class PhysicalBlock:
    """物理メモリブロック"""
    block_id: int
    block_size: int  # ブロックあたりのトークン数(例: 16)
    ref_count: int = 0

@dataclass
class LogicalBlock:
    """論理ブロック(リクエストにマッピングされる)"""
    physical_block_id: int
    num_filled: int = 0

class PagedKVCacheManager:
    """PagedAttention KVキャッシュマネージャー"""

    def __init__(
        self,
        num_physical_blocks: int,
        block_size: int,
        num_layers: int,
        num_kv_heads: int,
        head_dim: int,
        device: str = "cuda"
    ):
        self.block_size = block_size
        self.num_layers = num_layers

        self.free_blocks: List[int] = list(range(num_physical_blocks))
        self.all_blocks: Dict[int, PhysicalBlock] = {
            i: PhysicalBlock(block_id=i, block_size=block_size)
            for i in range(num_physical_blocks)
        }
        self.block_tables: Dict[int, List[LogicalBlock]] = {}

        # 実際のKVキャッシュテンソルプール
        self.kv_cache = torch.zeros(
            num_physical_blocks, 2, num_layers, block_size, num_kv_heads, head_dim,
            dtype=torch.float16,
            device=device
        )

    def allocate_blocks_for_request(self, request_id: int, num_tokens: int):
        """リクエスト用のブロックを割り当て"""
        num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size

        if len(self.free_blocks) < num_blocks_needed:
            raise RuntimeError(
                f"OOM: {num_blocks_needed}ブロック必要、{len(self.free_blocks)}ブロック利用可能"
            )

        logical_blocks = []
        for _ in range(num_blocks_needed):
            physical_id = self.free_blocks.pop(0)
            self.all_blocks[physical_id].ref_count = 1
            logical_blocks.append(LogicalBlock(physical_block_id=physical_id))

        self.block_tables[request_id] = logical_blocks
        print(f"リクエスト {request_id}: {num_blocks_needed}ブロック割り当て済み、"
              f"{len(self.free_blocks)}ブロック残り")

    def free_request(self, request_id: int):
        """リクエスト完了後にブロックを解放"""
        if request_id in self.block_tables:
            for logical_block in self.block_tables[request_id]:
                phys_id = logical_block.physical_block_id
                self.all_blocks[phys_id].ref_count -= 1

                if self.all_blocks[phys_id].ref_count == 0:
                    self.free_blocks.append(phys_id)

            del self.block_tables[request_id]


# デモ
manager = PagedKVCacheManager(
    num_physical_blocks=1000,
    block_size=16,
    num_layers=32,
    num_kv_heads=32,
    head_dim=128
)

manager.allocate_blocks_for_request(request_id=1, num_tokens=200)
manager.allocate_blocks_for_request(request_id=2, num_tokens=500)
manager.allocate_blocks_for_request(request_id=3, num_tokens=100)

manager.free_request(request_id=1)
print(f"\nリクエスト1完了後 — 空きブロック: {len(manager.free_blocks)}")

4. 連続バッチング

4.1 静的バッチングの問題

静的バッチングはバッチ内の全リクエストが完了するまで次のバッチを開始しません:

t=0: [リクエストA: 500トークン] [リクエストB: 100トークン] [リクエストC: 300トークン]
t=1: リクエストB完了 — ACが終わるまで新規リクエストを開始できない
t=2: リクエストC完了
t=3: リクエストA完了 → ここでやっと新しいバッチを開始できる

4.2 連続バッチング(イテレーションレベルスケジューリング)

from dataclasses import dataclass, field
from typing import List, Tuple
import time

@dataclass
class Request:
    """推論リクエスト"""
    request_id: str
    input_ids: List[int]
    max_new_tokens: int
    generated_ids: List[int] = field(default_factory=list)
    is_finished: bool = False

class ContinuousBatchingScheduler:
    """連続バッチングスケジューラー"""

    def __init__(self, max_batch_size: int = 32, max_seq_len: int = 4096):
        self.max_batch_size = max_batch_size
        self.max_seq_len = max_seq_len

        self.waiting_queue: List[Request] = []
        self.running_requests: List[Request] = []
        self.finished_requests: List[Request] = []

    def add_request(self, request: Request):
        self.waiting_queue.append(request)

    def schedule_iteration(self) -> Tuple[List[Request], List[str]]:
        """
        1イテレーションのバッチをスケジュール

        戻り値:
            (アクティブなリクエスト, 完了したリクエストIDのリスト)
        """
        completed_ids = []

        still_running = []
        for req in self.running_requests:
            if req.is_finished:
                self.finished_requests.append(req)
                completed_ids.append(req.request_id)
            else:
                still_running.append(req)
        self.running_requests = still_running

        # 空いたスロットに待機中のリクエストをすぐに追加
        while self.waiting_queue and self._can_add_request(self.waiting_queue[0]):
            new_request = self.waiting_queue.pop(0)
            self.running_requests.append(new_request)
            print(f"リクエスト {new_request.request_id} をバッチに追加 "
                  f"(バッチサイズ: {len(self.running_requests)})")

        return self.running_requests, completed_ids

5. 投機的デコーディング

5.1 アイデア: ドラフト + 検証

投機的デコーディングのコア: 小さいドラフトモデルが複数のトークンを並列で生成し、大きいターゲットモデルが全てを一度に(Prefillのように)検証します。

標準: [大モデル] → token1 → token2 → token3 → token4 → token5
投機的: [小モデル]  (t1, t2, t3, t4, t5) を並列で生成
        [大モデル]5トークンを一度に検証(並列、Prefillのように)
        受け入れられたトークンのみ保持

5.2 受け入れ率とスピードアップ分析

import torch
from typing import List, Tuple

def speculative_decode_step(
    draft_model,
    target_model,
    input_ids: torch.Tensor,
    draft_steps: int = 4,
    temperature: float = 1.0
) -> Tuple[torch.Tensor, int, int]:
    """
    投機的デコーディングの1ステップ

    戻り値:
        (生成トークン, 受け入れ数, ドラフト数)
    """
    # 1. ドラフトモデルが候補トークンを生成
    draft_tokens = []
    draft_probs = []
    current_ids = input_ids.clone()

    for _ in range(draft_steps):
        with torch.no_grad():
            draft_logits = draft_model(current_ids).logits[:, -1, :]

        draft_prob = torch.softmax(draft_logits / (temperature + 1e-8), dim=-1)
        draft_token = torch.multinomial(draft_prob, num_samples=1)
        draft_tokens.append(draft_token)
        draft_probs.append(draft_prob)
        current_ids = torch.cat([current_ids, draft_token], dim=1)

    draft_sequence = torch.cat(draft_tokens, dim=1)
    candidate_ids = torch.cat([input_ids, draft_sequence], dim=1)

    # 2. ターゲットモデルが全ドラフトトークンを1回のパスで検証
    with torch.no_grad():
        target_logits = target_model(candidate_ids).logits[
            :, input_ids.size(1) - 1:-1, :
        ]

    target_probs = torch.softmax(target_logits / (temperature + 1e-8), dim=-1)

    # 3. 各ドラフトトークンの受け入れ/拒否
    accepted_tokens = []
    num_accepted = 0

    for step in range(draft_steps):
        token = draft_sequence[:, step]
        p_draft = draft_probs[step].gather(1, token.unsqueeze(1)).squeeze(1)
        p_target = target_probs[:, step, :].gather(1, token.unsqueeze(1)).squeeze(1)

        acceptance_prob = torch.clamp(p_target / (p_draft + 1e-8), max=1.0)
        accepted = torch.rand_like(acceptance_prob) < acceptance_prob

        if not accepted.all():
            break

        accepted_tokens.append(token)
        num_accepted += 1

    # 4. ターゲットモデルからの最後のトークン
    last_logits = target_model(candidate_ids).logits[:, input_ids.size(1) + num_accepted - 1, :]
    last_prob = torch.softmax(last_logits / (temperature + 1e-8), dim=-1)

    if num_accepted < draft_steps:
        correction = torch.clamp(last_prob - draft_probs[num_accepted], min=0)
        correction = correction / (correction.sum(dim=-1, keepdim=True) + 1e-8)
        last_token = torch.multinomial(correction, num_samples=1)
    else:
        last_token = torch.multinomial(last_prob, num_samples=1)

    accepted_tokens.append(last_token.squeeze(1))
    final_tokens = torch.stack(accepted_tokens, dim=1)

    return final_tokens, num_accepted, draft_steps


def analyze_speedup(acceptance_rate: float, draft_steps: int = 4) -> dict:
    """受け入れ率に基づくスピードアップ分析"""

    # E[受け入れトークン数] = sum_{k=0}^{K} alpha^k
    expected_accepted = sum(
        acceptance_rate ** k for k in range(draft_steps + 1)
    )

    # ドラフトモデルはターゲットモデルの1/10のサイズ
    draft_model_ratio = 0.1

    steps_with_speculative = draft_steps * draft_model_ratio + 1
    speedup = expected_accepted / steps_with_speculative

    return {
        "acceptance_rate": acceptance_rate,
        "expected_accepted": expected_accepted,
        "speedup": speedup
    }

print("受け入れ率別の投機的デコーディングスピードアップ (K=4)")
print("=" * 55)
for alpha in [0.5, 0.6, 0.7, 0.8, 0.9, 0.95]:
    result = analyze_speedup(alpha, draft_steps=4)
    print(f"受け入れ率 {alpha:.0%}: 期待値 {result['expected_accepted']:.2f} トークン、"
          f"スピードアップ {result['speedup']:.2f}x")

5.3 Medusa: 複数のドラフトヘッド

import torch
import torch.nn as nn

class MedusaHead(nn.Module):
    """
    Medusa: 単一モデルに複数のドラフトヘッドを付加

    各ヘッドが将来のトークンを予測:
    - ヘッド1: t+1を予測
    - ヘッド2: t+2を予測
    - ヘッドN: t+Nを予測
    """

    def __init__(
        self,
        hidden_size: int,
        vocab_size: int,
        num_heads: int = 4,
        hidden_layers: int = 1
    ):
        super().__init__()
        self.num_heads = num_heads

        self.heads = nn.ModuleList([
            nn.Sequential(
                *[nn.Linear(hidden_size, hidden_size, bias=False),
                  nn.SiLU()] * hidden_layers,
                nn.Linear(hidden_size, vocab_size, bias=False)
            )
            for _ in range(num_heads)
        ])

    def forward(self, hidden_states: torch.Tensor):
        """
        各将来の位置のロジットリストを返す
        """
        return [head(hidden_states) for head in self.heads]

6. FlashAttention: メモリ効率的なアテンション

6.1 標準アテンションのHBMボトルネック

標準アテンションは中間結果をHBMに繰り返し書き込み・読み取りします:

標準アテンションのメモリ操作:
1. Q, KHBMから読み込む        → 読み込み: O(N * d)
2. S = Q @ K.T を計算           → 書き込み: O(N^2)  ← ボトルネック!
3. SoftmaxのためSHBMから読む  → 読み込み: O(N^2)
4. P = softmax(S) を保存        → 書き込み: O(N^2)
5. P @ V のためPを読む          → 読み込み: O(N^2)
6. 最終出力を保存               → 書き込み: O(N * d)

HBMアクセス: O(N^2) — シーケンス長に対して二次!

6.2 FlashAttentionのタイリング戦略

import torch
import math

def flash_attention_v1(Q, K, V, block_size=64):
    """
    FlashAttention v1の簡略実装
    タイリングによって完全なアテンション行列のHBM保存を回避

    キー: ブロックごとの処理のためのオンラインSoftmax
    """
    batch_size, num_heads, seq_len, d_head = Q.shape

    scale = 1.0 / math.sqrt(d_head)
    Q = Q * scale

    O = torch.zeros_like(Q)
    L = torch.zeros(batch_size, num_heads, seq_len, 1, device=Q.device)
    M = torch.full((batch_size, num_heads, seq_len, 1), float('-inf'), device=Q.device)

    num_blocks = (seq_len + block_size - 1) // block_size

    for j in range(num_blocks):
        k_start = j * block_size
        k_end = min((j + 1) * block_size, seq_len)
        K_j = K[:, :, k_start:k_end, :]
        V_j = V[:, :, k_start:k_end, :]

        for i in range(num_blocks):
            q_start = i * block_size
            q_end = min((i + 1) * block_size, seq_len)
            Q_i = Q[:, :, q_start:q_end, :]
            O_i = O[:, :, q_start:q_end, :]
            L_i = L[:, :, q_start:q_end, :]
            M_i = M[:, :, q_start:q_end, :]

            # SRAMでアテンションスコアを計算
            S_ij = torch.matmul(Q_i, K_j.transpose(-2, -1))

            # オンラインSoftmaxの更新
            M_new = torch.maximum(M_i, S_ij.max(dim=-1, keepdim=True)[0])
            P_ij = torch.exp(S_ij - M_new)
            L_new = torch.exp(M_i - M_new) * L_i + P_ij.sum(dim=-1, keepdim=True)

            # スケール変更と出力の更新
            O_new = torch.exp(M_i - M_new) * O_i + torch.matmul(P_ij, V_j)

            O[:, :, q_start:q_end, :] = O_new
            L[:, :, q_start:q_end, :] = L_new
            M[:, :, q_start:q_end, :] = M_new

    return O / L

6.3 PyTorch SDPAの使用

import torch
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel

def modern_attention(q, k, v, is_causal=True, dropout_p=0.0):
    """
    PyTorch 2.0以降のscaled_dot_product_attention

    FlashAttention 2/3を自動で選択
    """
    return F.scaled_dot_product_attention(
        q, k, v,
        attn_mask=None,
        dropout_p=dropout_p,
        is_causal=is_causal,
        scale=None  # デフォルトは1/sqrt(d_head)
    )

# Flash Attentionバージョンのハイライト
flash_versions = {
    "FlashAttention 1 (arXiv:2205.14135)": {
        "key_innovation": "タイリング + オンラインSoftmax",
        "memory": "O(N) — アテンション行列を保存しない",
        "speedup": "標準比2-4x"
    },
    "FlashAttention 2 (arXiv:2307.08691)": {
        "key_innovation": "ワーク分割、FP16/BF16",
        "memory": "O(N)",
        "speedup": "H100での標準比5-9x"
    },
    "FlashAttention 3 (arXiv:2407.08608)": {
        "key_innovation": "H100特化、FP8、非同期パイプライン",
        "memory": "O(N)",
        "speedup": "H100でFA2比1.5-2x"
    },
}

for name, info in flash_versions.items():
    print(f"\n{name}")
    for k, v in info.items():
        print(f"  {k}: {v}")

7. マルチGPU推論

7.1 テンソル並列

重み行列をGPU間で分割し、各GPUがシャードを処理します。

import torch
import torch.nn as nn

class TensorParallelLinear(nn.Module):
    """
    テンソル並列Linearレイヤー(列並列)
    各GPUがout_features // world_sizeの出力ニューロンを所有
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        world_size: int,
        rank: int
    ):
        super().__init__()
        self.world_size = world_size
        self.rank = rank
        self.local_out_features = out_features // world_size

        self.weight = nn.Parameter(
            torch.randn(self.local_out_features, in_features) / (in_features ** 0.5)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        local_output = nn.functional.linear(x, self.weight)
        # 実際の分散セットアップでは: dist.all_gather(output_list, local_output)
        return local_output


def setup_vllm_multiGPU(model_name: str, tp_size: int):
    """マルチGPU vLLM推論をセットアップ"""
    from vllm import LLM, SamplingParams

    llm = LLM(
        model=model_name,
        tensor_parallel_size=tp_size,
        gpu_memory_utilization=0.9
    )

    return llm

7.2 vLLMの完全な使用方法

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
import asyncio
import time

def vllm_basic_usage():
    """vLLMの基本的な使用方法"""

    llm = LLM(
        model="meta-llama/Llama-2-7b-hf",
        tensor_parallel_size=1,
        gpu_memory_utilization=0.90,
        max_model_len=4096,
        quantization=None,             # "awq", "gptq", "squeezellm"
        dtype="auto",
        max_num_seqs=256,
        enable_prefix_caching=True,
        use_v2_block_manager=True,
    )

    sampling_params = SamplingParams(
        temperature=0.8,
        top_p=0.95,
        max_tokens=200,
    )

    prompts = [
        "量子コンピューティングを簡単に説明してください",
        "人工知能の未来は?",
        "人間の脳はどのように機能しますか?",
    ]

    outputs = llm.generate(prompts, sampling_params)

    for output in outputs:
        print(f"プロンプト: {output.prompt[:50]}...")
        print(f"出力: {output.outputs[0].text[:100]}...")
        print(f"トークン数: {len(output.outputs[0].token_ids)}")
        print()

    return outputs

8. 推論エンジンの比較

8.1 主要エンジンの特徴

エンジン開発者主要機能最適な用途
vLLMUC BerkeleyPagedAttention、連続バッチング汎用LLMサービング
TGIHuggingFaceFlash Attention 2、投機的デコーディングHFモデルサービング
TensorRT-LLMNVIDIANVIDIA最適化、FP8NVIDIAの最大パフォーマンス
DeepSpeed-MIIMicrosoftZeRO推論、大規模モデルマルチGPUの巨大モデル
llama.cppG. GerganovCPU最適化、GGUFローカル実行

8.2 ベンチマーク結果

import time

def run_inference_benchmark(engine_name: str, model, tokenizer, prompts, max_tokens=100):
    """シンプルな推論ベンチマーク"""

    num_warmup = 5
    num_runs = 50

    # ウォームアップ
    for prompt in prompts[:num_warmup]:
        _ = model.generate(
            tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
            max_new_tokens=max_tokens,
            do_sample=False
        )

    # 時間計測
    import torch
    torch.cuda.synchronize()
    start = time.perf_counter()

    total_tokens = 0
    for prompt in prompts[:num_runs]:
        output = model.generate(
            tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
            max_new_tokens=max_tokens,
            do_sample=False
        )
        total_tokens += max_tokens

    torch.cuda.synchronize()
    elapsed = time.perf_counter() - start

    throughput = total_tokens / elapsed
    latency_ms = elapsed / num_runs * 1000

    print(f"\n{engine_name}")
    print(f"  スループット: {throughput:.1f} tokens/sec")
    print(f"  平均レイテンシ: {latency_ms:.1f} ms/リクエスト")

    return throughput, latency_ms


# 比較例(A100 80GB、Llama-2-7B、batch=1、100出力トークン)
benchmark_results = {
    "HuggingFace (FP16)":         {"throughput": 52,  "latency_ms": 1924},
    "HuggingFace (Flash Attn 2)": {"throughput": 78,  "latency_ms": 1280},
    "vLLM":                        {"throughput": 120, "latency_ms": 832},
    "vLLM + AWQ 4bit":             {"throughput": 165, "latency_ms": 606},
    "TensorRT-LLM":                {"throughput": 180, "latency_ms": 555},
}

print("LLM推論エンジンベンチマーク (A100 80GB、Llama-2-7B)")
print("=" * 65)
print(f"{'エンジン':<30} {'スループット (tok/s)':<22} {'レイテンシ (ms)':<15}")
print("-" * 65)
for engine, stats in benchmark_results.items():
    print(f"{engine:<30} {stats['throughput']:<22} {stats['latency_ms']:<15}")

9. プロンプトキャッシング

9.1 プレフィックスキャッシング

同じシステムプロンプトやドキュメントが繰り返し処理される場合にKVキャッシュを再利用します。

from vllm import LLM, SamplingParams
import time

def demonstrate_prefix_caching():
    """プレフィックスキャッシングの効果を実演"""

    llm = LLM(
        model="meta-llama/Llama-2-7b-hf",
        enable_prefix_caching=True,
        max_model_len=4096,
    )

    # 全リクエストに共通の長いシステムプロンプト(1000以上のトークン)
    system_prompt = (
        "あなたはPython、機械学習、データサイエンス、クラウドコンピューティングの専門知識を持つ"
        "有能なAIアシスタントです。"
    ) * 50

    questions = [
        "Pythonのループを最適化するには?",
        "勾配降下法とは何ですか?",
        "コンテナ化について説明してください。",
        "ニューラルネットワークとは何ですか?",
    ]

    sampling_params = SamplingParams(temperature=0.7, max_tokens=100)
    cold_prompts = [f"{system_prompt}\n\n質問: {q}" for q in questions]

    # コールドスタート(キャッシュなし)
    cold_start = time.time()
    llm.generate(cold_prompts, sampling_params)
    cold_time = time.time() - cold_start

    # ウォームスタート(キャッシュヒット)
    warm_start = time.time()
    llm.generate(cold_prompts, sampling_params)
    warm_time = time.time() - warm_start

    print(f"コールドスタート(キャッシュなし): {cold_time:.2f}s")
    print(f"ウォームスタート(キャッシュヒット): {warm_time:.2f}s")
    print(f"スピードアップ: {cold_time / warm_time:.2f}x")

10. 実践的な最適化チェックリスト

10.1 ステップごとの最適化ガイド

class LLMOptimizationChecklist:
    """LLM推論最適化チェックリスト"""

    optimizations = [
        {
            "category": "ベースライン",
            "level": 1,
            "items": [
                {
                    "name": "FP16/BF16を使用",
                    "impact": "高",
                    "effort": "低",
                    "description": "FP32 → FP16: 2xメモリ節約、速度向上",
                },
                {
                    "name": "Flash Attention 2を有効化",
                    "impact": "高",
                    "effort": "低",
                    "description": "2-4xアテンションスピードアップ、メモリ節約",
                },
            ]
        },
        {
            "category": "KVキャッシュ最適化",
            "level": 2,
            "items": [
                {
                    "name": "GQA/MQAモデルを選択",
                    "impact": "高",
                    "effort": "中",
                    "description": "4-8x KVキャッシュ削減、より大きな有効バッチ"
                },
                {
                    "name": "プレフィックスキャッシング",
                    "impact": "中",
                    "effort": "低",
                    "description": "共通システムプロンプトのKVキャッシュを再利用"
                },
            ]
        },
        {
            "category": "バッチング最適化",
            "level": 3,
            "items": [
                {
                    "name": "vLLMで連続バッチング",
                    "impact": "非常に高",
                    "effort": "低",
                    "description": "2-5xスループット向上",
                },
            ]
        },
        {
            "category": "モデル最適化",
            "level": 4,
            "items": [
                {
                    "name": "AWQ 4ビット量子化",
                    "impact": "高",
                    "effort": "中",
                    "description": "4xメモリ削減、1.5-2x速度向上",
                },
                {
                    "name": "投機的デコーディング",
                    "impact": "中",
                    "effort": "高",
                    "description": "2-3xスピードアップ(適切なドラフトモデルが必要)"
                },
            ]
        },
        {
            "category": "ハードウェア最適化",
            "level": 5,
            "items": [
                {
                    "name": "テンソル並列",
                    "impact": "非常に高",
                    "effort": "中",
                    "description": "複数GPUで線形スループットスケーリング"
                },
                {
                    "name": "CUDAグラフキャプチャ",
                    "impact": "中",
                    "effort": "高",
                    "description": "カーネル起動オーバーヘッドを排除"
                },
            ]
        }
    ]

10.2 エンドツーエンドのプロダクションセットアップ

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from fastapi import FastAPI
from pydantic import BaseModel
import asyncio
import uvicorn

app = FastAPI(title="LLM推論API")

class GenerateRequest(BaseModel):
    prompt: str
    max_tokens: int = 200
    temperature: float = 0.7
    top_p: float = 0.95
    stream: bool = False

class GenerateResponse(BaseModel):
    text: str
    tokens_generated: int
    finish_reason: str

# グローバルエンジンインスタンス
engine: AsyncLLMEngine = None

def create_optimized_engine(model_name: str, **kwargs) -> AsyncLLMEngine:
    """プロダクション最適化されたvLLMエンジンを作成"""

    engine_args = AsyncEngineArgs(
        model=model_name,
        tensor_parallel_size=kwargs.get("tp_size", 1),
        gpu_memory_utilization=kwargs.get("gpu_util", 0.90),
        max_model_len=kwargs.get("max_model_len", 4096),
        quantization=kwargs.get("quantization", None),  # "awq" または "gptq"
        dtype="auto",
        max_num_seqs=kwargs.get("max_num_seqs", 256),
        enable_prefix_caching=True,
        use_v2_block_manager=True,
        speculative_model=kwargs.get("draft_model", None),  # オプションのドラフトモデル
        num_speculative_tokens=kwargs.get("num_spec_tokens", 5),
    )

    return AsyncLLMEngine.from_engine_args(engine_args)

@app.on_event("startup")
async def startup():
    global engine
    engine = create_optimized_engine(
        model_name="meta-llama/Llama-2-7b-hf",
        tp_size=1,
        gpu_util=0.90,
        max_model_len=4096,
    )

@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
    sampling_params = SamplingParams(
        temperature=request.temperature,
        top_p=request.top_p,
        max_tokens=request.max_tokens,
    )

    request_id = f"req_{id(request)}"

    final_output = None
    async for output in engine.generate(request.prompt, sampling_params, request_id):
        final_output = output

    if final_output and final_output.outputs:
        result = final_output.outputs[0]
        return GenerateResponse(
            text=result.text,
            tokens_generated=len(result.token_ids),
            finish_reason=result.finish_reason or "length"
        )

    return GenerateResponse(text="", tokens_generated=0, finish_reason="error")

# 実行: uvicorn script:app --host 0.0.0.0 --port 8000 --workers 1

まとめ

LLM推論最適化は階層的なアプローチが必要です。

重要なポイント:

  1. KVキャッシュを理解する: 2 * レイヤー数 * kvヘッド数 * d_head * seq_len * dtype_bytesを暗記する。GQA/MQAでKVキャッシュを4-8x削減。

  2. PagedAttention: vLLMのコアイノベーション — OSの仮想メモリから着想を得てKVキャッシュの断片化を排除。

  3. 連続バッチング: 完了と同時に新しいリクエストをすぐに挿入し、GPU利用率を最大化。

  4. 投機的デコーディング: 小さいドラフトモデル + 大きな検証モデル = 適切な受け入れ率で2-3xスピードアップ。

  5. FlashAttention: アテンションのメモリをO(N^2)からO(N)に削減し、長いコンテキストを可能に。

プロダクションデプロイの推奨:

  • 小規模サービス: vLLM + AWQ 4ビット + プレフィックスキャッシング
  • 大規模サービス: TensorRT-LLMまたはvLLM + テンソル並列
  • 最低レイテンシ: 投機的デコーディング + CUDAグラフ

参考文献