- Authors

- Name
- Youngju Kim
- @fjvbn20031
はじめに
大規模言語モデル(LLM)をプロダクションにデプロイすると、すぐに課題に直面します。推論速度とコストです。GPT-4クラスのモデルへの最初のクエリは数秒かかり、同時ユーザーが増えるとスループットが急激に低下します。
このガイドでは、LLM推論最適化のコアテクニックを徹底解説します。KVキャッシュの内部構造からPagedAttention、投機的デコーディング、FlashAttention、そして最新のvLLMとTensorRT-LLMエンジンまで、使い方だけでなく、なぜそれが機能するのかを理解します。
1. LLM推論パイプラインの理解
1.1 2つのフェーズ: PrefillとDecode
LLMのテキスト生成は2つの明確なフェーズに分かれます。
Prefillフェーズ(プロンプト処理)
- 入力プロンプトの全トークンを同時に処理(並列)
- 各レイヤーでKey/Valueキャッシュを生成して保存
- Compute-Bound: GPUの計算スループットがボトルネック
- TTFT(Time To First Token)に直接影響
Decodeフェーズ(トークン生成)
- 自己回帰的に1トークンずつ生成
- 以前に生成した全トークンのKVキャッシュを参照
- Memory-Bandwidth-Bound: HBMの読み取り速度がボトルネック
- TPOT(Time Per Output Token)に直接影響
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
def measure_prefill_decode_time(model, tokenizer, prompt: str, max_new_tokens: int = 100):
"""PrefillとDecodeフェーズのタイミングを計測"""
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
input_len = inputs["input_ids"].size(1)
# TTFTを計測
torch.cuda.synchronize()
prefill_start = time.perf_counter()
with torch.no_grad():
first_output = model.generate(
**inputs,
max_new_tokens=1,
do_sample=False
)
torch.cuda.synchronize()
ttft = time.perf_counter() - prefill_start
# 全体の生成を計測
torch.cuda.synchronize()
total_start = time.perf_counter()
with torch.no_grad():
full_output = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False
)
torch.cuda.synchronize()
total_time = time.perf_counter() - total_start
output_tokens = full_output.size(1) - input_len
decode_time = total_time - ttft
tpot = decode_time / max(output_tokens - 1, 1)
print(f"入力トークン数: {input_len}")
print(f"出力トークン数: {output_tokens}")
print(f"TTFT(最初のトークンレイテンシ): {ttft * 1000:.1f} ms")
print(f"TPOT(トークンあたりの時間): {tpot * 1000:.1f} ms")
print(f"スループット: {output_tokens / total_time:.1f} tokens/sec")
return ttft, tpot
1.2 メモリ帯域幅の分析
Decodeフェーズがメモリバウンドになる理由を理解します:
def analyze_memory_bandwidth():
"""LLM推論のメモリ帯域幅分析"""
# 例: Llama-2-7Bの設定
model_params = {
"num_layers": 32,
"hidden_size": 4096,
"num_heads": 32,
"head_dim": 128,
"vocab_size": 32000,
}
dtype_bytes = 2 # FP16: 2バイト
# 重みメモリ
attn_weight = 4 * model_params["hidden_size"] ** 2 # Q, K, V, Oプロジェクション
ffn_weight = 8 * model_params["hidden_size"] ** 2 # SwiGLU: Up, Gate, Down
layer_weight = (attn_weight + ffn_weight) * dtype_bytes
total_weight_bytes = layer_weight * model_params["num_layers"]
total_weight_gb = total_weight_bytes / 1e9
print(f"モデル重み: {total_weight_gb:.2f} GB")
# KVキャッシュメモリ(シーケンス長に比例)
seq_len = 2048
kv_cache_per_token = (
2 *
model_params["num_layers"] *
model_params["num_heads"] *
model_params["head_dim"] *
dtype_bytes
)
kv_cache_total = kv_cache_per_token * seq_len / 1e6
print(f"KVキャッシュ({seq_len}トークン): {kv_cache_total:.2f} MB")
print(f"トークンあたりのKVキャッシュ: {kv_cache_per_token} バイト")
# A100のメモリ帯域幅: 2 TB/s
memory_bandwidth_tbs = 2.0
# Decode: 重みはトークンごとに1回読み込まれる
memory_bound_tps = memory_bandwidth_tbs * 1e12 / total_weight_bytes
# A100 FP16: 312 TFLOPS
compute_throughput = 312e12
flops_per_token = 2 * total_weight_bytes
compute_bound_tps = compute_throughput / flops_per_token
print(f"\nbatch_size=1の場合:")
print(f"メモリバウンドスループット: {memory_bound_tps:.1f} tokens/sec")
print(f"コンピュートバウンドスループット: {compute_bound_tps:.1f} tokens/sec")
print(f"実際のボトルネック: {'メモリ' if memory_bound_tps < compute_bound_tps else 'コンピュート'}")
analyze_memory_bandwidth()
1.3 推論コスト分析
def estimate_inference_cost(
model_size_b: float,
tokens_per_request: int,
requests_per_day: int,
gpu_cost_per_hour: float = 3.0 # A100の時間単価(USD)
):
"""推論コストを推定"""
# 経験的スループット推定
# 7Bモデル: ~100 tok/s、70Bモデル: ~20 tok/s (A100)
throughput_tps = 100 / (model_size_b / 7) ** 0.6
total_tokens_per_day = tokens_per_request * requests_per_day
seconds_needed = total_tokens_per_day / throughput_tps
hours_needed = seconds_needed / 3600
daily_cost = hours_needed * gpu_cost_per_hour
cost_per_1k_tokens = daily_cost / (total_tokens_per_day / 1000)
print(f"モデル: {model_size_b}Bパラメータ")
print(f"日次リクエスト数: {requests_per_day:,}")
print(f"リクエストあたりトークン数: {tokens_per_request}")
print(f"日次総トークン数: {total_tokens_per_day:,}")
print(f"推定スループット: {throughput_tps:.1f} tokens/sec")
print(f"必要GPU時間: {hours_needed:.2f}")
print(f"日次コスト: ${daily_cost:.2f}")
print(f"1Kトークンあたりのコスト: ${cost_per_1k_tokens:.4f}")
estimate_inference_cost(
model_size_b=7.0,
tokens_per_request=500,
requests_per_day=10000
)
2. KVキャッシュ: コアとなる最適化
2.1 KVキャッシュが必要な理由
Transformerのアテンションでは、各トークンは全ての過去のトークンに対してアテンションを計算します。既に処理済みのトークンの再計算を避けるために、KとVの行列をキャッシュします。
import torch
import torch.nn as nn
import math
class MultiHeadAttentionWithKVCache(nn.Module):
"""KVキャッシュをサポートするマルチヘッドアテンション"""
def __init__(self, d_model: int, num_heads: int, max_seq_len: int = 4096):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_head = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
# KVキャッシュの初期化
self.register_buffer(
'k_cache',
torch.zeros(1, max_seq_len, num_heads, self.d_head)
)
self.register_buffer(
'v_cache',
torch.zeros(1, max_seq_len, num_heads, self.d_head)
)
self.cache_pos = 0
def forward(self, x: torch.Tensor, use_cache: bool = True, position: int = None):
batch_size, seq_len, _ = x.shape
q = self.W_q(x).reshape(batch_size, seq_len, self.num_heads, self.d_head)
k = self.W_k(x).reshape(batch_size, seq_len, self.num_heads, self.d_head)
v = self.W_v(x).reshape(batch_size, seq_len, self.num_heads, self.d_head)
if use_cache:
start_pos = self.cache_pos if position is None else position
self.k_cache[:, start_pos:start_pos + seq_len] = k
self.v_cache[:, start_pos:start_pos + seq_len] = v
if position is None:
self.cache_pos += seq_len
total_len = self.cache_pos if position is None else start_pos + seq_len
k = self.k_cache[:, :total_len]
v = self.v_cache[:, :total_len]
scale = math.sqrt(self.d_head)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale
attn_weights = torch.softmax(attn_scores, dim=-1)
output = torch.matmul(attn_weights, v)
output = output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
output = self.W_o(output)
return output
def clear_cache(self):
self.k_cache.zero_()
self.v_cache.zero_()
self.cache_pos = 0
2.2 KVキャッシュのメモリ計算
def calculate_kv_cache_memory(
model_config: dict,
batch_size: int,
seq_len: int,
dtype_bytes: int = 2 # FP16
) -> dict:
"""KVキャッシュのメモリ使用量を計算"""
num_layers = model_config["num_layers"]
num_kv_heads = model_config.get("num_kv_heads", model_config["num_heads"])
head_dim = model_config["head_dim"]
# 計算式: 2(K+V) * レイヤー数 * kvヘッド数 * ヘッド次元 * シーケンス長 * バッチ * dtype
kv_cache_bytes = (
2 *
num_layers *
num_kv_heads *
head_dim *
seq_len *
batch_size *
dtype_bytes
)
return {
"kv_cache_bytes": kv_cache_bytes,
"kv_cache_mb": kv_cache_bytes / 1e6,
"kv_cache_gb": kv_cache_bytes / 1e9,
"per_token_bytes": kv_cache_bytes // seq_len,
}
models = {
"Llama-2-7B (MHA)": {
"num_layers": 32, "num_heads": 32,
"num_kv_heads": 32, "head_dim": 128
},
"Llama-2-70B (GQA)": {
"num_layers": 80, "num_heads": 64,
"num_kv_heads": 8, "head_dim": 128
},
"Mistral-7B (GQA)": {
"num_layers": 32, "num_heads": 32,
"num_kv_heads": 8, "head_dim": 128
},
}
print("KVキャッシュメモリ (batch=1, seq=4096)")
print("=" * 65)
for name, config in models.items():
result = calculate_kv_cache_memory(config, batch_size=1, seq_len=4096)
print(f"{name:<25} {result['kv_cache_gb']:.2f} GB "
f"({result['per_token_bytes']:,} bytes/token)")
2.3 グループクエリアテンション(GQA)
GQAはKVキャッシュを削減するキーテクニックです。複数のクエリヘッドがより少ないKVヘッドを共有します。
import torch
import torch.nn as nn
import math
class GroupedQueryAttention(nn.Module):
"""グループクエリアテンション(GQA)の実装"""
def __init__(self, d_model: int, num_q_heads: int, num_kv_heads: int):
super().__init__()
assert num_q_heads % num_kv_heads == 0
self.d_model = d_model
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.num_groups = num_q_heads // num_kv_heads
self.d_head = d_model // num_q_heads
self.W_q = nn.Linear(d_model, num_q_heads * self.d_head, bias=False)
self.W_k = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
self.W_v = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
def forward(self, x: torch.Tensor, kv_cache=None):
batch_size, seq_len, _ = x.shape
q = self.W_q(x).reshape(batch_size, seq_len, self.num_q_heads, self.d_head)
k = self.W_k(x).reshape(batch_size, seq_len, self.num_kv_heads, self.d_head)
v = self.W_v(x).reshape(batch_size, seq_len, self.num_kv_heads, self.d_head)
if kv_cache is not None:
k = torch.cat([kv_cache["k"], k], dim=1)
v = torch.cat([kv_cache["v"], v], dim=1)
new_kv_cache = {"k": k, "v": v}
q = q.transpose(1, 2) # [B, Q_heads, S, d_head]
k = k.transpose(1, 2) # [B, KV_heads, T, d_head]
v = v.transpose(1, 2)
# GQA: KVヘッドをQヘッドに合わせて拡張
k = k.repeat_interleave(self.num_groups, dim=1)
v = v.repeat_interleave(self.num_groups, dim=1)
scale = math.sqrt(self.d_head)
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale
attn_weights = torch.softmax(attn_scores, dim=-1)
output = torch.matmul(attn_weights, v)
output = output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
return self.W_o(output), new_kv_cache
2.4 DeepSeek MLA(マルチヘッド潜在アテンション)
DeepSeek-V2で導入されたMLAは、KVキャッシュを低次元の潜在ベクトルに圧縮します。
class MultiHeadLatentAttention(nn.Module):
"""
DeepSeek MLA — KVキャッシュを低ランク潜在ベクトルに圧縮
コアアイデア:
- 高次元のK,Vを保存するのではなく、低次元の潜在c_kvを保存
- c_kvからアッププロジェクションでK, Vを復元
- KVキャッシュサイズ: num_layers * kv_lora_rank * seq_len
(標準: 2 * num_layers * num_kv_heads * d_head * seq_len と比較)
"""
def __init__(
self,
d_model: int = 5120,
num_heads: int = 128,
kv_lora_rank: int = 512,
qk_nope_head_dim: int = 128,
qk_rope_head_dim: int = 64,
v_head_dim: int = 128,
):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.kv_lora_rank = kv_lora_rank
# Qプロジェクション(LoRAスタイル)
self.q_a_proj = nn.Linear(d_model, 1536, bias=False)
self.q_b_proj = nn.Linear(
1536,
num_heads * (qk_nope_head_dim + qk_rope_head_dim),
bias=False
)
# KVダウンプロジェクション: d_model -> kv_lora_rank
# これだけがKVキャッシュに保存される!
self.kv_a_proj = nn.Linear(
d_model,
kv_lora_rank + qk_rope_head_dim,
bias=False
)
# KVアッププロジェクション: kv_lora_rank -> K, V
self.kv_b_proj = nn.Linear(
kv_lora_rank,
num_heads * (qk_nope_head_dim + v_head_dim),
bias=False
)
self.o_proj = nn.Linear(num_heads * v_head_dim, d_model, bias=False)
def forward(self, x: torch.Tensor, compressed_kv_cache=None):
batch_size, seq_len, _ = x.shape
# KV圧縮(この結果のみキャッシュに入る)
kv_compressed = self.kv_a_proj(x) # [B, S, kv_lora_rank + rope_dim]
if compressed_kv_cache is not None:
kv_compressed_total = torch.cat([compressed_kv_cache, kv_compressed], dim=1)
else:
kv_compressed_total = kv_compressed
# キャッシュされた圧縮表現から完全なK, Vを復元
kv_content = kv_compressed_total[:, :, :self.kv_lora_rank]
kv_full = self.kv_b_proj(kv_content)
return None, kv_compressed
3. PagedAttention: vLLMのコアイノベーション
3.1 従来のKVキャッシュの問題
従来のLLMサービングシステムはリクエストごとに最大シーケンス長のメモリを事前割り当てします:
リクエスト1: [PROMPT=200トークン] [KV_CACHE=最大1848トークン予約済み] → 内部断片化
リクエスト2: [PROMPT=100トークン] [KV_CACHE=1948トークン予約済み]
リクエスト3: 待機中(外部断片化)
これにより60-80%のGPUメモリが無駄になります。
3.2 PagedAttention: 仕組み
OSの仮想メモリに着想を得て、PagedAttentionは固定サイズの物理ブロックでKVキャッシュを管理します。
from dataclasses import dataclass, field
from typing import Dict, List, Optional
import torch
@dataclass
class PhysicalBlock:
"""物理メモリブロック"""
block_id: int
block_size: int # ブロックあたりのトークン数(例: 16)
ref_count: int = 0
@dataclass
class LogicalBlock:
"""論理ブロック(リクエストにマッピングされる)"""
physical_block_id: int
num_filled: int = 0
class PagedKVCacheManager:
"""PagedAttention KVキャッシュマネージャー"""
def __init__(
self,
num_physical_blocks: int,
block_size: int,
num_layers: int,
num_kv_heads: int,
head_dim: int,
device: str = "cuda"
):
self.block_size = block_size
self.num_layers = num_layers
self.free_blocks: List[int] = list(range(num_physical_blocks))
self.all_blocks: Dict[int, PhysicalBlock] = {
i: PhysicalBlock(block_id=i, block_size=block_size)
for i in range(num_physical_blocks)
}
self.block_tables: Dict[int, List[LogicalBlock]] = {}
# 実際のKVキャッシュテンソルプール
self.kv_cache = torch.zeros(
num_physical_blocks, 2, num_layers, block_size, num_kv_heads, head_dim,
dtype=torch.float16,
device=device
)
def allocate_blocks_for_request(self, request_id: int, num_tokens: int):
"""リクエスト用のブロックを割り当て"""
num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size
if len(self.free_blocks) < num_blocks_needed:
raise RuntimeError(
f"OOM: {num_blocks_needed}ブロック必要、{len(self.free_blocks)}ブロック利用可能"
)
logical_blocks = []
for _ in range(num_blocks_needed):
physical_id = self.free_blocks.pop(0)
self.all_blocks[physical_id].ref_count = 1
logical_blocks.append(LogicalBlock(physical_block_id=physical_id))
self.block_tables[request_id] = logical_blocks
print(f"リクエスト {request_id}: {num_blocks_needed}ブロック割り当て済み、"
f"{len(self.free_blocks)}ブロック残り")
def free_request(self, request_id: int):
"""リクエスト完了後にブロックを解放"""
if request_id in self.block_tables:
for logical_block in self.block_tables[request_id]:
phys_id = logical_block.physical_block_id
self.all_blocks[phys_id].ref_count -= 1
if self.all_blocks[phys_id].ref_count == 0:
self.free_blocks.append(phys_id)
del self.block_tables[request_id]
# デモ
manager = PagedKVCacheManager(
num_physical_blocks=1000,
block_size=16,
num_layers=32,
num_kv_heads=32,
head_dim=128
)
manager.allocate_blocks_for_request(request_id=1, num_tokens=200)
manager.allocate_blocks_for_request(request_id=2, num_tokens=500)
manager.allocate_blocks_for_request(request_id=3, num_tokens=100)
manager.free_request(request_id=1)
print(f"\nリクエスト1完了後 — 空きブロック: {len(manager.free_blocks)}")
4. 連続バッチング
4.1 静的バッチングの問題
静的バッチングはバッチ内の全リクエストが完了するまで次のバッチを開始しません:
t=0: [リクエストA: 500トークン] [リクエストB: 100トークン] [リクエストC: 300トークン]
t=1: リクエストB完了 — AとCが終わるまで新規リクエストを開始できない
t=2: リクエストC完了
t=3: リクエストA完了 → ここでやっと新しいバッチを開始できる
4.2 連続バッチング(イテレーションレベルスケジューリング)
from dataclasses import dataclass, field
from typing import List, Tuple
import time
@dataclass
class Request:
"""推論リクエスト"""
request_id: str
input_ids: List[int]
max_new_tokens: int
generated_ids: List[int] = field(default_factory=list)
is_finished: bool = False
class ContinuousBatchingScheduler:
"""連続バッチングスケジューラー"""
def __init__(self, max_batch_size: int = 32, max_seq_len: int = 4096):
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
self.waiting_queue: List[Request] = []
self.running_requests: List[Request] = []
self.finished_requests: List[Request] = []
def add_request(self, request: Request):
self.waiting_queue.append(request)
def schedule_iteration(self) -> Tuple[List[Request], List[str]]:
"""
1イテレーションのバッチをスケジュール
戻り値:
(アクティブなリクエスト, 完了したリクエストIDのリスト)
"""
completed_ids = []
still_running = []
for req in self.running_requests:
if req.is_finished:
self.finished_requests.append(req)
completed_ids.append(req.request_id)
else:
still_running.append(req)
self.running_requests = still_running
# 空いたスロットに待機中のリクエストをすぐに追加
while self.waiting_queue and self._can_add_request(self.waiting_queue[0]):
new_request = self.waiting_queue.pop(0)
self.running_requests.append(new_request)
print(f"リクエスト {new_request.request_id} をバッチに追加 "
f"(バッチサイズ: {len(self.running_requests)})")
return self.running_requests, completed_ids
5. 投機的デコーディング
5.1 アイデア: ドラフト + 検証
投機的デコーディングのコア: 小さいドラフトモデルが複数のトークンを並列で生成し、大きいターゲットモデルが全てを一度に(Prefillのように)検証します。
標準: [大モデル] → token1 → token2 → token3 → token4 → token5
投機的: [小モデル] → (t1, t2, t3, t4, t5) を並列で生成
[大モデル] → 5トークンを一度に検証(並列、Prefillのように)
受け入れられたトークンのみ保持
5.2 受け入れ率とスピードアップ分析
import torch
from typing import List, Tuple
def speculative_decode_step(
draft_model,
target_model,
input_ids: torch.Tensor,
draft_steps: int = 4,
temperature: float = 1.0
) -> Tuple[torch.Tensor, int, int]:
"""
投機的デコーディングの1ステップ
戻り値:
(生成トークン, 受け入れ数, ドラフト数)
"""
# 1. ドラフトモデルが候補トークンを生成
draft_tokens = []
draft_probs = []
current_ids = input_ids.clone()
for _ in range(draft_steps):
with torch.no_grad():
draft_logits = draft_model(current_ids).logits[:, -1, :]
draft_prob = torch.softmax(draft_logits / (temperature + 1e-8), dim=-1)
draft_token = torch.multinomial(draft_prob, num_samples=1)
draft_tokens.append(draft_token)
draft_probs.append(draft_prob)
current_ids = torch.cat([current_ids, draft_token], dim=1)
draft_sequence = torch.cat(draft_tokens, dim=1)
candidate_ids = torch.cat([input_ids, draft_sequence], dim=1)
# 2. ターゲットモデルが全ドラフトトークンを1回のパスで検証
with torch.no_grad():
target_logits = target_model(candidate_ids).logits[
:, input_ids.size(1) - 1:-1, :
]
target_probs = torch.softmax(target_logits / (temperature + 1e-8), dim=-1)
# 3. 各ドラフトトークンの受け入れ/拒否
accepted_tokens = []
num_accepted = 0
for step in range(draft_steps):
token = draft_sequence[:, step]
p_draft = draft_probs[step].gather(1, token.unsqueeze(1)).squeeze(1)
p_target = target_probs[:, step, :].gather(1, token.unsqueeze(1)).squeeze(1)
acceptance_prob = torch.clamp(p_target / (p_draft + 1e-8), max=1.0)
accepted = torch.rand_like(acceptance_prob) < acceptance_prob
if not accepted.all():
break
accepted_tokens.append(token)
num_accepted += 1
# 4. ターゲットモデルからの最後のトークン
last_logits = target_model(candidate_ids).logits[:, input_ids.size(1) + num_accepted - 1, :]
last_prob = torch.softmax(last_logits / (temperature + 1e-8), dim=-1)
if num_accepted < draft_steps:
correction = torch.clamp(last_prob - draft_probs[num_accepted], min=0)
correction = correction / (correction.sum(dim=-1, keepdim=True) + 1e-8)
last_token = torch.multinomial(correction, num_samples=1)
else:
last_token = torch.multinomial(last_prob, num_samples=1)
accepted_tokens.append(last_token.squeeze(1))
final_tokens = torch.stack(accepted_tokens, dim=1)
return final_tokens, num_accepted, draft_steps
def analyze_speedup(acceptance_rate: float, draft_steps: int = 4) -> dict:
"""受け入れ率に基づくスピードアップ分析"""
# E[受け入れトークン数] = sum_{k=0}^{K} alpha^k
expected_accepted = sum(
acceptance_rate ** k for k in range(draft_steps + 1)
)
# ドラフトモデルはターゲットモデルの1/10のサイズ
draft_model_ratio = 0.1
steps_with_speculative = draft_steps * draft_model_ratio + 1
speedup = expected_accepted / steps_with_speculative
return {
"acceptance_rate": acceptance_rate,
"expected_accepted": expected_accepted,
"speedup": speedup
}
print("受け入れ率別の投機的デコーディングスピードアップ (K=4)")
print("=" * 55)
for alpha in [0.5, 0.6, 0.7, 0.8, 0.9, 0.95]:
result = analyze_speedup(alpha, draft_steps=4)
print(f"受け入れ率 {alpha:.0%}: 期待値 {result['expected_accepted']:.2f} トークン、"
f"スピードアップ {result['speedup']:.2f}x")
5.3 Medusa: 複数のドラフトヘッド
import torch
import torch.nn as nn
class MedusaHead(nn.Module):
"""
Medusa: 単一モデルに複数のドラフトヘッドを付加
各ヘッドが将来のトークンを予測:
- ヘッド1: t+1を予測
- ヘッド2: t+2を予測
- ヘッドN: t+Nを予測
"""
def __init__(
self,
hidden_size: int,
vocab_size: int,
num_heads: int = 4,
hidden_layers: int = 1
):
super().__init__()
self.num_heads = num_heads
self.heads = nn.ModuleList([
nn.Sequential(
*[nn.Linear(hidden_size, hidden_size, bias=False),
nn.SiLU()] * hidden_layers,
nn.Linear(hidden_size, vocab_size, bias=False)
)
for _ in range(num_heads)
])
def forward(self, hidden_states: torch.Tensor):
"""
各将来の位置のロジットリストを返す
"""
return [head(hidden_states) for head in self.heads]
6. FlashAttention: メモリ効率的なアテンション
6.1 標準アテンションのHBMボトルネック
標準アテンションは中間結果をHBMに繰り返し書き込み・読み取りします:
標準アテンションのメモリ操作:
1. Q, KをHBMから読み込む → 読み込み: O(N * d)
2. S = Q @ K.T を計算 → 書き込み: O(N^2) ← ボトルネック!
3. SoftmaxのためSをHBMから読む → 読み込み: O(N^2)
4. P = softmax(S) を保存 → 書き込み: O(N^2)
5. P @ V のためPを読む → 読み込み: O(N^2)
6. 最終出力を保存 → 書き込み: O(N * d)
総HBMアクセス: O(N^2) — シーケンス長に対して二次!
6.2 FlashAttentionのタイリング戦略
import torch
import math
def flash_attention_v1(Q, K, V, block_size=64):
"""
FlashAttention v1の簡略実装
タイリングによって完全なアテンション行列のHBM保存を回避
キー: ブロックごとの処理のためのオンラインSoftmax
"""
batch_size, num_heads, seq_len, d_head = Q.shape
scale = 1.0 / math.sqrt(d_head)
Q = Q * scale
O = torch.zeros_like(Q)
L = torch.zeros(batch_size, num_heads, seq_len, 1, device=Q.device)
M = torch.full((batch_size, num_heads, seq_len, 1), float('-inf'), device=Q.device)
num_blocks = (seq_len + block_size - 1) // block_size
for j in range(num_blocks):
k_start = j * block_size
k_end = min((j + 1) * block_size, seq_len)
K_j = K[:, :, k_start:k_end, :]
V_j = V[:, :, k_start:k_end, :]
for i in range(num_blocks):
q_start = i * block_size
q_end = min((i + 1) * block_size, seq_len)
Q_i = Q[:, :, q_start:q_end, :]
O_i = O[:, :, q_start:q_end, :]
L_i = L[:, :, q_start:q_end, :]
M_i = M[:, :, q_start:q_end, :]
# SRAMでアテンションスコアを計算
S_ij = torch.matmul(Q_i, K_j.transpose(-2, -1))
# オンラインSoftmaxの更新
M_new = torch.maximum(M_i, S_ij.max(dim=-1, keepdim=True)[0])
P_ij = torch.exp(S_ij - M_new)
L_new = torch.exp(M_i - M_new) * L_i + P_ij.sum(dim=-1, keepdim=True)
# スケール変更と出力の更新
O_new = torch.exp(M_i - M_new) * O_i + torch.matmul(P_ij, V_j)
O[:, :, q_start:q_end, :] = O_new
L[:, :, q_start:q_end, :] = L_new
M[:, :, q_start:q_end, :] = M_new
return O / L
6.3 PyTorch SDPAの使用
import torch
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel
def modern_attention(q, k, v, is_causal=True, dropout_p=0.0):
"""
PyTorch 2.0以降のscaled_dot_product_attention
FlashAttention 2/3を自動で選択
"""
return F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=dropout_p,
is_causal=is_causal,
scale=None # デフォルトは1/sqrt(d_head)
)
# Flash Attentionバージョンのハイライト
flash_versions = {
"FlashAttention 1 (arXiv:2205.14135)": {
"key_innovation": "タイリング + オンラインSoftmax",
"memory": "O(N) — アテンション行列を保存しない",
"speedup": "標準比2-4x"
},
"FlashAttention 2 (arXiv:2307.08691)": {
"key_innovation": "ワーク分割、FP16/BF16",
"memory": "O(N)",
"speedup": "H100での標準比5-9x"
},
"FlashAttention 3 (arXiv:2407.08608)": {
"key_innovation": "H100特化、FP8、非同期パイプライン",
"memory": "O(N)",
"speedup": "H100でFA2比1.5-2x"
},
}
for name, info in flash_versions.items():
print(f"\n{name}")
for k, v in info.items():
print(f" {k}: {v}")
7. マルチGPU推論
7.1 テンソル並列
重み行列をGPU間で分割し、各GPUがシャードを処理します。
import torch
import torch.nn as nn
class TensorParallelLinear(nn.Module):
"""
テンソル並列Linearレイヤー(列並列)
各GPUがout_features // world_sizeの出力ニューロンを所有
"""
def __init__(
self,
in_features: int,
out_features: int,
world_size: int,
rank: int
):
super().__init__()
self.world_size = world_size
self.rank = rank
self.local_out_features = out_features // world_size
self.weight = nn.Parameter(
torch.randn(self.local_out_features, in_features) / (in_features ** 0.5)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
local_output = nn.functional.linear(x, self.weight)
# 実際の分散セットアップでは: dist.all_gather(output_list, local_output)
return local_output
def setup_vllm_multiGPU(model_name: str, tp_size: int):
"""マルチGPU vLLM推論をセットアップ"""
from vllm import LLM, SamplingParams
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size,
gpu_memory_utilization=0.9
)
return llm
7.2 vLLMの完全な使用方法
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
import asyncio
import time
def vllm_basic_usage():
"""vLLMの基本的な使用方法"""
llm = LLM(
model="meta-llama/Llama-2-7b-hf",
tensor_parallel_size=1,
gpu_memory_utilization=0.90,
max_model_len=4096,
quantization=None, # "awq", "gptq", "squeezellm"
dtype="auto",
max_num_seqs=256,
enable_prefix_caching=True,
use_v2_block_manager=True,
)
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=200,
)
prompts = [
"量子コンピューティングを簡単に説明してください",
"人工知能の未来は?",
"人間の脳はどのように機能しますか?",
]
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print(f"プロンプト: {output.prompt[:50]}...")
print(f"出力: {output.outputs[0].text[:100]}...")
print(f"トークン数: {len(output.outputs[0].token_ids)}")
print()
return outputs
8. 推論エンジンの比較
8.1 主要エンジンの特徴
| エンジン | 開発者 | 主要機能 | 最適な用途 |
|---|---|---|---|
| vLLM | UC Berkeley | PagedAttention、連続バッチング | 汎用LLMサービング |
| TGI | HuggingFace | Flash Attention 2、投機的デコーディング | HFモデルサービング |
| TensorRT-LLM | NVIDIA | NVIDIA最適化、FP8 | NVIDIAの最大パフォーマンス |
| DeepSpeed-MII | Microsoft | ZeRO推論、大規模モデル | マルチGPUの巨大モデル |
| llama.cpp | G. Gerganov | CPU最適化、GGUF | ローカル実行 |
8.2 ベンチマーク結果
import time
def run_inference_benchmark(engine_name: str, model, tokenizer, prompts, max_tokens=100):
"""シンプルな推論ベンチマーク"""
num_warmup = 5
num_runs = 50
# ウォームアップ
for prompt in prompts[:num_warmup]:
_ = model.generate(
tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
max_new_tokens=max_tokens,
do_sample=False
)
# 時間計測
import torch
torch.cuda.synchronize()
start = time.perf_counter()
total_tokens = 0
for prompt in prompts[:num_runs]:
output = model.generate(
tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
max_new_tokens=max_tokens,
do_sample=False
)
total_tokens += max_tokens
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
throughput = total_tokens / elapsed
latency_ms = elapsed / num_runs * 1000
print(f"\n{engine_name}")
print(f" スループット: {throughput:.1f} tokens/sec")
print(f" 平均レイテンシ: {latency_ms:.1f} ms/リクエスト")
return throughput, latency_ms
# 比較例(A100 80GB、Llama-2-7B、batch=1、100出力トークン)
benchmark_results = {
"HuggingFace (FP16)": {"throughput": 52, "latency_ms": 1924},
"HuggingFace (Flash Attn 2)": {"throughput": 78, "latency_ms": 1280},
"vLLM": {"throughput": 120, "latency_ms": 832},
"vLLM + AWQ 4bit": {"throughput": 165, "latency_ms": 606},
"TensorRT-LLM": {"throughput": 180, "latency_ms": 555},
}
print("LLM推論エンジンベンチマーク (A100 80GB、Llama-2-7B)")
print("=" * 65)
print(f"{'エンジン':<30} {'スループット (tok/s)':<22} {'レイテンシ (ms)':<15}")
print("-" * 65)
for engine, stats in benchmark_results.items():
print(f"{engine:<30} {stats['throughput']:<22} {stats['latency_ms']:<15}")
9. プロンプトキャッシング
9.1 プレフィックスキャッシング
同じシステムプロンプトやドキュメントが繰り返し処理される場合にKVキャッシュを再利用します。
from vllm import LLM, SamplingParams
import time
def demonstrate_prefix_caching():
"""プレフィックスキャッシングの効果を実演"""
llm = LLM(
model="meta-llama/Llama-2-7b-hf",
enable_prefix_caching=True,
max_model_len=4096,
)
# 全リクエストに共通の長いシステムプロンプト(1000以上のトークン)
system_prompt = (
"あなたはPython、機械学習、データサイエンス、クラウドコンピューティングの専門知識を持つ"
"有能なAIアシスタントです。"
) * 50
questions = [
"Pythonのループを最適化するには?",
"勾配降下法とは何ですか?",
"コンテナ化について説明してください。",
"ニューラルネットワークとは何ですか?",
]
sampling_params = SamplingParams(temperature=0.7, max_tokens=100)
cold_prompts = [f"{system_prompt}\n\n質問: {q}" for q in questions]
# コールドスタート(キャッシュなし)
cold_start = time.time()
llm.generate(cold_prompts, sampling_params)
cold_time = time.time() - cold_start
# ウォームスタート(キャッシュヒット)
warm_start = time.time()
llm.generate(cold_prompts, sampling_params)
warm_time = time.time() - warm_start
print(f"コールドスタート(キャッシュなし): {cold_time:.2f}s")
print(f"ウォームスタート(キャッシュヒット): {warm_time:.2f}s")
print(f"スピードアップ: {cold_time / warm_time:.2f}x")
10. 実践的な最適化チェックリスト
10.1 ステップごとの最適化ガイド
class LLMOptimizationChecklist:
"""LLM推論最適化チェックリスト"""
optimizations = [
{
"category": "ベースライン",
"level": 1,
"items": [
{
"name": "FP16/BF16を使用",
"impact": "高",
"effort": "低",
"description": "FP32 → FP16: 2xメモリ節約、速度向上",
},
{
"name": "Flash Attention 2を有効化",
"impact": "高",
"effort": "低",
"description": "2-4xアテンションスピードアップ、メモリ節約",
},
]
},
{
"category": "KVキャッシュ最適化",
"level": 2,
"items": [
{
"name": "GQA/MQAモデルを選択",
"impact": "高",
"effort": "中",
"description": "4-8x KVキャッシュ削減、より大きな有効バッチ"
},
{
"name": "プレフィックスキャッシング",
"impact": "中",
"effort": "低",
"description": "共通システムプロンプトのKVキャッシュを再利用"
},
]
},
{
"category": "バッチング最適化",
"level": 3,
"items": [
{
"name": "vLLMで連続バッチング",
"impact": "非常に高",
"effort": "低",
"description": "2-5xスループット向上",
},
]
},
{
"category": "モデル最適化",
"level": 4,
"items": [
{
"name": "AWQ 4ビット量子化",
"impact": "高",
"effort": "中",
"description": "4xメモリ削減、1.5-2x速度向上",
},
{
"name": "投機的デコーディング",
"impact": "中",
"effort": "高",
"description": "2-3xスピードアップ(適切なドラフトモデルが必要)"
},
]
},
{
"category": "ハードウェア最適化",
"level": 5,
"items": [
{
"name": "テンソル並列",
"impact": "非常に高",
"effort": "中",
"description": "複数GPUで線形スループットスケーリング"
},
{
"name": "CUDAグラフキャプチャ",
"impact": "中",
"effort": "高",
"description": "カーネル起動オーバーヘッドを排除"
},
]
}
]
10.2 エンドツーエンドのプロダクションセットアップ
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from fastapi import FastAPI
from pydantic import BaseModel
import asyncio
import uvicorn
app = FastAPI(title="LLM推論API")
class GenerateRequest(BaseModel):
prompt: str
max_tokens: int = 200
temperature: float = 0.7
top_p: float = 0.95
stream: bool = False
class GenerateResponse(BaseModel):
text: str
tokens_generated: int
finish_reason: str
# グローバルエンジンインスタンス
engine: AsyncLLMEngine = None
def create_optimized_engine(model_name: str, **kwargs) -> AsyncLLMEngine:
"""プロダクション最適化されたvLLMエンジンを作成"""
engine_args = AsyncEngineArgs(
model=model_name,
tensor_parallel_size=kwargs.get("tp_size", 1),
gpu_memory_utilization=kwargs.get("gpu_util", 0.90),
max_model_len=kwargs.get("max_model_len", 4096),
quantization=kwargs.get("quantization", None), # "awq" または "gptq"
dtype="auto",
max_num_seqs=kwargs.get("max_num_seqs", 256),
enable_prefix_caching=True,
use_v2_block_manager=True,
speculative_model=kwargs.get("draft_model", None), # オプションのドラフトモデル
num_speculative_tokens=kwargs.get("num_spec_tokens", 5),
)
return AsyncLLMEngine.from_engine_args(engine_args)
@app.on_event("startup")
async def startup():
global engine
engine = create_optimized_engine(
model_name="meta-llama/Llama-2-7b-hf",
tp_size=1,
gpu_util=0.90,
max_model_len=4096,
)
@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
sampling_params = SamplingParams(
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens,
)
request_id = f"req_{id(request)}"
final_output = None
async for output in engine.generate(request.prompt, sampling_params, request_id):
final_output = output
if final_output and final_output.outputs:
result = final_output.outputs[0]
return GenerateResponse(
text=result.text,
tokens_generated=len(result.token_ids),
finish_reason=result.finish_reason or "length"
)
return GenerateResponse(text="", tokens_generated=0, finish_reason="error")
# 実行: uvicorn script:app --host 0.0.0.0 --port 8000 --workers 1
まとめ
LLM推論最適化は階層的なアプローチが必要です。
重要なポイント:
-
KVキャッシュを理解する:
2 * レイヤー数 * kvヘッド数 * d_head * seq_len * dtype_bytesを暗記する。GQA/MQAでKVキャッシュを4-8x削減。 -
PagedAttention: vLLMのコアイノベーション — OSの仮想メモリから着想を得てKVキャッシュの断片化を排除。
-
連続バッチング: 完了と同時に新しいリクエストをすぐに挿入し、GPU利用率を最大化。
-
投機的デコーディング: 小さいドラフトモデル + 大きな検証モデル = 適切な受け入れ率で2-3xスピードアップ。
-
FlashAttention: アテンションのメモリをO(N^2)からO(N)に削減し、長いコンテキストを可能に。
プロダクションデプロイの推奨:
- 小規模サービス: vLLM + AWQ 4ビット + プレフィックスキャッシング
- 大規模サービス: TensorRT-LLMまたはvLLM + テンソル並列
- 最低レイテンシ: 投機的デコーディング + CUDAグラフ
参考文献
- vLLM/PagedAttention: arXiv:2309.06180
- 投機的デコーディング: arXiv:2211.17192
- FlashAttention: arXiv:2205.14135
- FlashAttention-2: arXiv:2307.08691
- Medusa: arXiv:2401.10774
- 連続バッチング: Anyscale Blog
- DeepSeek-V2 MLA: arXiv:2405.04434