Skip to content
Published on

アテンションの進化 — MQA、GQA、FlashAttention、そして長いコンテキスト

Authors

はじめに — アテンションが高価な理由

Transformerのself-attentionは強力ですが高価です。系列長Nに対してすべてのトークン対を比較するため、演算量と中間メモリがNの2乗に比例します。コンテキストが2Kから32Kへ16倍長くなると、アテンションスコア行列は256倍に膨らみます。長くなったコンテキストはそのままコストの爆発につながります。

さらに推論段階には別のボトルネックがあります。自己回帰生成はトークンを一つずつ作るため、前のトークンのKey/Valueを保存したKV cacheが系列長に比例して線形に大きくなります。長コンテキストのサービングでは、モデル重みよりKV cacheがGPUメモリを多く食う場合も珍しくありません。

本記事は二つのボトルネック — アテンションの2乗コストとKV cacheの線形増加 — を攻略する核心手法を扱います。MQA/GQAはKV cacheを削減し、FlashAttentionはアテンション演算のメモリIOを削減し、スライディングウィンドウのような手法はアテンション範囲そのものを制限します。

1. 標準アテンションのコスト分析

演算量とメモリ

標準マルチヘッドアテンションの核心コストを整理します。

記号:
B = バッチ, H = ヘッド数, N = 系列長, d_k = ヘッドあたり次元, D = H x d_k

Q·K^T スコア行列: (B, H, N, N)
 - 演算量: O(B x H x N^2 x d_k)
 - メモリ: O(B x H x N^2)   <- Nの2乗!

softmax後のV加重和: さらに O(B x H x N^2 x d_k)

問題の核心は(N, N)スコア行列です。N=8192, H=32, B=8なら、スコア行列一つだけでFP16で約34GBに達します。素朴な実装の本当のボトルネックは、この巨大な中間行列をGPUの高速メモリ(SRAM)ではなく遅いメモリ(HBM)に書いて読むことです。

推論段階のKV cache

推論ではKV cacheサイズを次のように概算します。

KV cacheバイト
 = 2 x L x N x D x B x bytes_per_element
   (2はKとV、Lはレイヤー数)

例: L=32, D=4096, N=8192, B=1, FP16
 約 4.3 GB

バッチを32に上げると 約 137 GB -> 単一GPUメモリを超過

つまりスループットのためにバッチを上げたくても、KV cacheがメモリを食い尽くしてバッチサイズが制限されます。これがMQA/GQAが登場した直接の動機です。

2. MQAとGQA — KVヘッドを共有せよ

核心アイデア

標準マルチヘッドアテンション(MHA)はH個のQueryヘッドそれぞれに対応するH個のKey/Valueヘッドを持ちます。MQA(Multi-Query Attention)は発想を逆転させます。QueryヘッドはH個のままにしつつ、Key/Valueヘッドはただ1個だけ持ち、すべてのQueryヘッドがそれを共有します。

GQA(Grouped-Query Attention)は両者の折衷案です。QueryヘッドをG個のグループに分け、各グループが一つのKVヘッドを共有します。G=1ならMQA、G=HならMHAと同じになります。

MHA  : QueryヘッドH個, KVヘッドH個   (KV cache最大)
GQA  : QueryヘッドH個, KVヘッドG個   (G < H, 折衷)
MQA  : QueryヘッドH個, KVヘッド1個   (KV cache最小)

例) H=32のとき
 MHA: KVヘッド32  ->  基準
 GQA(G=8): KVヘッド8  ->  KV cache 1/4
 MQA: KVヘッド1  ->  KV cache 1/32

なぜ機能するか

KV cacheはKVヘッド数に比例します。KVヘッドをH個からG個に減らせば、KV cacheもG/Hに縮みます。MQAのように1個まで減らすと劇的に小さくなり、より長いコンテキストやより大きなバッチを同じメモリに収められます。

代償は表現力です。KVヘッドを減らしすぎる(MQA)と品質がわずかに落ちることがあります。そこで実務ではG=8のようなGQAが、品質損失をほぼなくしつつKV cacheを大きく削減する均衡点として広く採用されます。Llama 2/3、Qwenなど多数の最新モデルがGQAを使います。

import torch
import torch.nn.functional as F

def gqa_attention(q, k, v, num_kv_groups):
    # q: (B, H, N, d_k)
    # k, v: (B, G, N, d_k)   G = num_kv_groups
    B, H, N, d_k = q.shape
    G = num_kv_groups
    rep = H // G  # 各KVヘッドを何個のQueryヘッドが共有するか

    # KVヘッドをQueryヘッド数に合わせて反復拡張
    k = k.repeat_interleave(rep, dim=1)  # (B, H, N, d_k)
    v = v.repeat_interleave(rep, dim=1)

    scores = torch.matmul(q, k.transpose(-2, -1)) / (d_k ** 0.5)
    weights = F.softmax(scores, dim=-1)
    return torch.matmul(weights, v)

比較テーブル

方式KVヘッド数KV cacheサイズ品質代表モデル
MHAH基準(最大)最高初期GPT, 原論文
GQAG (例: 8)約 G/HMHAに近いLlama 2/3, Qwen
MQA1約 1/H(最小)わずかに低下しうるPaLM, 一部の軽量モデル

3. FlashAttention — IOを削減せよ

真のボトルネックは演算ではなくメモリ移動

標準アテンションの遅さは、浮動小数点演算量そのものよりも、巨大な(N, N)スコア行列を遅いHBMに書いて読み直すメモリの往復から来ます。GPUのSRAMは非常に速いが小さく、HBMは大きいが遅いです。スコア行列がSRAMに入りきらないため、HBMを行き来するのです。

FlashAttentionの核心はIO-awareアルゴリズムです。(N, N)スコア行列全体を決してHBMに実体化しません。代わりにQ, K, Vを小さなタイル(ブロック)に分けてSRAMに載せ、ブロック単位でアテンションを計算しながらsoftmaxをオンラインで漸進更新します。

標準アテンション:
  1) S = Q·K^T  ->  (N, N) 全体をHBMに保存
  2) P = softmax(S)  ->  (N, N) を再び読み書き
  3) O = P·V  ->  また読み込み
  => HBM往復 O(N^2)

FlashAttention (タイリング):
  Q, K, V をブロックに分割 -> SRAMに載せる
  ブロックごとに部分スコア計算 -> オンラインsoftmaxで累積
  (N, N) 行列を丸ごとは作らない
  => HBM往復が大幅減少、メモリも O(N) に削減

オンラインsoftmaxの直観

softmaxは通常、正規化定数を求めるのに行全体を見る必要があります。FlashAttentionはブロックを巡回しながら、これまで見た最大値と合計を持ち回り、新しいブロックが来たら安全にスケールを補正しつつ累積します。こうすれば全体行列を一度にメモリに持たなくても、まったく同じ結果が得られます。

# 概念的な擬似コード: 実際のFlashAttentionはCUDAカーネルで実装される
# 核心はブロック巡回 + オンラインsoftmax累積

def flash_attention_concept(Q, K, V, block_size):
    N, d = Q.shape
    O = zeros(N, d)
    for i in range(0, N, block_size):          # Qブロックを巡回
        q_block = Q[i:i+block_size]
        running_max = -inf
        running_sum = 0
        acc = zeros(block_size, d)
        for j in range(0, N, block_size):      # K, Vブロックを巡回
            k_block = K[j:j+block_size]
            v_block = V[j:j+block_size]
            s = q_block @ k_block.T / sqrt(d)   # 小さなブロックスコアだけSRAMに
            block_max = s.max(axis=-1)
            new_max = maximum(running_max, block_max)
            # 以前の累積値と現在のブロックを同じスケールに補正
            acc = rescale_and_accumulate(acc, s, v_block, running_max, new_max)
            running_max = new_max
        O[i:i+block_size] = finalize(acc, running_sum)
    return O

実際の実装では、PyTorchのscaled_dot_product_attentionでバックエンドとして自動選択されるか、FlashAttentionのライブラリ/カーネルで提供されます。自分で擬似コードを書くことはほぼなく、ライブラリを有効にするだけです。

効果の整理

項目標準アテンションFlashAttention
中間メモリNの2乗に比例Nに線形
HBM往復多いタイリングで大幅減少
精度基準数値的に同一(近似ではない)
効果が大きい区間短い系列長い系列で加速大

重要な点は、FlashAttentionが近似ではなく、まったく同じアテンションをより速く、より少ないメモリで計算するということです。品質損失なしでタダで得られる最適化なので、事実上の標準になりました。

4. 長コンテキストのためのアテンション変種

KV cache削減(GQA/MQA)とIO最適化(FlashAttention)は標準アテンションを効率化しますが、アテンション範囲は依然として系列全体です。コンテキストが数十万トークンに伸びると、範囲そのものを制限する手法が必要です。

スライディングウィンドウアテンション

各トークンが全体ではなく直近W個のトークンだけを見るよう制限します。アテンションコストがNの2乗からN かける Wに減り、KV cacheもウィンドウサイズで上限が生まれます。深い層を積むとウィンドウが間接的に重なり、遠くの情報もある程度伝わります。Mistralなどが採用しました。

全体アテンション (N=8):
 各queryがすべてのkeyを見る -> コスト O(N^2)

スライディングウィンドウ (W=3):
 query5はkey3, key4, key5だけを見る
 query6はkey4, key5, key6だけを見る
 -> コスト O(N x W), KV cache上限 W

その他の長コンテキストアプローチ

  • 全域+局所の混合: 大半のトークンは局所ウィンドウだけを見つつ、少数の全域トークンは全体を見て、遠くの情報経路を確保します。
  • 階層的/疎なアテンション: 遠いトークンはまばらに、近いトークンは密に見るようにパターンを疎にします。
  • 位置エンコーディングの外挿: 学習より長い入力でRoPEを補正(NTK, YaRNなど)してコンテキストを伸ばします。これは位置エンコーディングの記事で扱います。

5. サービングメモリへの影響

これらの選択は結局、GPUメモリ予算とスループットに帰着します。2026年のサービングスタックでは、次の要素が一緒に働きます。

GPUメモリ = モデル重み + KV cache + 活性値/作業バッファ

KV cacheを削減すると (GQA/MQA, スライディングウィンドウ, KV量子化)
 -> より大きなバッチが可能 -> スループット上昇
 -> より長いコンテキストを収容可能
  • continuous(in-flight) batching: リクエストが終わった瞬間に新しいリクエストを差し込み、GPUを遊ばせません。2026年の標準です。
  • paged KV cache: KV cacheを固定サイズのブロックに分け、OSの仮想メモリのように管理(PagedAttention)し、断片化をなくしてメモリ利用率を上げます。
  • KV量子化: KV cacheをFP8/INT4で保存し、メモリをさらに半分以下に削減します。decode段階がメモリバウンドなので効果が大きいです。
  • prefill/decode分離, chunked prefill: 演算バウンドのprefillとメモリバウンドのdecodeを分離したり、prefillを分割して遅延を平坦化します。
手法削減対象スループット影響品質影響
GQA/MQAKV cache(ヘッド数)バッチ拡大で上昇ほぼなし~わずか
FlashAttentionアテンションIO/中間メモリ直接加速なし(同一結果)
スライディングウィンドウKV cache/アテンション範囲長コンテキストで上昇長距離依存に注意
KV量子化KV cache精度バッチ拡大で上昇わずか(設定依存)
paged KV cacheメモリ断片化利用率で上昇なし

5.5. FlashAttentionとcausal mask、そして後続の改善

FlashAttentionのタイリングはcausal maskとも相性がよいです。デコーダでは、queryブロックより未来にあるkeyブロックはそもそも計算する必要がないため、ブロック単位でスキップして演算をさらに減らせます。

causalアテンションでのブロックスキップ:
 queryブロックi, keyブロックj について
  j > i のブロック (完全に未来) -> 丸ごとスキップ
  j == i のブロック -> ブロック内部だけ三角マスク
  j < i のブロック -> すべて計算

-> causalのときアテンション演算が約半分に減少

FlashAttentionはその後のバージョンを経て、ハードウェア活用をさらに引き上げました。核心アイデア(タイリング + オンラインsoftmax + IO認識)は同じですが、ワープ/スレッドのスケジューリングと作業分割を改善し、最新GPUの演算ユニットをより隙間なく満たします。実務ではライブラリやフレームワークがハードウェアに合った実装を自動選択するので、ユーザーは「オンになっているか」だけ確認すればよいです。

世代ごとの発展方向(概念):
 第1世代: (N,N)を非実体化 + オンラインsoftmaxでメモリ O(N), HBM往復減少
 以降   : 並列化の軸を再配置, ワープ単位の作業分割を改善
          -> 同じアルゴリズムでGPU占有率(occupancy)が上昇

重要な点は、これらの改善が結果を変えないことです。どの世代を使っても数値的に同一のアテンションを計算し、違いは速度とメモリ効率だけです。

6. サービングメモリ予算を一度自分で計算してみる

理論を実戦の感覚に変えるには、自分で数字を入れてみる必要があります。7B級モデルを80GBのGPU 1枚に載せると仮定し、バッチをどこまで増やせるか見積もります。

仮定:
 モデル重み: FP16, 約 14 GB
 L=32, D=4096, FP16(2バイト)
 コンテキスト N=8192

GPUメモリ 80 GB のうち:
 重み 14 GB
 活性値/作業バッファ 約 6 GB (ラフ)
 KV cacheに使える余裕 = 80 - 14 - 6 = 約 60 GB

MHA(KVヘッド32)のときリクエストあたりKV cache:
 = 2 x 32 x 8192 x 4096 x 2 (バイト)
 約 4.3 GB
 -> 60 / 4.3 = 約 13 件の同時リクエスト

GQA(G=8, KVヘッド8)に変えるとリクエストあたりKV cache:
 = 上の 8/32 = 約 1.1 GB
 -> 60 / 1.1 = 約 54 件の同時リクエスト

KVヘッドを32から8に減らしただけで、同時処理可能なリクエスト数が約4倍に増えました。同時リクエスト数はそのままスループットなので、GQA一つで同じGPUのサービングコストを大きく下げられるわけです。ここにKV量子化(FP8)まで加えれば、もう一度半分に減らせます。

7. decodeがメモリバウンドである理由と意味

推論は二つの段階に分かれます。入力プロンプトを一度に処理するprefill、そしてトークンを一つずつ作るdecodeです。両者の性格は正反対です。

prefill (プロンプトN個のトークン処理):
 大きな行列積 -> 演算(FLOPs)バウンド
 GPUテンソルコアを満たして使う

decode (トークン1個ずつ生成):
 薄い行列-ベクトル積 -> メモリ帯域バウンド
 毎ステップ重み + KV cacheをメモリから読む
 演算量は少ないがメモリ読み込みがボトルネック

decodeがメモリバウンドであるという事実は、最適化の方向を定めます。演算ではなくメモリ読み込みがボトルネックなので、(1)重みを小さく(量子化)し、(2)KV cacheを小さく(GQA/MQA, KV量子化)し、(3)一度のメモリ読み込みでより多くの仕事をさせる(バッチ拡大, speculative decoding)のが効果的です。

メモリバウンドdecodeを攻略する2026の手法:
 continuous batching : 終わったリクエストの席に新しいリクエストを即投入
 paged KV cache      : KVをブロック単位で管理し断片化を除去
 FP8/INT4量子化       : 重み/KVを小さく -> メモリ読み込み減少
 speculative decoding: ドラフトモデルで複数トークンを先に作り検証
                       メモリバウンド区間で約2~3倍の加速

8. サービングフレームワーク別のアテンション最適化

2026年の主要サービングフレームワークは、上の手法をそれぞれの方式で実装します。

フレームワーク特徴アテンション/KV関連の強み
vLLM汎用, 広いモデル/ハードウェア対応PagedAttentionの元祖, continuous batching
TensorRT-LLMコンパイルベース, NVIDIA最適化H100で高いスループット(約15~30%上回る事例)
SGLangRadixAttentionprefixキャッシュ再利用, マルチターンに強い

三つのフレームワークはすべてFlashAttention系カーネルとpaged KV cacheを基本に敷き、GQA/MQAモデルを対応します。違いはコンパイル最適化の水準、prefixキャッシュ再利用の戦略、対応ハードウェアの幅などで分かれます。vLLMは汎用性と広い対応で、TensorRT-LLMはNVIDIAハードウェアでの頂点性能で、SGLangはprefixを多く共有するマルチターン/構造化ワークロードで、それぞれ強みを示します。

9. アテンション変種の選択ガイド

状況別にどの組み合わせを選ぶか整理すると次のとおりです。

一般的なテキストLLMサービング:
 GQAモデル + FlashAttention + paged KV cache + continuous batching
 (大半のデフォルト)

メモリが厳しくスループット最優先:
 上に加えて KV量子化(FP8) + より積極的なバッチ

超長文コンテキスト(数十万トークン):
 スライディングウィンドウ/疎アテンションモデル + 位置外挿(NTK/YaRN)
 + paged KV cache は必須

マルチターンチャット/共通prefixが多いワークロード:
 prefixキャッシュ再利用(例: RadixAttention)の利得が大きい

核心の原則は、「標準アテンションはIO最適化(FlashAttention)で、KV cacheはヘッド共有と量子化で、アテンション範囲はウィンドウ/疎化で抑える」という三つの軸を状況に合わせて組み合わせることです。

落とし穴とトラブルシューティング

  • GQAヘッド数のミスマッチ: KVヘッドGがQueryヘッドHを割り切れないと、repeat_interleave段階でshapeエラーになります。HはGの倍数である必要があります。
  • MQA品質低下の見落とし: メモリだけを見てMQAに行くと、特定のタスクで品質が落ちることがあります。まずGQA(G=8など)から始めて均衡を取ってください。
  • FlashAttention未適用: ライブラリ/バックエンドがオフだと、長い系列でOOMになったり遅くなったりします。PyTorchなら適切なバックエンドが選ばれているか確認してください。
  • スライディングウィンドウの長距離損失: ウィンドウが小さすぎると、文書前半の情報が後ろへ伝わりません。ウィンドウサイズと層数の関係を考慮してください。
  • KV量子化の精度の落とし穴: KVを低すぎるビットに量子化すると、長コンテキストで累積誤差により品質が落ちることがあります。INT4よりFP8が安全な場合が多いです。
  • バッチメモリ推定の抜け: KV cacheはバッチに線形比例します。バッチを上げる前に上の式でメモリを事前計算しないと、運用中にOOMになります。

9.5. アテンション最適化と他の効率化手法の相互作用

アテンション・KV最適化は単独で使われず、モデル・サービング全般の他の効率化手法と一緒に働きます。いくつかの相互作用を押さえておきます。

量子化(重み)との関係:
 重み量子化はモデルメモリを減らし、KV量子化はKV cacheを減らす
 -> 両者は独立なので一緒に適用可能
 -> 一緒に使えばより大きなバッチ/長いコンテキストを収容

MoEとの関係:
 MoEはFFNを複数のエキスパートに分割し一部だけを活性化
 -> アテンション部分はそのままなのでGQA/FlashAttentionと直交(独立)
 -> ただしMoEは活性パラメータは少なくても全体パラメータ(メモリ)は大きい

speculative decodingとの関係:
 ドラフトモデルで複数トークンを先に作り、本モデルで一度に検証
 -> KV cache管理がより複雑になるが、メモリバウンドdecodeで加速大
 -> GQAでKVを減らしておくと、ドラフト・検証ともにメモリの余裕ができる

核心は、これらの手法が大半は互いに直交しているという点です。アテンション変種(GQA/FlashAttention)は「アテンションを安く」、量子化は「重み/KVを小さく」、MoEは「演算を条件付きに」、speculative decodingは「decodeを速く」します。異なる軸を攻略するので一緒に積み重ねられ、実際のプロダクションサービングはこれらを組み合わせて構成されます。

9.7. よくある誤解を正す

アテンション効率化についてのよくある誤解を整理します。

誤解1: 「FlashAttentionは近似なので品質が少し落ちる」
 -> 誤り。数値的に同一のアテンション。品質損失なし。

誤解2: 「MQAを使えば必ずGQAより速い」
 -> メモリは小さいが、品質低下でより大きなモデル/再学習が必要になりうる。
    実効コストはワークロード次第。

誤解3: 「長コンテキストはKV cacheを大きくすればよい」
 -> メモリだけでなく、位置外挿(品質)とアテンションコスト(演算)も併せて扱う必要がある。

誤解4: 「バッチを大きくすれば必ず良い」
 -> スループットは上がるが、個々のリクエスト遅延(TPOT)が増えうる。SLAとの均衡が必要。

10. ベンチマーク数値の読み方

サービング最適化の記事を読むと、「スループットX%向上」「遅延Yms」のような数値があふれます。これを正しく解釈するには、いくつかの軸を区別する必要があります。

主要指標:
 スループット(throughput): 単位時間あたりのトークン数 (全ユーザー基準)
 遅延(latency):
   TTFT (time to first token): 最初のトークンまで -> prefillの影響が大きい
   TPOT (time per output token): トークン間の間隔 -> decodeの影響が大きい
 同時実行数: 同時に処理するリクエスト数 (バッチと直結)

同じ最適化でも改善する指標が異なります。たとえばバッチを大きくすると全体のスループットは上がりますが、個々のリクエストの遅延は増えることがあります。KV cache削減(GQA)はバッチを大きくしてスループットを上げ、FlashAttentionはprefillと長い系列でTTFTを減らし、speculative decodingは主にTPOT(decode遅延)を減らします。

指標主に改善する手法注意点
スループットGQA, バッチ拡大, KV量子化遅延とトレードオフしうる
TTFTFlashAttention, chunked prefill長いプロンプトで体感大
TPOTspeculative decoding, KV最適化受容率により効果が変動

ベンチマーク数値を引用するときは、「どのハードウェアで、どのモデルで、どの系列長とバッチで」測定したかが核心です。同じ手法でも条件によって向上幅が大きく変わるので、絶対数値よりも、どの軸をどう改善するかのメカニズムを理解するほうが重要です。

おわりに

標準アテンションの二つのボトルネック — 2乗の演算コストと線形のKV cache — は、それぞれ別の道具で攻略します。FlashAttentionは精度を保ったままアテンションのメモリIOを削減し、事実上タダの加速を与えます。GQA/MQAはKVヘッドを共有してKV cacheを削減し、より大きなバッチと長いコンテキストを可能にします。スライディングウィンドウや疎アテンションは範囲そのものを制限し、超長文コンテキストのコストを抑えます。

実務ではこれらが一緒に使われます。GQAでKVを削減し、FlashAttentionでアテンションを加速し、continuous batchingとpaged KV cacheでメモリを隙間なく活用するのが、2026年のサービングの基本の組み合わせです。次の記事では、位置エンコーディング、特にRoPEと長さの外挿を深く扱います。

参考資料