Skip to content
Published on

Speculative Decoding으로 LLM 추론 2~3배 빠르게: 원리부터 실전 구현까지

Authors
  • Name
    Twitter

1. LLM 추론의 근본적 병목

LLM의 Autoregressive 디코딩은 본질적으로 시리얼하다:

토큰 1 생성 → 토큰 2 생성 → 토큰 3 생성 → ...
     ↓              ↓              ↓
  전체 모델        전체 모델       전체 모델
  Forward Pass    Forward Pass   Forward Pass

각 토큰 생성 시 70B 모델 전체를 Forward Pass해야 하고, 이 과정은 Memory-Bandwidth Bound다. GPU 연산 능력은 남지만 메모리 대역폭이 병목이 된다.

1.1 산술 강도 분석

70B 모델, FP16:
- 모델 크기: ~140GB
- 1토큰 생성: 140GB 메모리 읽기
- A100 80GB 메모리 대역폭: 2TB/s
- 이론적 최대: 2000/14014 tokens/s

실제로는 KV Cache 접근 등으로 ~10 tokens/s
GPU 연산 활용률: 1~2% 😱

핵심 통찰: 1토큰이든 K토큰이든 모델 가중치를 읽는 비용은 동일하다. 한 번 읽을 때 여러 토큰을 처리하면 효율이 올라간다.

2. Speculative Decoding 원리

2.1 기본 아이디어

작은 Draft 모델이 K개 토큰을 빠르게 제안하고, 큰 Target 모델이 한 번의 Forward Pass로 K개를 동시에 검증한다:

Draft Model (1B):  t1 → t2 → t3 → t4 → t5  (빠르게 5개 제안)
                    ↓    ↓    ↓    ↓    ↓
Target Model (70B): ✅   ✅   ✅   ❌   -    (한 번에 검증)
                              t3' 재생성    (거절 후 보정)

결과: [t1, t2, t3, t3']1번의 70B Forward Pass로 4토큰 생성!

2.2 수학적 보장: Rejection Sampling

Speculative Decoding의 핵심은 출력 분포가 Target 모델과 정확히 동일하다는 수학적 보장이다.

Draft 모델 분포 q(x)q(x), Target 모델 분포 p(x)p(x)에서:

수용 확률:

α(x)=min(1,p(x)q(x))\alpha(x) = \min\left(1, \frac{p(x)}{q(x)}\right)

거절 시 보정 분포:

p(x)=norm(max(0,p(x)q(x)))p'(x) = \text{norm}\left(\max(0, p(x) - q(x))\right)

이 과정을 거치면 최종 출력 분포는 **정확히 p(x)p(x)**가 된다.

import torch

def speculative_decode(draft_model, target_model, input_ids, K=5):
    """Speculative Decoding 핵심 알고리즘"""
    
    # 1) Draft 모델로 K개 토큰 생성
    draft_tokens = []
    draft_probs = []
    current = input_ids.clone()
    
    for _ in range(K):
        logits = draft_model(current).logits[:, -1]
        probs = torch.softmax(logits, dim=-1)
        token = torch.multinomial(probs, 1)
        draft_tokens.append(token)
        draft_probs.append(probs.gather(-1, token))
        current = torch.cat([current, token], dim=-1)
    
    # 2) Target 모델로 한 번에 검증
    all_tokens = torch.cat([input_ids] + draft_tokens, dim=-1)
    target_logits = target_model(all_tokens).logits
    
    # 3) Rejection Sampling
    accepted = []
    n = input_ids.shape[-1]
    
    for i in range(K):
        target_prob = torch.softmax(target_logits[:, n+i-1], dim=-1)
        p_target = target_prob.gather(-1, draft_tokens[i])
        q_draft = draft_probs[i]
        
        # 수용 확률
        accept_prob = torch.min(
            torch.ones_like(p_target),
            p_target / q_draft
        )
        
        if torch.rand(1) < accept_prob:
            accepted.append(draft_tokens[i])
        else:
            # 거절: 보정 분포에서 새 토큰 샘플링
            residual = torch.clamp(target_prob - 
                torch.softmax(draft_model(all_tokens[:, :n+i]).logits[:, -1], dim=-1),
                min=0)
            residual = residual / residual.sum(dim=-1, keepdim=True)
            new_token = torch.multinomial(residual, 1)
            accepted.append(new_token)
            break
    else:
        # 모두 수용 시 보너스 토큰
        bonus = torch.multinomial(
            torch.softmax(target_logits[:, n+K-1], dim=-1), 1
        )
        accepted.append(bonus)
    
    return torch.cat(accepted, dim=-1)

2.3 수용률과 속도 향상

수용률 α\alpha일 때, 평균 생성 토큰 수:

E[tokens per step]=1αK+11αE[\text{tokens per step}] = \frac{1 - \alpha^{K+1}}{1 - \alpha}
Draft-Target 쌍수용률 αK=5 평균 토큰속도 향상
GPT-2 → GPT-40.41.61.3x
Llama-68M → Llama-70B0.72.82.3x
Llama-1B → Llama-70B0.83.62.8x

3. 고급 기법들

3.1 Self-Speculative Decoding (Draft 모델 없이)

별도 Draft 모델 없이 Target 모델 자체의 Early Exit 또는 Layer Skipping을 활용:

# Layer Skip 방식
class SelfSpeculativeModel(nn.Module):
    def draft_forward(self, x):
        """처음 8개 레이어만 사용하여 빠른 draft"""
        for layer in self.layers[:8]:
            x = layer(x)
        return self.lm_head(self.norm(x))
    
    def verify_forward(self, x):
        """전체 레이어로 검증"""
        for layer in self.layers:
            x = layer(x)
        return self.lm_head(self.norm(x))

장점: Draft 모델 별도 로딩 불필요, 메모리 절약

3.2 Medusa: Multi-Head Speculative Decoding

Draft 모델 대신 여러 개의 LM Head를 추가하여 동시에 여러 위치의 토큰을 예측:

          Target LM Head → t[n+1]
InputHidden States          Medusa Head 1 → t[n+2]  (예측)
          Medusa Head 2 → t[n+3]  (예측)
          Medusa Head 3 → t[n+4]  (예측)

3.3 Apple Mirror Speculative Decoding (2026)

Apple의 최신 연구(2026.01). 기존 Speculative Decoding의 시리얼 검증 병목을 해결:

  • Mirror Model: Target 모델의 경량화 버전이 Draft와 Verify를 동시에 수행
  • 기존: Draft → Verify → Draft → Verify (시리얼)
  • Mirror: Draft₁ + Verify₀ → Draft₂ + Verify₁ → ... (파이프라인)

4. vLLM에서 Speculative Decoding 사용하기

4.1 설정

from vllm import LLM, SamplingParams

# Draft 모델 지정
llm = LLM(
    model="meta-llama/Llama-3.1-70B-Instruct",
    speculative_model="meta-llama/Llama-3.2-1B-Instruct",
    num_speculative_tokens=5,
    tensor_parallel_size=4,
    gpu_memory_utilization=0.9,
)

params = SamplingParams(temperature=0.7, max_tokens=512)
outputs = llm.generate(["Explain quantum computing:"], params)

4.2 벤치마크 스크립트

# 기본 디코딩
python -m vllm.entrypoints.openai.api_server \
    --model meta-llama/Llama-3.1-70B-Instruct \
    --tensor-parallel-size 4

# Speculative Decoding
python -m vllm.entrypoints.openai.api_server \
    --model meta-llama/Llama-3.1-70B-Instruct \
    --speculative-model meta-llama/Llama-3.2-1B-Instruct \
    --num-speculative-tokens 5 \
    --tensor-parallel-size 4

4.3 TensorRT-LLM에서 사용

import tensorrt_llm
from tensorrt_llm import BuildConfig

# Draft 모델과 Target 모델 동시 빌드
build_config = BuildConfig(
    max_batch_size=8,
    max_input_len=2048,
    max_seq_len=4096,
    speculative_decoding_mode="draft_tokens_external",
    max_draft_len=5,
)

5. 최적의 K값 선택

import time

def find_optimal_k(draft_model, target_model, test_prompts, k_range=range(1, 11)):
    """최적의 speculative token 수 탐색"""
    results = {}
    
    for k in k_range:
        start = time.time()
        total_tokens = 0
        
        for prompt in test_prompts:
            output = speculative_generate(
                draft_model, target_model, prompt,
                num_speculative_tokens=k, max_tokens=256
            )
            total_tokens += len(output)
        
        elapsed = time.time() - start
        throughput = total_tokens / elapsed
        results[k] = throughput
        print(f"K={k}: {throughput:.1f} tokens/s")
    
    optimal_k = max(results, key=results.get)
    print(f"\nOptimal K = {optimal_k} ({results[optimal_k]:.1f} tokens/s)")
    return optimal_k

일반적 가이드라인:

  • Draft가 강할수록 (수용률 높음): K를 크게 (7~10)
  • Draft가 약할수록: K를 작게 (3~5)
  • 코드 생성: K=5~7 (반복 패턴 많아 수용률 높음)
  • 창의적 텍스트: K=3~4 (다양성 높아 수용률 낮음)

6. 실전 고려사항

6.1 Draft 모델 선택 기준

  1. 같은 토크나이저: 토크나이저가 다르면 토큰 정렬 문제 발생
  2. 같은 패밀리: Llama-1B → Llama-70B (같은 학습 데이터, 높은 수용률)
  3. 적절한 크기 비율: Target의 1/50~1/10 (너무 크면 Draft 비용 증가)
  4. 빠른 추론: Draft는 지연 시간이 핵심

6.2 Batch 환경에서의 주의점

Speculative Decoding은 배치 크기가 커지면 효과가 감소한다:

  • 배치 내 각 요청의 수용률이 달라 동기화 문제
  • 이미 compute-bound인 배치에서는 추가 연산이 부담
  • 처리량(throughput) 보다 지연 시간(latency) 최적화에 더 적합

7. 퀴즈

Q1. Speculative Decoding이 출력 분포를 변경하지 않는 이유는?

Rejection Sampling 기법 덕분. 수용 확률 min(1,p(x)/q(x))\min(1, p(x)/q(x))로 샘플링하고, 거절 시 보정 분포 max(0,p(x)q(x))\max(0, p(x)-q(x))에서 재샘플링하면 최종 분포가 정확히 Target 분포 p(x)p(x)가 됨.

Q2. LLM 추론이 Memory-Bandwidth Bound인 이유는?

토큰 1개 생성에 전체 모델 가중치를 메모리에서 읽어야 하지만, 실제 연산량(FLOP)은 적음. GPU 연산 능력 대비 메모리 대역폭이 병목. 70B FP16 = 140GB를 매 토큰마다 읽음.

Q3. Self-Speculative Decoding의 장단점은?

장점: 별도 Draft 모델 불필요, 메모리 절약. 단점: Target 모델의 일부 레이어만 사용하므로 수용률이 전용 Draft 모델보다 낮을 수 있음.

Q4. K값이 너무 크면 왜 비효율적인가?

수용률이 지수적으로 감소(αK\alpha^K)하여, 후반 토큰이 거절될 확률이 높아짐. Draft 모델의 K번 Forward Pass 비용은 항상 발생하므로, 거절된 토큰에 대한 Draft 비용이 낭비.

Q5. 배치 크기가 클 때 Speculative Decoding 효과가 감소하는 이유는?

(1) 배치 내 수용률이 달라 동기화 오버헤드 (2) 배치가 크면 이미 compute-bound여서 GPU 활용률이 높음 (3) Draft 토큰 관리의 메모리 오버헤드 증가.

Q6. Medusa 방식이 기존 Speculative Decoding과 다른 점은?

별도 Draft 모델 대신 Multiple LM Head를 Target 모델에 추가하여, 한 번의 Forward Pass로 여러 위치의 토큰을 동시에 예측. 추가 모델 로딩 불필요.