Skip to content

필사 모드: 推論を速く — Speculative Decodingとスループット最適化

日本語
0%
정확도 0%
💡 왼쪽 원문을 읽으면서 오른쪽에 따라 써보세요. Tab 키로 힌트를 받을 수 있습니다.
원문 렌더가 준비되기 전까지 텍스트 가이드로 표시합니다.

はじめに

LLM推論を初めて運用してみた人が最初に驚くのは「なぜこんなに遅いのか」ということです。同じGPUで学習時には膨大な演算を注ぎ込んでいたモデルが、いざテキストを生成するときには、もどかしいほどゆっくりとトークンを吐き出します。これは実装が間違っているからではなく、decodeという作業の根本的な性質のためです。

この記事では、まずdecodeがなぜ遅いのかを明確にしたうえで、その限界を回避する代表的な手法であるspeculative decodingを詳しく扱います。続いて、Medusa、n-gram、EAGLEのような変種や、chunked prefill、prefill/decodeの分離といったシステムレベルの最適化、そして遅延とスループットのトレードオフと測定方法までを整理します。目標は「どのつまみを回せば何が速くなるのか」についての明確な感覚を身につけることです。

decodeが遅い理由

LLMはトークンを一つずつ自己回帰的に生成します。トークンを一つ作るにはモデル全体を一度forwardしなければなりません。つまりN個のトークンを作るにはforwardをN回行う必要があり、各forwardは直前のトークンに依存するため逐次的です。並列化することができません。

より根本的な問題は、各forwardがメモリバウンドであるという点です。

トークン1個生成時:

- モデル重み全体(数十GB)をメモリから読む

- KVキャッシュを読む

- 実際の計算量はトークン1個分のみ

結果: GPU演算ユニットは遊び、メモリ帯域幅がボトルネック。

演算能力がいくら高くてもメモリを読む速度が限界。

ここで重要な洞察が出てきます。decodeはメモリバウンドなので、メモリを一度読むついでにより多くの有用な仕事をすれば、ほぼタダに近い利得が得られます。バッチング(複数のリクエストを一度に)も、speculative decoding(一度に複数のトークンを検証)も、すべてこの原理を利用しています。

Speculative Decodingの原理

speculative decodingの発想は単純でありながら巧妙です。小さく速い「ドラフトモデル」が次のトークンを複数あらかじめ推測し、大きく正確な「ターゲットモデル」がその推測たちを一度のforwardで並列に検証するというものです。

通常のデコーディング (遅い):

ターゲットモデル forward -> トークン1

ターゲットモデル forward -> トークン2

ターゲットモデル forward -> トークン3

(forward 3回)

speculative decoding:

1) ドラフトモデルが素早く推測: [t1', t2', t3', t4']

2) ターゲットモデル forward 1回で4個を一度に検証

3) 前から合っているトークンは採択、最初に外れた所で停止

例) t1', t2'は採択、t3'で不一致 -> t3はターゲットが訂正

(ターゲット forward 1回で2~3個のトークンを確定可能)

核心は、検証が並列であるという点です。ターゲットモデルはK個の推測トークンを一度のforwardで同時に検査できます。どうせメモリバウンドなので重みを一度読むコストは同じなのに、そのついでに複数のトークンを処理するので利得です。推測がよく当たれば、forward一度で複数のトークンが確定します。

重要な保証は、speculative decodingが出力分布を変えないという点です。検証段階がターゲットモデルの分布をそのまま追従するように設計されているため、結果物はターゲットモデルが単独で生成したものと統計的に同一です。つまり品質の損失なしに速度だけを得ます。推測の的中率とモデルの組み合わせによって異なりますが、メモリバウンドの状況でおよそ2~3倍の速度向上が報告されています。

変種: Medusa、n-gram、EAGLE

別途のドラフトモデルを置くことが負担になりうるため、さまざまな変種が登場しました。概念だけ押さえます。

- **Medusa**: 別途のドラフトモデルなしで、ターゲットモデルに複数の追加予測ヘッドを付けます。各ヘッドが未来のトークンを同時に予測し、その候補をツリー形式で検証します。別途のモデルを管理する必要がないのが利点です。

- **n-gram (lookahead系)**: モデルなしで、これまでのテキストから頻繁に出たパターンを辞書のように活用して次のトークンを推測します。コードや繰り返しの多いテキストのようにパターンが明確な場合に効果的です。

- **EAGLE**: ドラフト段階をより精巧にしたアプローチで、モデルの中間表現(feature)のレベルで次のトークンを予測して推測の的中率を高めます。的中率が高いほど採択されるトークンが多くなり、加速効果が大きくなります。

これらの共通の目標は一つです。推測の的中率を高めて、ターゲットモデルの一度のforwardでより多くのトークンを確定することです。

システムレベルの最適化: chunked prefill

speculative decodingがdecode自体を速くするなら、システムレベルでprefillとdecodeをより上手く重ねる最適化もあります。

prefillは演算バウンド、decodeはメモリバウンドです。ところが二つの段階を同じバッチに混ぜると、互いの空いた資源を埋めることができます。問題は、長いプロンプトのprefillが一度に入ってくると、その間進行中だった他のリクエストのdecodeが止まってしまうという点です(応答遅延が跳ねる)。

chunked prefillは長いprefillを複数の断片に分割し、各ステップごとにprefill断片の一部とdecodeを一緒に処理します。

chunked prefillなし:

[長いプロンプト prefillまるごと] ... その間 他のリクエスト decode停止

-> 進行中のリクエストの遅延が跳ねる

chunked prefill適用:

ステップ1: [prefill断片A] + [reqたち decode]

ステップ2: [prefill断片B] + [reqたち decode]

ステップ3: [prefill断片C] + [reqたち decode]

-> prefillを流しながら decodeも着実に進行

こうすれば長いプロンプトが入ってきても他のユーザーのトークン生成が止まらないため、遅延が安定します。

prefill/decodeの分離 (disaggregation)

もう一歩進んで、prefillとdecodeをまったく別のGPU(または別のインスタンス)で処理する方式がdisaggregationです。

disaggregation構造:

[prefill専担ノード] --(KVキャッシュ転送)--> [decode専担ノード]

prefillノード: 演算バウンド作業に最適化、短く太く

decodeノード: メモリバウンド作業に最適化、長く続けて

二つの段階の資源の性質が異なるため、それぞれを独立に最適化しスケーリングできます。prefill負荷が集中すればprefillノードだけ増やし、長い生成が多ければdecodeノードだけ増やすという具合です。欠点はKVキャッシュをノード間で転送しなければならないコストとシステムの複雑さです。大規模サービングで資源効率を極限まで絞るときに検討する高度な手法です。

バッチング、量子化との結合

これまでの手法たちは互いに排他的ではありません。むしろ結合するときに効果が大きいです。

- **バッチング**: decodeがメモリバウンドなので、複数のリクエストをまとめて一度の重み読み込みで複数のトークンを処理します。スループットを上げる最も基本的な手段です。

- **量子化**: 重みとKVを低い精度で保存すればメモリ読み込み量が減ってdecodeが速くなります。

- **speculative decoding**: 一度のforwardで複数のトークンを確定します。

これらを一緒に使うと効果が掛け合わさる傾向があります。ただし無限に足されるわけではありません。例えばバッチがすでに大きくてGPUが演算バウンドに近づくと、speculative decodingの利得は減ります。speculative decodingはメモリバウンドのとき(つまりバッチが小さいとき)に最も効果的だからです。したがって手法の組み合わせはワークロードに合わせてバランスを取らなければなりません。

遅延 vs スループットのトレードオフ

推論最適化で最も重要な緊張関係が遅延(latency)とスループット(throughput)です。この二つはしばしば逆方向に動きます。

バッチを大きくすると:

スループット(秒あたり総トークン)増加 ↑

しかし個別リクエストの遅延も増加 ↑ (大きなバッチを待つ)

バッチを小さくすると:

個別リクエストの遅延減少 ↓

しかしスループット減少 ↓ (GPUを満たしきれない)

どちらを優先するかはサービスの性質次第です。リアルタイム対話のように一人の応答速度が重要なら遅延を優先し、大量バッチ処理のように全体のスループットが重要ならスループットを優先します。この二つを同時に最大化することはできないので、サービスの目標をまず定めてそれに合わせてつまみを回さなければなりません。

測定: TTFT、TPOT、throughput

最適化をするにはまず正しく測定しなければなりません。LLMサービングの核心指標は三つです。

TTFT (Time To First Token):

リクエストを送った後、最初のトークンが出るまでの時間。

主にprefill速度と待ち行列に左右される。

対話型UXで「応答が始まる体感速度」。

TPOT (Time Per Output Token):

最初のトークン以降、トークン一つを作るのにかかる平均時間。

主にdecode速度に左右される。

ストリーミングがどれだけ滑らかかを決める。

Throughput:

システム全体が秒あたり処理する総トークン数。

バッチングと同時実行性に左右される。コスト効率の尺度。

この三つの指標は互いに異なるものを測定します。TTFTが良くてもTPOTが悪ければ、最初の一文字だけ速く出て以降はもどかしいです。throughputが高くても大きなバッチのせいでTTFTが悪いことがあります。したがって一つの数字だけを見ず、三つの指標を一緒に見て、自分のサービスが何を優先するかに合わせて判断しなければなりません。

推測の的中率がすべてを左右する

speculative decodingの利得は「推測がどれだけ頻繁に当たるか」にほぼ全面的にかかっています。直感的に考えてみましょう。ドラフトが一度にK個を推測し、平均a個が採択されるなら、ターゲットモデルのforward一度で平均a+1個のトークンが確定します(採択されたa個 + 訂正された1個)。

加速の直感 (概念的):

ドラフトがK個推測 -> ターゲット forward 1回で検証

平均採択数をaとすれば -> forwardあたり約(a+1)トークン確定

的中率高い(a大) -> forwardあたり多くのトークン -> 大きな加速

的中率低い(a小) -> ドラフトコストだけかかって利得少ない

ここで二つのコストを一緒に見なければなりません。ドラフトモデルを回すコストと、ターゲットが検証するコストです。ドラフトが重すぎると推測自体が高くつき、軽すぎると的中率が下がります。なのでドラフトモデルは普通ターゲットよりはるかに小さいモデル(例: ターゲットの数十分の一の大きさ)を使います。また推測の長さKもつまみです。Kを大きくしすぎると後ろの推測はほとんど外れて無駄骨になり、小さすぎると一度に確定するトークンが少なくなります。

推測の長さKのトレードオフ:

K小 -> 検証は安いがforwardあたりトークン少ない

K大 -> forwardあたり潜在トークン多いが後ろはほとんど外れる

-> 適切なKはドラフト-ターゲットの整合度によって異なる

核心的な教訓は、speculative decodingが万能スイッチではないという点です。ワークロードとモデルの組み合わせによって的中率が変わり、的中率が低ければむしろ損です。オンにする前に自分のトラフィックで的中率を測定するのが正しいです。

なぜメモリバウンドのとき効果が大きいのか

speculative decodingがバッチが小さいとき(メモリバウンドのとき)に特に効果的だという点を、もう少し掘り下げる価値があります。

バッチが小さいとき (メモリバウンド):

GPU演算ユニットが多く遊んでいる

-> ターゲット forwardでK個のトークンを並列検証しても

追加の演算コストがほとんど感じられない (どうせ遊んでいた資源)

-> speculative decodingの利得が大きい

バッチが大きいとき (演算バウンドに近い):

GPU演算ユニットがすでに忙しい

-> K個の並列検証の追加演算が実際のコストとして迫る

-> 利得が減る

これが重要な理由は、同じシステムでもトラフィックの状況によってspeculative decodingの価値が変わるという意味だからです。暇な時間帯(小さなバッチ)には大きな利得を、混む時間帯(大きなバッチ)には小さな利得を与えます。一部のシステムはこの点を利用して、バッチサイズに応じてspeculative decodingを動的にオン・オフしたりもします。

バッチサイズのスケジューリングと動作点

遅延とスループットのトレードオフを実際に扱う方法は、バッチサイズと待機ポリシーを調整することです。リクエストが入ってきたときにすぐ処理するか、少し集めてより大きなバッチを作るかがつまみです。

待機(バッチング)ポリシー:

即時処理 -> 遅延最小、しかしGPUを満たしきれずスループット低い

少し集めて処理 -> スループット高い、しかし集める時間だけ遅延追加

運用点を見つける:

許容可能なTTFT/TPOTの上限をSLAとして定め

-> その中でバッチを最大限大きくしてスループット最大化

核心は「SLAをまず定めて、その限度の中でスループットを最大化する」順序です。遅延の上限なしにスループットだけ追えばユーザー体験が崩れ、スループットを無視して遅延だけ追えばGPUコストが爆増します。二つの指標のバランス点を見つけるのがサービングエンジニアリングの核心的な作業です。

測定ツールと負荷テスト

指標を正しく測定するには現実的な負荷を真似なければなりません。合成負荷を使うとき最もよくある間違いは、すべてのリクエストの入力/出力の長さを同じにすることです。実際のトラフィックは長さの分布が広く、この分布がcontinuous batchingの効率を左右します。

現実的な負荷テストのチェックリスト:

1) 入力長の分布を実際と類似させる (短いものから長いものまで)

2) 出力長の分布も多様に

3) 同時実行性を段階的に上げながら測定

4) p50だけでなくp95/p99のテール遅延を確認

5) 十分長く回してウォームアップ以降の定常状態を測定

特にテール遅延(p95、p99)が重要です。平均は良く見えても、一部のユーザーが非常に長い応答遅延を経験していることがあります。長いプロンプトが他のリクエストを塞いだり、待ち行列がたまに滞ったりするとテールが長くなります。平均だけを見るとこうした問題を見逃します。

指標を見る順序:

スループットで「容量」を見て

-> TTFT/TPOTのp50で「普通の体験」を見て

-> p95/p99で「最悪の体験」を見る

三つの層をすべて見てこそ本当の状態が見える

落とし穴とトラブルシューティング

- **speculative decodingをオンにしたのに遅くなった**: 推測の的中率が低すぎるか、ドラフトモデルが重すぎる場合です。的中率とドラフトコストのバランスを点検してください。バッチがすでに大きくて演算バウンドなら利得は少ないです。

- **TTFTがばらつく**: 長いプロンプトのprefillが他のリクエストを塞いでいる場合です。chunked prefillを検討してください。

- **スループットを上げたらユーザーの不満が増えた**: バッチを大きくして遅延が悪化したのです。TPOTとTTFTを一緒に見てバランスを取り直してください。

- **ベンチマーク数値と実際が違う**: 合成負荷の入出力長の分布が実際と異なります。実際のトラフィックのサンプルで測定してください。

- **disaggregationを入れたのに複雑になっただけ**: 規模が十分に大きくなければKV転送コストと複雑さが利得を相殺します。本当に必要な規模かをまず吟味してください。

おわりに

decodeが遅いのはメモリバウンドという根本的な性質のためであり、推論高速化のほぼすべての手法はこの事実から出発します。speculative decodingは一度のforwardで複数のトークンを確定し、メモリを一度読むコストを倹約に使います。Medusa、n-gram、EAGLEは推測の的中率を高める変種であり、chunked prefillとdisaggregationはシステムレベルで資源をより上手く重ね、分けます。

これらすべては結局、遅延とスループットの間のバランスの上に置かれます。正解はサービスの目標によって変わり、TTFT、TPOT、throughputを一緒に測定してこそ初めて正しいつまみを回すことができます。推論高速化は華やかな手法の羅列ではなく、ボトルネックを正確に見て、それに合った道具を選ぶ仕事です。

参考資料

- [vLLM公式ドキュメント](https://docs.vllm.ai/)

- [vLLM GitHub](https://github.com/vllm-project/vllm)

- [SGLang GitHub](https://github.com/sgl-project/sglang)

- [TensorRT-LLM GitHub](https://github.com/NVIDIA/TensorRT-LLM)

- [Hugging Faceドキュメント](https://huggingface.co/docs)

- [PyTorch](https://pytorch.org/)

- [Attention Is All You Need (arXiv:1706.03762)](https://arxiv.org/abs/1706.03762)

- [FlashAttention (arXiv:2205.14135)](https://arxiv.org/abs/2205.14135)

현재 단락 (1/133)

LLM推論を初めて運用してみた人が最初に驚くのは「なぜこんなに遅いのか」ということです。同じGPUで学習時には膨大な演算を注ぎ込んでいたモデルが、いざテキストを生成するときには、もどかしいほどゆっくり...

작성 글자: 0원문 글자: 7,815작성 단락: 0/133