Skip to content
Published on

Speculative DecodingでLLM推論を2〜3倍高速化:原理から実践実装まで

Authors
  • Name
    Twitter

1. LLM推論の根本的なボトルネック

LLMのAutoregressiveデコーディングは本質的にシリアルである:

トークン1生成 → トークン2生成 → トークン3生成 → ...
      ↓              ↓              ↓
  モデル全体       モデル全体      モデル全体
  Forward Pass    Forward Pass   Forward Pass

各トークン生成時に70Bモデル全体のForward Passが必要であり、このプロセスはMemory-Bandwidth Boundである。GPUの演算能力は余っているが、メモリ帯域幅がボトルネックとなる。

1.1 算術強度分析

70Bモデル、FP16- モデルサイズ:〜140GB
- 1トークン生成:140GBのメモリ読み取り
- A100 80GBメモリ帯域幅:2TB/s
- 理論的最大値:2000/14014 tokens/s

実際にはKV Cacheアクセス等で〜10 tokens/s
GPU演算活用率:12%

核心的な洞察:1トークンでもKトークンでもモデルの重みを読み取るコストは同じである。1回の読み取りで複数トークンを処理すれば効率が向上する。

2. Speculative Decodingの原理

2.1 基本アイデア

小さなDraftモデルがK個のトークンを素早く提案し、大きなTargetモデルが1回のForward PassでK個を同時に検証する:

Draft Model (1B):  t1 → t2 → t3 → t4 → t5  (素早く5個提案)
                    ↓    ↓    ↓    ↓    ↓
Target Model (70B): ✅   ✅   ✅   ❌   -    (一括検証)
                              t3' 再生成    (棄却後補正)

結果: [t1, t2, t3, t3'] → 70Bの1回のForward Passで4トークン生成!

2.2 数学的保証:Rejection Sampling

Speculative Decodingの核心は、出力分布がTargetモデルと正確に同一であるという数学的保証である。

Draftモデル分布 q(x)q(x)、Targetモデル分布 p(x)p(x) において:

受容確率:

α(x)=min(1,p(x)q(x))\alpha(x) = \min\left(1, \frac{p(x)}{q(x)}\right)

棄却時の補正分布:

p(x)=norm(max(0,p(x)q(x)))p'(x) = \text{norm}\left(\max(0, p(x) - q(x))\right)

このプロセスを経ると、最終出力分布は正確に p(x)p(x) となる。

import torch

def speculative_decode(draft_model, target_model, input_ids, K=5):
    """Speculative Decodingコアアルゴリズム"""

    # 1) DraftモデルでK個のトークンを生成
    draft_tokens = []
    draft_probs = []
    current = input_ids.clone()

    for _ in range(K):
        logits = draft_model(current).logits[:, -1]
        probs = torch.softmax(logits, dim=-1)
        token = torch.multinomial(probs, 1)
        draft_tokens.append(token)
        draft_probs.append(probs.gather(-1, token))
        current = torch.cat([current, token], dim=-1)

    # 2) Targetモデルで一括検証
    all_tokens = torch.cat([input_ids] + draft_tokens, dim=-1)
    target_logits = target_model(all_tokens).logits

    # 3) Rejection Sampling
    accepted = []
    n = input_ids.shape[-1]

    for i in range(K):
        target_prob = torch.softmax(target_logits[:, n+i-1], dim=-1)
        p_target = target_prob.gather(-1, draft_tokens[i])
        q_draft = draft_probs[i]

        # 受容確率
        accept_prob = torch.min(
            torch.ones_like(p_target),
            p_target / q_draft
        )

        if torch.rand(1) < accept_prob:
            accepted.append(draft_tokens[i])
        else:
            # 棄却:補正分布から新しいトークンをサンプリング
            residual = torch.clamp(target_prob -
                torch.softmax(draft_model(all_tokens[:, :n+i]).logits[:, -1], dim=-1),
                min=0)
            residual = residual / residual.sum(dim=-1, keepdim=True)
            new_token = torch.multinomial(residual, 1)
            accepted.append(new_token)
            break
    else:
        # 全て受容された場合のボーナストークン
        bonus = torch.multinomial(
            torch.softmax(target_logits[:, n+K-1], dim=-1), 1
        )
        accepted.append(bonus)

    return torch.cat(accepted, dim=-1)

2.3 受容率と速度向上

受容率 α\alpha のとき、平均生成トークン数:

E[tokens per step]=1αK+11αE[\text{tokens per step}] = \frac{1 - \alpha^{K+1}}{1 - \alpha}
Draft-Targetペア受容率 αK=5 平均トークン速度向上
GPT-2 → GPT-40.41.61.3x
Llama-68M → Llama-70B0.72.82.3x
Llama-1B → Llama-70B0.83.62.8x

3. 高度な技法

3.1 Self-Speculative Decoding(Draftモデル不要)

別途のDraftモデルなしにTargetモデル自体のEarly ExitまたはLayer Skippingを活用:

# Layer Skip方式
class SelfSpeculativeModel(nn.Module):
    def draft_forward(self, x):
        """最初の8層のみ使用して高速draft"""
        for layer in self.layers[:8]:
            x = layer(x)
        return self.lm_head(self.norm(x))

    def verify_forward(self, x):
        """全層で検証"""
        for layer in self.layers:
            x = layer(x)
        return self.lm_head(self.norm(x))

利点:別途Draftモデルのロード不要、メモリ節約

3.2 Medusa:Multi-Head Speculative Decoding

Draftモデルの代わりに複数のLM Headを追加して、複数位置のトークンを同時に予測:

          Target LM Head → t[n+1]
InputHidden States          Medusa Head 1 → t[n+2]  (予測)
          Medusa Head 2 → t[n+3]  (予測)
          Medusa Head 3 → t[n+4]  (予測)

3.3 Apple Mirror Speculative Decoding (2026)

Appleの最新研究(2026.01)。既存のSpeculative Decodingのシリアル検証ボトルネックを解決:

  • Mirror Model:Targetモデルの軽量化バージョンがDraftとVerifyを同時に実行
  • 既存:Draft → Verify → Draft → Verify(シリアル)
  • Mirror:Draft₁ + Verify₀ → Draft₂ + Verify₁ → ...(パイプライン)

4. vLLMでのSpeculative Decoding使用法

4.1 設定

from vllm import LLM, SamplingParams

# Draftモデル指定
llm = LLM(
    model="meta-llama/Llama-3.1-70B-Instruct",
    speculative_model="meta-llama/Llama-3.2-1B-Instruct",
    num_speculative_tokens=5,
    tensor_parallel_size=4,
    gpu_memory_utilization=0.9,
)

params = SamplingParams(temperature=0.7, max_tokens=512)
outputs = llm.generate(["Explain quantum computing:"], params)

4.2 ベンチマークスクリプト

# 基本デコーディング
python -m vllm.entrypoints.openai.api_server \
    --model meta-llama/Llama-3.1-70B-Instruct \
    --tensor-parallel-size 4

# Speculative Decoding
python -m vllm.entrypoints.openai.api_server \
    --model meta-llama/Llama-3.1-70B-Instruct \
    --speculative-model meta-llama/Llama-3.2-1B-Instruct \
    --num-speculative-tokens 5 \
    --tensor-parallel-size 4

4.3 TensorRT-LLMでの使用

import tensorrt_llm
from tensorrt_llm import BuildConfig

# DraftモデルとTargetモデルの同時ビルド
build_config = BuildConfig(
    max_batch_size=8,
    max_input_len=2048,
    max_seq_len=4096,
    speculative_decoding_mode="draft_tokens_external",
    max_draft_len=5,
)

5. 最適なK値の選択

import time

def find_optimal_k(draft_model, target_model, test_prompts, k_range=range(1, 11)):
    """最適なspeculativeトークン数の探索"""
    results = {}

    for k in k_range:
        start = time.time()
        total_tokens = 0

        for prompt in test_prompts:
            output = speculative_generate(
                draft_model, target_model, prompt,
                num_speculative_tokens=k, max_tokens=256
            )
            total_tokens += len(output)

        elapsed = time.time() - start
        throughput = total_tokens / elapsed
        results[k] = throughput
        print(f"K={k}: {throughput:.1f} tokens/s")

    optimal_k = max(results, key=results.get)
    print(f"\nOptimal K = {optimal_k} ({results[optimal_k]:.1f} tokens/s)")
    return optimal_k

一般的なガイドライン:

  • Draftが強いほど(受容率が高い):Kを大きく(7〜10)
  • Draftが弱いほど:Kを小さく(3〜5)
  • コード生成:K=5〜7(繰り返しパターンが多く受容率が高い)
  • 創造的テキスト:K=3〜4(多様性が高く受容率が低い)

6. 実践上の考慮事項

6.1 Draftモデル選択基準

  1. 同じトークナイザ:トークナイザが異なるとトークン整列問題が発生
  2. 同じファミリー:Llama-1B → Llama-70B(同じ学習データ、高い受容率)
  3. 適切なサイズ比率:Targetの1/50〜1/10(大きすぎるとDraftコストが増加)
  4. 高速推論:Draftはレイテンシが重要

6.2 バッチ環境での注意点

Speculative Decodingはバッチサイズが大きくなると効果が低下する:

  • バッチ内の各リクエストの受容率が異なり同期問題が発生
  • 既にcompute-boundのバッチでは追加演算が負担
  • **スループット(throughput)よりもレイテンシ(latency)**最適化に適している

7. クイズ

Q1. Speculative Decodingが出力分布を変更しない理由は?

Rejection Sampling技法のおかげ。受容確率 min(1,p(x)/q(x))\min(1, p(x)/q(x)) でサンプリングし、棄却時に補正分布 max(0,p(x)q(x))\max(0, p(x)-q(x)) から再サンプリングすれば、最終分布が正確にTarget分布 p(x)p(x) となる。

Q2. LLM推論がMemory-Bandwidth Boundである理由は?

トークン1個の生成にモデル全体の重みをメモリから読み取る必要があるが、実際の演算量(FLOP)は少ない。GPU演算能力に対してメモリ帯域幅がボトルネック。70B FP16 = 140GBをトークンごとに読み取る。

Q3. Self-Speculative Decodingの長所と短所は?

長所:別途Draftモデル不要、メモリ節約。短所:Targetモデルの一部のレイヤーのみ使用するため、専用Draftモデルより受容率が低い場合がある。

Q4. K値が大きすぎると非効率的な理由は?

受容率が指数的に減少(αK\alpha^K)するため、後半のトークンが棄却される確率が高くなる。DraftモデルのK回のForward Passコストは常に発生するため、棄却されたトークンに対するDraftコストが無駄になる。

Q5. バッチサイズが大きい場合にSpeculative Decodingの効果が低下する理由は?

(1) バッチ内の受容率が異なり同期オーバーヘッドが発生 (2) バッチが大きい場合は既にcompute-boundでGPU活用率が高い (3) Draftトークン管理のメモリオーバーヘッドが増加。

Q6. Medusa方式が既存のSpeculative Decodingと異なる点は?

別途のDraftモデルではなくMultiple LM HeadをTargetモデルに追加して、1回のForward Passで複数位置のトークンを同時に予測。追加モデルのロード不要。