Skip to content
Published on

Transformer構造の完全解剖 — AttentionからKV Cacheまで

Authors

はじめに — なぜ今もTransformerなのか

2017年に「Attention Is All You Need」が登場して以来、Transformerは自然言語処理だけでなく、ビジョン、音声、コード、マルチモーダルまで、ほぼすべての深層学習分野の基本骨格になりました。2026年現在、私たちが使うほぼすべての大規模言語モデル(LLM)は本質的にTransformerデコーダの変種です。GQA、RoPE、FlashAttention、MoEといった最新手法も、すべて元のTransformerの上に乗せた改良にすぎず、核となる骨格は変わっていません。

そのため、Transformerをきちんと理解すれば、新しいモデルや論文を読むときに「どこが変わったか」だけを把握すればよくなります。逆に、self-attentionとKV cacheの動作を曖昧にしか理解していないと、サービング最適化でもファインチューニングでも、あらゆる段階でつまずきます。

本記事ではTransformerの1ブロックを最初から最後まで分解します。各演算がどのテンソルshapeを入力に取り、どのshapeを出力するか、パラメータがいくつ生じるか、そして推論時になぜKV cacheが登場するかを一つの流れでつなげます。数式はドル記号を使わず通常表記で、コードは動作する形で示します。

全体構造をひと目で

入力トークン列が入り、次トークンの分布が出るまでの流れは次のとおりです。

[入力トークンID]  (整数列, 長さN)
      |
      v
[トークン埋め込み + 位置エンコーディング]   ->  X: (B, N, D)
      |
      v
+---------------------------+
|  Transformerブロック x L   |
|                           |
|  LayerNorm                |
|  Multi-Head Attention     |
|  + 残差(residual)          |
|                           |
|  LayerNorm                |
|  Feed-Forward Network     |
|  + 残差(residual)          |
+---------------------------+
      |
      v
[最終LayerNorm]
      |
      v
[アンエンベディング / LM head]   ->  logits: (B, N, V)
      |
      v
[softmax]                       ->  次トークンの確率分布

記号は次のように使います。

B = バッチサイズ
N = 系列長 (トークン数)
D = モデル次元 (d_model)
H = アテンションヘッド数
d_k = ヘッドあたり次元 = D / H
V = 語彙サイズ
L = ブロック(layer)数

GPT系は上図でデコーダブロックだけをL個積み重ねた構造です。エンコーダ-デコーダ構造(原論文、T5など)はエンコーダスタックとデコーダスタックを別に持ちます。本記事ではデコーダ中心に説明しつつ、エンコーダとの違いを明確に押さえます。

1. 埋め込み — トークンをベクトルに

最初に整数トークンIDをD次元ベクトルに変換します。これは単にサイズ(V, D)のルックアップテーブルです。

import torch
import torch.nn as nn

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model

    def forward(self, token_ids):
        # token_ids: (B, N)  ->  (B, N, D)
        return self.embed(token_ids) * (self.d_model ** 0.5)

原論文は埋め込みにsqrt(D)を掛けてスケールを合わせますが、これは位置エンコーディングと大きさを近づけるための慣例です。埋め込みのパラメータ数はV * Dです。たとえば語彙5万、D=4096なら埋め込みだけで約2億個のパラメータを占めます。

2. 位置エンコーディング — 順序情報の注入

self-attentionはそれ自体では順序を知りません。つまり入力トークンを並べ替えてもattention出力の集合は同じです(順列同変性)。そこで位置情報を明示的に入れる必要があります。最も古典的な方式は正弦/余弦関数で作る絶対位置エンコーディングです。

PE(pos, 2i)   = sin( pos / 10000^(2i / D) )
PE(pos, 2i+1) = cos( pos / 10000^(2i / D) )

pos = 位置インデックス (0, 1, 2, ...)
i   = 次元インデックス

各位置ごとに異なる周波数の正弦波の組み合わせで固有のパターンを作り、埋め込みに足します。最新モデルは絶対位置の代わりにRoPE(回転位置エンコーディング)やALiBi(アテンションバイアス)をより多く使います。この部分は別記事で深く扱い、ここでは「位置情報がアテンション入力のどこかに必ず注入される」という点だけ覚えておけば十分です。

3. Self-Attention — 核心メカニズム

Query, Key, Value

self-attentionの直観はこうです。各トークンは「自分は何を探したいか」(Query)を投げ、すべてのトークンは「自分はこういう内容だ」(Key)という標識を持ち、実際に取り出す内容はValueに入っています。QueryとKeyの類似度が高いほど、そのValueをより多く取り出します。

入力X: (B, N, D)に対して、三つの線形変換でQ, K, Vを作ります。

Q = X · W_Q     W_Q: (D, D)   ->  Q: (B, N, D)
K = X · W_K     W_K: (D, D)   ->  K: (B, N, D)
V = X · W_V     W_V: (D, D)   ->  V: (B, N, D)

スケールド・ドットプロダクト・アテンション

核心の数式は次のとおりです(ドル記号なしで表記)。

Attention(Q, K, V) = softmax( (Q · K^T) / sqrt(d_k) ) · V

Q · K^T : (B, N, N)  -- すべてのトークン対の類似度スコア
/ sqrt(d_k) : スケーリング (スコアが大きくなりすぎるのを防止)
softmax : 行方向に正規化 -> アテンション重み
· V : 加重平均で出力を生成  ->  (B, N, D)

sqrt(d_k)で割る理由が重要です。d_kが大きいとQとKの内積の分散がd_kに比例して大きくなり、softmax入力が大きすぎると勾配がほぼ0の領域(飽和)に押しやられて学習が不安定になります。標準偏差で割って分散を1付近に保つわけです。

import torch
import torch.nn.functional as F

def scaled_dot_product_attention(q, k, v, mask=None):
    # q, k, v: (B, H, N, d_k)
    d_k = q.size(-1)
    scores = torch.matmul(q, k.transpose(-2, -1)) / (d_k ** 0.5)
    # scores: (B, H, N, N)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    weights = F.softmax(scores, dim=-1)
    out = torch.matmul(weights, v)  # (B, H, N, d_k)
    return out, weights

Causal Mask — 未来を見せない

デコーダ(GPT系)では、位置iのトークンがiより後ろに来るトークンを見てはいけません。学習時に正解を先に見ることになるからです。そこでアテンションスコア行列の上三角部分(未来の位置)を負の無限大で埋め、softmax後の重みが0になるようにします。

causal mask (N=4, 1=許可, 0=遮断):

        key0  key1  key2  key3
query0   1     0     0     0
query1   1     1     0     0
query2   1     1     1     0
query3   1     1     1     1

エンコーダ(BERT、原論文のエンコーダスタック)にはこのマスクがありません。すべてのトークンが互いを見られる双方向アテンションです。この違いがデコーダ(生成)とエンコーダ(理解/表現)モデルの根本的な区別点です。

4. Multi-Head Attention — 複数の視点で見る

一つのアテンションだけを使うと、モデルは一つの観点しか学べません。マルチヘッドはD次元をH個のヘッドに分割し、各ヘッドがd_k = D / H次元で独立にアテンションを行うようにします。あるヘッドは文法的依存を、別のヘッドは意味的関連を学ぶ、というように役割が分化します。

1) Q, K, VをH個のヘッドに分割:
   (B, N, D)  ->  (B, N, H, d_k)  ->  (B, H, N, d_k)

2) 各ヘッドでスケールド・ドットプロダクト・アテンション:
   (B, H, N, d_k)  ->  (B, H, N, d_k)

3) ヘッドを再び結合:
   (B, H, N, d_k)  ->  (B, N, H, d_k)  ->  (B, N, D)

4) 出力射影:
   (B, N, D) · W_O  ->  (B, N, D)
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        B, N, D = x.shape
        H, d_k = self.num_heads, self.d_k

        def split_heads(t):
            return t.view(B, N, H, d_k).transpose(1, 2)  # (B, H, N, d_k)

        q = split_heads(self.w_q(x))
        k = split_heads(self.w_k(x))
        v = split_heads(self.w_v(x))

        out, _ = scaled_dot_product_attention(q, k, v, mask)
        # (B, H, N, d_k) -> (B, N, D)
        out = out.transpose(1, 2).contiguous().view(B, N, D)
        return self.w_o(out)

マルチヘッドアテンションのパラメータ数は、W_Q, W_K, W_V, W_Oの四つの(D, D)行列で、バイアスを無視すれば4 * D * Dです。D=4096なら約6700万個です。

5. Feed-Forward Network — 位置ごとの変換

アテンション出力に対し、各位置ごとに独立に適用される2層MLPです。通常は中間次元をDの4倍に拡張してから再び縮小します。

FFN(x) = activation(x · W_1 + b_1) · W_2 + b_2

W_1: (D, 4D)   ->  中間表現 (B, N, 4D)
activation: ReLU / GELU / SwiGLU など
W_2: (4D, D)   ->  再び (B, N, D)

最新モデルは活性化にSwiGLU(ゲート付き変種)をよく使います。FFNのパラメータ数はおよそ2 * D * 4D = 8 * D * Dで、実は1ブロックの中でアテンション(4 D^2)よりFFN(8 D^2)のほうが多くのパラメータを占めます。MoEはまさにこのFFNを複数のエキスパートに分割し、一部だけを活性化する手法です。

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.act = nn.GELU()

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))

6. 残差接続とLayerNorm — 深いネットを安定化

数十個のブロックを積むと勾配消失や学習の不安定が生じます。これを二つの仕組みで解決します。第一に、残差接続(residual)はサブ層の出力に入力をそのまま足し、勾配が深いネットを通れる近道を作ります。第二に、LayerNormは各トークンベクトルを正規化し、活性値の分布を安定させます。

原論文はサブ層の後に正規化を置くpost-norm構造でしたが、深いモデルで学習が不安定になるため、最近はほぼすべてpre-normを使います。pre-normはサブ層の前で正規化します。

post-norm (原論文):
   x = LayerNorm(x + Sublayer(x))

pre-norm (最新標準):
   x = x + Sublayer(LayerNorm(x))

pre-normが安定する理由は、残差経路が正規化を経ずにそのまま足され、深いネットでも信号が損なわれずに伝わるからです。多くの最新モデルはLayerNormの代わりに平均引きを省いたRMSNormを使い、わずかに演算を減らします。

class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = FeedForward(d_model, d_ff)

    def forward(self, x, mask=None):
        # pre-norm構造
        x = x + self.attn(self.ln1(x), mask)
        x = x + self.ffn(self.ln2(x))
        return x

7. 1ブロックのテンソルの流れの整理

一つのデコーダブロックを通る間にshapeがどう変わるかを整理すると次のとおりです。

入力 x                          : (B, N, D)
LayerNorm(x)                    : (B, N, D)
  W_Q/W_K/W_V 射影               : (B, N, D) それぞれ
  ヘッド分割                       : (B, H, N, d_k)
  Q·K^T                         : (B, H, N, N)
  softmax · V                   : (B, H, N, d_k)
  ヘッド結合                       : (B, N, D)
  W_O 射影                       : (B, N, D)
残差を足す                        : (B, N, D)
LayerNorm(x)                    : (B, N, D)
  FFN 拡張                       : (B, N, 4D)
  FFN 縮小                       : (B, N, D)
残差を足す                        : (B, N, D)
出力                            : (B, N, D)

ブロックを通ってもshapeは(B, N, D)で保存されます。だから同じブロックをL回繰り返して積み重ねられるのです。

8. パラメータ数の概算

1ブロックの主要パラメータを整理すると次のとおりです。

構成要素パラメータ数(概算)
アテンション (W_Q, W_K, W_V, W_O)4 かける D かける D
FFN (拡張 + 縮小)8 かける D かける D
ブロック合計約 12 かける D かける D

D=4096、L=32ブロックなら、ブロックあたり約12 かける 4096 かける 4096 = 約2億個、全体は約64億個です。ここに埋め込み(V かける D)を足すと、よく耳にする「7Bモデル」の規模になります。この概算はモデルサイズを直観的に見積もるのに役立ちます。

9. 推論とKV Cache — なぜ必要か

学習時は系列全体を一度に並列処理します。しかし推論(生成)はトークンを一つずつ自己回帰的に作ります。トークンを一つ生成するたびに最初からアテンションを計算し直すと、すでに計算した前のトークンのKey/Valueを毎回再計算する大きな無駄が生じます。

KV cacheは、すでに生成したトークンのKとVをメモリに保存しておき、新しいトークンが入ったときにそのトークンのQだけを新たに計算してキャッシュ済みのK, Vとアテンションする手法です。

KV cacheなし (再計算):
   t番目のトークン生成時 -> 1..tトークンのK, Vをすべて再計算  (O(t) の計算)

KV cache使用:
   t番目のトークン生成時 -> t番目のK, Vだけ計算してキャッシュに追加
   キャッシュ済みの1..t-1のK, Vはそのまま再利用  (O(1) の追加計算)

このときKV cacheが占めるメモリは次のように概算できます。

KV cacheサイズ(バイト)
 = 2(KとV) x L(レイヤー) x N(系列長) x D(次元)
   x B(バッチ) x bytes_per_element

例: L=32, N=8192, D=4096, B=1, FP16(2バイト)
 = 2 x 32 x 8192 x 4096 x 1 x 2
 約 4.3 GB

系列が長くなりバッチが大きくなるほどKV cacheは線形に増え、長コンテキストのサービングではモデル重みよりKV cacheがメモリを多く食うこともあります。この問題を解くための手法がまさにGQA/MQA(KVヘッド共有でキャッシュ縮小)、PagedAttention(仮想メモリ式のブロック管理)などで、これは後続記事で深く扱います。

10. エンコーダとデコーダ、そして変種

原論文のエンコーダ-デコーダ構造と、今日主流のデコーダ専用構造の違いを整理します。

区分エンコーダデコーダ(GPT系)エンコーダ-デコーダ(T5など)
アテンション方向双方向因果的(causal)エンコーダは双方向、デコーダはcausal + 交差アテンション
マスクなしcausal maskデコーダにcausal mask
代表的用途理解/表現(分類, 埋め込み)生成翻訳, 要約などseq-to-seq
代表モデルBERTGPT, Llama, QwenT5, BART

交差アテンション(cross-attention)はデコーダがエンコーダの出力をKey/Valueとしてアテンションするもので、入力(ソース)と出力(ターゲット)が異なる翻訳のようなタスクに適します。一方、純粋な生成モデルはデコーダブロックだけを積み、単純でありながら強力です。

11. 小さな数値例で追うアテンション

抽象的な数式だけでは感覚がつかみにくいです。とても小さな例で流れを一度追ってみます。トークン3個、次元D=4、ヘッド1個としましょう。

入力 X: (N=3, D=4)
 x0 = [1, 0, 1, 0]
 x1 = [0, 1, 0, 1]
 x2 = [1, 1, 0, 0]

1) Q, K, Vを線形変換で作る (重みは学習された値)
   -> Q, K, V: それぞれ (3, 4)

2) スコア = Q · K^T  ->  (3, 3) 行列
   各要素 score[i][j] はトークンiがトークンjにどれだけ注目するかの原スコア

3) sqrt(d_k)=2 で割ってスケーリング

4) causal mask 適用 (デコーダなら):
   score[0][1], score[0][2], score[1][2] -> 負の無限大

5) 行ごとに softmax -> アテンション重み (各行の合計が1)

6) 重み · V  ->  出力 (3, 4)
   トークンiの出力 = 自分自身と過去のトークンのVを加重平均したもの

核心は、score行列の各行が「このトークンが他のトークンをどれだけ参照するか」の分布であること、そしてcausal maskがその分布で未来を0にすることです。出力は結局、過去のトークンのValueを自分の関心に合わせて混ぜた結果です。

12. 出力層 — logitsからトークンへ

最後のブロックを通った(B, N, D)の表現は、LM head(通常は埋め込みと重みを共有する(D, V)行列)を経て、語彙全体に対するスコア(logits)になります。

最終表現: (B, N, D)
LM head (D, V) を掛ける:  ->  logits: (B, N, V)
softmax (最後の位置):  ->  次トークンの確率分布 (V,)

生成時には最後の位置のlogitsだけを見て次のトークンを選びます。選び方(サンプリング)が出力の多様性を左右します。

デコーディング戦略のまとめ:
 greedy   : 毎ステップ最も確率の高いトークンを選択 -> 決定的, 単調になりうる
 temperature: logitsをTで割り分布を平坦(T>1)/尖鋭(T<1)に調整
 top-k    : 上位k個のトークンの中だけからサンプリング
 top-p    : 累積確率pまでのトークンの中だけからサンプリング (nucleus)
import torch
import torch.nn.functional as F

def sample_next_token(logits, temperature=1.0, top_k=None):
    # logits: (V,)
    logits = logits / max(temperature, 1e-6)
    if top_k is not None:
        v, _ = torch.topk(logits, top_k)
        logits[logits < v[-1]] = float('-inf')
    probs = F.softmax(logits, dim=-1)
    return torch.multinomial(probs, num_samples=1)

この出力層とデコーディング戦略は、学習されたTransformerを実際の生成器にする最後のつなぎ目です。モデル構造自体は同じでも、どうサンプリングするかで結果の性格が大きく変わります。

13. 学習と推論の演算パターンの違い

同じTransformerでも、学習と推論は演算の性格が異なります。この違いを知ると、サービング最適化の出発点が見えてきます。

区分学習(またはprefill)推論 decode段階
処理単位系列全体を並列トークン1個ずつ順次
ボトルネック演算(行列積)バウンドメモリ帯域バウンド
行列の形大きな行列積(N x N)薄い行列-ベクトル積
最適化ポイントFLOPs, テンソルコア活用KV cache, メモリ帯域

推論のdecode段階がメモリバウンドである点が重要です。トークン一つを作るたびにモデル重みとKV cacheをメモリから読む必要があるため、演算よりメモリ読み込みがボトルネックになります。だから量子化(重みを小さく)とKV cache最適化がdecodeスループットに大きな効果を出します。2026年のサービングスタックのcontinuous(in-flight) batching, paged KV cache, FP8/INT4量子化, speculative decodingは、すべてこのメモリバウンド特性を攻略します。

落とし穴とトラブルシューティング

  • softmaxスケーリングの抜け: sqrt(d_k)で割るのを忘れると、ヘッド次元が大きいとき学習が発散したり、アテンションが一つのトークンに過度に偏ったりします。
  • マスク適用位置の誤り: causal maskはsoftmax前のスコア段階で負の無限大で埋めるべきです。softmax後に0を掛けると正規化が壊れます。
  • pre-normとpost-normの混同: 深いモデルをゼロから学習するときpost-normを使うと、warmupなしでは発散しやすいです。特別な理由がなければpre-normを推奨します。
  • KV cacheのdtype不一致: 重みがFP16なのにKV cacheをFP32にするとメモリが2倍になります。意図的な量子化でなければdtypeを合わせてください。
  • 位置エンコーディングの長さ超過: 絶対位置エンコーディングを学習長より長い入力に使うと性能が急落します。長コンテキストが必要ならRoPE系と外挿手法を検討すべきです。
  • ヘッド分割時のtranspose抜け: (B, N, H, d_k)(B, H, N, d_k)に変えるtransposeを忘れると、ヘッド次元と系列次元が混ざって静かに誤った結果が出ます。

おわりに

Transformerの1ブロックは結局、(1)アテンションでトークン間の情報を混ぜ、(2)FFNで位置ごとの非線形変換を行い、(3)残差と正規化で深いネットを安定化する、という三つの単純なアイデアの繰り返しです。そこに位置エンコーディングで順序を注入し、推論段階ではKV cacheで再計算を避けます。

この骨格を手にしていれば、RoPEは位置エンコーディングの変種、GQA/MQAはアテンションのKV効率化、FlashAttentionはアテンション演算のIO最適化、MoEはFFNの条件付き活性化として、すっきり整理できます。次の記事では、アテンションの進化(MQA/GQA/FlashAttention)と位置エンコーディング(RoPE)をそれぞれ深く掘り下げます。

参考資料