Skip to content

필사 모드: Transformer構造の完全解剖 — AttentionからKV Cacheまで

日本語
0%
정확도 0%
💡 왼쪽 원문을 읽으면서 오른쪽에 따라 써보세요. Tab 키로 힌트를 받을 수 있습니다.
원문 렌더가 준비되기 전까지 텍스트 가이드로 표시합니다.

はじめに — なぜ今もTransformerなのか

2017年に「Attention Is All You Need」が登場して以来、Transformerは自然言語処理だけでなく、ビジョン、音声、コード、マルチモーダルまで、ほぼすべての深層学習分野の基本骨格になりました。2026年現在、私たちが使うほぼすべての大規模言語モデル(LLM)は本質的にTransformerデコーダの変種です。GQA、RoPE、FlashAttention、MoEといった最新手法も、すべて元のTransformerの上に乗せた改良にすぎず、核となる骨格は変わっていません。

そのため、Transformerをきちんと理解すれば、新しいモデルや論文を読むときに「どこが変わったか」だけを把握すればよくなります。逆に、self-attentionとKV cacheの動作を曖昧にしか理解していないと、サービング最適化でもファインチューニングでも、あらゆる段階でつまずきます。

本記事ではTransformerの1ブロックを最初から最後まで分解します。各演算がどのテンソルshapeを入力に取り、どのshapeを出力するか、パラメータがいくつ生じるか、そして推論時になぜKV cacheが登場するかを一つの流れでつなげます。数式はドル記号を使わず通常表記で、コードは動作する形で示します。

全体構造をひと目で

入力トークン列が入り、次トークンの分布が出るまでの流れは次のとおりです。

[入力トークンID] (整数列, 長さN)

|

v

[トークン埋め込み + 位置エンコーディング] -> X: (B, N, D)

|

v

+---------------------------+

| Transformerブロック x L |

| |

| LayerNorm |

| Multi-Head Attention |

| + 残差(residual) |

| |

| LayerNorm |

| Feed-Forward Network |

| + 残差(residual) |

+---------------------------+

|

v

[最終LayerNorm]

|

v

[アンエンベディング / LM head] -> logits: (B, N, V)

|

v

[softmax] -> 次トークンの確率分布

記号は次のように使います。

B = バッチサイズ

N = 系列長 (トークン数)

D = モデル次元 (d_model)

H = アテンションヘッド数

d_k = ヘッドあたり次元 = D / H

V = 語彙サイズ

L = ブロック(layer)数

GPT系は上図でデコーダブロックだけをL個積み重ねた構造です。エンコーダ-デコーダ構造(原論文、T5など)はエンコーダスタックとデコーダスタックを別に持ちます。本記事ではデコーダ中心に説明しつつ、エンコーダとの違いを明確に押さえます。

1. 埋め込み — トークンをベクトルに

最初に整数トークンIDをD次元ベクトルに変換します。これは単にサイズ`(V, D)`のルックアップテーブルです。

class TokenEmbedding(nn.Module):

def __init__(self, vocab_size, d_model):

super().__init__()

self.embed = nn.Embedding(vocab_size, d_model)

self.d_model = d_model

def forward(self, token_ids):

token_ids: (B, N) -> (B, N, D)

return self.embed(token_ids) * (self.d_model ** 0.5)

原論文は埋め込みにsqrt(D)を掛けてスケールを合わせますが、これは位置エンコーディングと大きさを近づけるための慣例です。埋め込みのパラメータ数は`V * D`です。たとえば語彙5万、D=4096なら埋め込みだけで約2億個のパラメータを占めます。

2. 位置エンコーディング — 順序情報の注入

self-attentionはそれ自体では順序を知りません。つまり入力トークンを並べ替えてもattention出力の集合は同じです(順列同変性)。そこで位置情報を明示的に入れる必要があります。最も古典的な方式は正弦/余弦関数で作る絶対位置エンコーディングです。

PE(pos, 2i) = sin( pos / 10000^(2i / D) )

PE(pos, 2i+1) = cos( pos / 10000^(2i / D) )

pos = 位置インデックス (0, 1, 2, ...)

i = 次元インデックス

各位置ごとに異なる周波数の正弦波の組み合わせで固有のパターンを作り、埋め込みに足します。最新モデルは絶対位置の代わりにRoPE(回転位置エンコーディング)やALiBi(アテンションバイアス)をより多く使います。この部分は別記事で深く扱い、ここでは「位置情報がアテンション入力のどこかに必ず注入される」という点だけ覚えておけば十分です。

3. Self-Attention — 核心メカニズム

Query, Key, Value

self-attentionの直観はこうです。各トークンは「自分は何を探したいか」(Query)を投げ、すべてのトークンは「自分はこういう内容だ」(Key)という標識を持ち、実際に取り出す内容はValueに入っています。QueryとKeyの類似度が高いほど、そのValueをより多く取り出します。

入力`X: (B, N, D)`に対して、三つの線形変換でQ, K, Vを作ります。

Q = X · W_Q W_Q: (D, D) -> Q: (B, N, D)

K = X · W_K W_K: (D, D) -> K: (B, N, D)

V = X · W_V W_V: (D, D) -> V: (B, N, D)

スケールド・ドットプロダクト・アテンション

核心の数式は次のとおりです(ドル記号なしで表記)。

Attention(Q, K, V) = softmax( (Q · K^T) / sqrt(d_k) ) · V

Q · K^T : (B, N, N) -- すべてのトークン対の類似度スコア

/ sqrt(d_k) : スケーリング (スコアが大きくなりすぎるのを防止)

softmax : 行方向に正規化 -> アテンション重み

· V : 加重平均で出力を生成 -> (B, N, D)

sqrt(d_k)で割る理由が重要です。d_kが大きいとQとKの内積の分散がd_kに比例して大きくなり、softmax入力が大きすぎると勾配がほぼ0の領域(飽和)に押しやられて学習が不安定になります。標準偏差で割って分散を1付近に保つわけです。

def scaled_dot_product_attention(q, k, v, mask=None):

q, k, v: (B, H, N, d_k)

d_k = q.size(-1)

scores = torch.matmul(q, k.transpose(-2, -1)) / (d_k ** 0.5)

scores: (B, H, N, N)

if mask is not None:

scores = scores.masked_fill(mask == 0, float('-inf'))

weights = F.softmax(scores, dim=-1)

out = torch.matmul(weights, v) # (B, H, N, d_k)

return out, weights

Causal Mask — 未来を見せない

デコーダ(GPT系)では、位置iのトークンがiより後ろに来るトークンを見てはいけません。学習時に正解を先に見ることになるからです。そこでアテンションスコア行列の上三角部分(未来の位置)を負の無限大で埋め、softmax後の重みが0になるようにします。

causal mask (N=4, 1=許可, 0=遮断):

key0 key1 key2 key3

query0 1 0 0 0

query1 1 1 0 0

query2 1 1 1 0

query3 1 1 1 1

エンコーダ(BERT、原論文のエンコーダスタック)にはこのマスクがありません。すべてのトークンが互いを見られる双方向アテンションです。この違いがデコーダ(生成)とエンコーダ(理解/表現)モデルの根本的な区別点です。

4. Multi-Head Attention — 複数の視点で見る

一つのアテンションだけを使うと、モデルは一つの観点しか学べません。マルチヘッドはD次元をH個のヘッドに分割し、各ヘッドがd_k = D / H次元で独立にアテンションを行うようにします。あるヘッドは文法的依存を、別のヘッドは意味的関連を学ぶ、というように役割が分化します。

1) Q, K, VをH個のヘッドに分割:

(B, N, D) -> (B, N, H, d_k) -> (B, H, N, d_k)

2) 各ヘッドでスケールド・ドットプロダクト・アテンション:

(B, H, N, d_k) -> (B, H, N, d_k)

3) ヘッドを再び結合:

(B, H, N, d_k) -> (B, N, H, d_k) -> (B, N, D)

4) 出力射影:

(B, N, D) · W_O -> (B, N, D)

class MultiHeadAttention(nn.Module):

def __init__(self, d_model, num_heads):

super().__init__()

assert d_model % num_heads == 0

self.d_model = d_model

self.num_heads = num_heads

self.d_k = d_model // num_heads

self.w_q = nn.Linear(d_model, d_model)

self.w_k = nn.Linear(d_model, d_model)

self.w_v = nn.Linear(d_model, d_model)

self.w_o = nn.Linear(d_model, d_model)

def forward(self, x, mask=None):

B, N, D = x.shape

H, d_k = self.num_heads, self.d_k

def split_heads(t):

return t.view(B, N, H, d_k).transpose(1, 2) # (B, H, N, d_k)

q = split_heads(self.w_q(x))

k = split_heads(self.w_k(x))

v = split_heads(self.w_v(x))

out, _ = scaled_dot_product_attention(q, k, v, mask)

(B, H, N, d_k) -> (B, N, D)

out = out.transpose(1, 2).contiguous().view(B, N, D)

return self.w_o(out)

マルチヘッドアテンションのパラメータ数は、W_Q, W_K, W_V, W_Oの四つの`(D, D)`行列で、バイアスを無視すれば`4 * D * D`です。D=4096なら約6700万個です。

5. Feed-Forward Network — 位置ごとの変換

アテンション出力に対し、各位置ごとに独立に適用される2層MLPです。通常は中間次元をDの4倍に拡張してから再び縮小します。

FFN(x) = activation(x · W_1 + b_1) · W_2 + b_2

W_1: (D, 4D) -> 中間表現 (B, N, 4D)

activation: ReLU / GELU / SwiGLU など

W_2: (4D, D) -> 再び (B, N, D)

最新モデルは活性化にSwiGLU(ゲート付き変種)をよく使います。FFNのパラメータ数はおよそ`2 * D * 4D = 8 * D * D`で、実は1ブロックの中でアテンション(4 D^2)よりFFN(8 D^2)のほうが多くのパラメータを占めます。MoEはまさにこのFFNを複数のエキスパートに分割し、一部だけを活性化する手法です。

class FeedForward(nn.Module):

def __init__(self, d_model, d_ff):

super().__init__()

self.fc1 = nn.Linear(d_model, d_ff)

self.fc2 = nn.Linear(d_ff, d_model)

self.act = nn.GELU()

def forward(self, x):

return self.fc2(self.act(self.fc1(x)))

6. 残差接続とLayerNorm — 深いネットを安定化

数十個のブロックを積むと勾配消失や学習の不安定が生じます。これを二つの仕組みで解決します。第一に、残差接続(residual)はサブ層の出力に入力をそのまま足し、勾配が深いネットを通れる近道を作ります。第二に、LayerNormは各トークンベクトルを正規化し、活性値の分布を安定させます。

原論文はサブ層の後に正規化を置くpost-norm構造でしたが、深いモデルで学習が不安定になるため、最近はほぼすべてpre-normを使います。pre-normはサブ層の前で正規化します。

post-norm (原論文):

x = LayerNorm(x + Sublayer(x))

pre-norm (最新標準):

x = x + Sublayer(LayerNorm(x))

pre-normが安定する理由は、残差経路が正規化を経ずにそのまま足され、深いネットでも信号が損なわれずに伝わるからです。多くの最新モデルはLayerNormの代わりに平均引きを省いたRMSNormを使い、わずかに演算を減らします。

class TransformerBlock(nn.Module):

def __init__(self, d_model, num_heads, d_ff):

super().__init__()

self.ln1 = nn.LayerNorm(d_model)

self.attn = MultiHeadAttention(d_model, num_heads)

self.ln2 = nn.LayerNorm(d_model)

self.ffn = FeedForward(d_model, d_ff)

def forward(self, x, mask=None):

pre-norm構造

x = x + self.attn(self.ln1(x), mask)

x = x + self.ffn(self.ln2(x))

return x

7. 1ブロックのテンソルの流れの整理

一つのデコーダブロックを通る間にshapeがどう変わるかを整理すると次のとおりです。

入力 x : (B, N, D)

LayerNorm(x) : (B, N, D)

W_Q/W_K/W_V 射影 : (B, N, D) それぞれ

ヘッド分割 : (B, H, N, d_k)

Q·K^T : (B, H, N, N)

softmax · V : (B, H, N, d_k)

ヘッド結合 : (B, N, D)

W_O 射影 : (B, N, D)

残差を足す : (B, N, D)

LayerNorm(x) : (B, N, D)

FFN 拡張 : (B, N, 4D)

FFN 縮小 : (B, N, D)

残差を足す : (B, N, D)

出力 : (B, N, D)

ブロックを通ってもshapeは`(B, N, D)`で保存されます。だから同じブロックをL回繰り返して積み重ねられるのです。

8. パラメータ数の概算

1ブロックの主要パラメータを整理すると次のとおりです。

| 構成要素 | パラメータ数(概算) |

| --- | --- |

| アテンション (W_Q, W_K, W_V, W_O) | 4 かける D かける D |

| FFN (拡張 + 縮小) | 8 かける D かける D |

| ブロック合計 | 約 12 かける D かける D |

D=4096、L=32ブロックなら、ブロックあたり約12 かける 4096 かける 4096 = 約2億個、全体は約64億個です。ここに埋め込み(V かける D)を足すと、よく耳にする「7Bモデル」の規模になります。この概算はモデルサイズを直観的に見積もるのに役立ちます。

9. 推論とKV Cache — なぜ必要か

学習時は系列全体を一度に並列処理します。しかし推論(生成)はトークンを一つずつ自己回帰的に作ります。トークンを一つ生成するたびに最初からアテンションを計算し直すと、すでに計算した前のトークンのKey/Valueを毎回再計算する大きな無駄が生じます。

KV cacheは、すでに生成したトークンのKとVをメモリに保存しておき、新しいトークンが入ったときにそのトークンのQだけを新たに計算してキャッシュ済みのK, Vとアテンションする手法です。

KV cacheなし (再計算):

t番目のトークン生成時 -> 1..tトークンのK, Vをすべて再計算 (O(t) の計算)

KV cache使用:

t番目のトークン生成時 -> t番目のK, Vだけ計算してキャッシュに追加

キャッシュ済みの1..t-1のK, Vはそのまま再利用 (O(1) の追加計算)

このときKV cacheが占めるメモリは次のように概算できます。

KV cacheサイズ(バイト)

= 2(KとV) x L(レイヤー) x N(系列長) x D(次元)

x B(バッチ) x bytes_per_element

例: L=32, N=8192, D=4096, B=1, FP16(2バイト)

= 2 x 32 x 8192 x 4096 x 1 x 2

約 4.3 GB

系列が長くなりバッチが大きくなるほどKV cacheは線形に増え、長コンテキストのサービングではモデル重みよりKV cacheがメモリを多く食うこともあります。この問題を解くための手法がまさにGQA/MQA(KVヘッド共有でキャッシュ縮小)、PagedAttention(仮想メモリ式のブロック管理)などで、これは後続記事で深く扱います。

10. エンコーダとデコーダ、そして変種

原論文のエンコーダ-デコーダ構造と、今日主流のデコーダ専用構造の違いを整理します。

| 区分 | エンコーダ | デコーダ(GPT系) | エンコーダ-デコーダ(T5など) |

| --- | --- | --- | --- |

| アテンション方向 | 双方向 | 因果的(causal) | エンコーダは双方向、デコーダはcausal + 交差アテンション |

| マスク | なし | causal mask | デコーダにcausal mask |

| 代表的用途 | 理解/表現(分類, 埋め込み) | 生成 | 翻訳, 要約などseq-to-seq |

| 代表モデル | BERT | GPT, Llama, Qwen | T5, BART |

交差アテンション(cross-attention)はデコーダがエンコーダの出力をKey/Valueとしてアテンションするもので、入力(ソース)と出力(ターゲット)が異なる翻訳のようなタスクに適します。一方、純粋な生成モデルはデコーダブロックだけを積み、単純でありながら強力です。

11. 小さな数値例で追うアテンション

抽象的な数式だけでは感覚がつかみにくいです。とても小さな例で流れを一度追ってみます。トークン3個、次元D=4、ヘッド1個としましょう。

入力 X: (N=3, D=4)

x0 = [1, 0, 1, 0]

x1 = [0, 1, 0, 1]

x2 = [1, 1, 0, 0]

1) Q, K, Vを線形変換で作る (重みは学習された値)

-> Q, K, V: それぞれ (3, 4)

2) スコア = Q · K^T -> (3, 3) 行列

各要素 score[i][j] はトークンiがトークンjにどれだけ注目するかの原スコア

3) sqrt(d_k)=2 で割ってスケーリング

4) causal mask 適用 (デコーダなら):

score[0][1], score[0][2], score[1][2] -> 負の無限大

5) 行ごとに softmax -> アテンション重み (各行の合計が1)

6) 重み · V -> 出力 (3, 4)

トークンiの出力 = 自分自身と過去のトークンのVを加重平均したもの

核心は、score行列の各行が「このトークンが他のトークンをどれだけ参照するか」の分布であること、そしてcausal maskがその分布で未来を0にすることです。出力は結局、過去のトークンのValueを自分の関心に合わせて混ぜた結果です。

12. 出力層 — logitsからトークンへ

最後のブロックを通った`(B, N, D)`の表現は、LM head(通常は埋め込みと重みを共有する`(D, V)`行列)を経て、語彙全体に対するスコア(logits)になります。

最終表現: (B, N, D)

LM head (D, V) を掛ける: -> logits: (B, N, V)

softmax (最後の位置): -> 次トークンの確率分布 (V,)

生成時には最後の位置のlogitsだけを見て次のトークンを選びます。選び方(サンプリング)が出力の多様性を左右します。

デコーディング戦略のまとめ:

greedy : 毎ステップ最も確率の高いトークンを選択 -> 決定的, 単調になりうる

temperature: logitsをTで割り分布を平坦(T>1)/尖鋭(T<1)に調整

top-k : 上位k個のトークンの中だけからサンプリング

top-p : 累積確率pまでのトークンの中だけからサンプリング (nucleus)

def sample_next_token(logits, temperature=1.0, top_k=None):

logits: (V,)

logits = logits / max(temperature, 1e-6)

if top_k is not None:

v, _ = torch.topk(logits, top_k)

logits[logits < v[-1]] = float('-inf')

probs = F.softmax(logits, dim=-1)

return torch.multinomial(probs, num_samples=1)

この出力層とデコーディング戦略は、学習されたTransformerを実際の生成器にする最後のつなぎ目です。モデル構造自体は同じでも、どうサンプリングするかで結果の性格が大きく変わります。

13. 学習と推論の演算パターンの違い

同じTransformerでも、学習と推論は演算の性格が異なります。この違いを知ると、サービング最適化の出発点が見えてきます。

| 区分 | 学習(またはprefill) | 推論 decode段階 |

| --- | --- | --- |

| 処理単位 | 系列全体を並列 | トークン1個ずつ順次 |

| ボトルネック | 演算(行列積)バウンド | メモリ帯域バウンド |

| 行列の形 | 大きな行列積(N x N) | 薄い行列-ベクトル積 |

| 最適化ポイント | FLOPs, テンソルコア活用 | KV cache, メモリ帯域 |

推論のdecode段階がメモリバウンドである点が重要です。トークン一つを作るたびにモデル重みとKV cacheをメモリから読む必要があるため、演算よりメモリ読み込みがボトルネックになります。だから量子化(重みを小さく)とKV cache最適化がdecodeスループットに大きな効果を出します。2026年のサービングスタックのcontinuous(in-flight) batching, paged KV cache, FP8/INT4量子化, speculative decodingは、すべてこのメモリバウンド特性を攻略します。

落とし穴とトラブルシューティング

- **softmaxスケーリングの抜け**: sqrt(d_k)で割るのを忘れると、ヘッド次元が大きいとき学習が発散したり、アテンションが一つのトークンに過度に偏ったりします。

- **マスク適用位置の誤り**: causal maskはsoftmax前のスコア段階で負の無限大で埋めるべきです。softmax後に0を掛けると正規化が壊れます。

- **pre-normとpost-normの混同**: 深いモデルをゼロから学習するときpost-normを使うと、warmupなしでは発散しやすいです。特別な理由がなければpre-normを推奨します。

- **KV cacheのdtype不一致**: 重みがFP16なのにKV cacheをFP32にするとメモリが2倍になります。意図的な量子化でなければdtypeを合わせてください。

- **位置エンコーディングの長さ超過**: 絶対位置エンコーディングを学習長より長い入力に使うと性能が急落します。長コンテキストが必要ならRoPE系と外挿手法を検討すべきです。

- **ヘッド分割時のtranspose抜け**: `(B, N, H, d_k)`を`(B, H, N, d_k)`に変えるtransposeを忘れると、ヘッド次元と系列次元が混ざって静かに誤った結果が出ます。

おわりに

Transformerの1ブロックは結局、(1)アテンションでトークン間の情報を混ぜ、(2)FFNで位置ごとの非線形変換を行い、(3)残差と正規化で深いネットを安定化する、という三つの単純なアイデアの繰り返しです。そこに位置エンコーディングで順序を注入し、推論段階ではKV cacheで再計算を避けます。

この骨格を手にしていれば、RoPEは位置エンコーディングの変種、GQA/MQAはアテンションのKV効率化、FlashAttentionはアテンション演算のIO最適化、MoEはFFNの条件付き活性化として、すっきり整理できます。次の記事では、アテンションの進化(MQA/GQA/FlashAttention)と位置エンコーディング(RoPE)をそれぞれ深く掘り下げます。

参考資料

- Vaswani et al., "Attention Is All You Need" (arxiv 1706.03762): https://arxiv.org/abs/1706.03762

- Dao et al., "FlashAttention" (arxiv 2205.14135): https://arxiv.org/abs/2205.14135

- PyTorch Transformer ドキュメント: https://pytorch.org/docs/stable/nn.html#transformer-layers

- Hugging Face Transformers ドキュメント: https://huggingface.co/docs/transformers/index

- vLLM 公式ドキュメント(KV cache, サービング): https://docs.vllm.ai

- vLLM リポジトリ: https://github.com/vllm-project/vllm

- The Illustrated Transformer (Jay Alammar): https://jalammar.github.io/illustrated-transformer/

- Qwen モデルリポジトリ: https://github.com/QwenLM

현재 단락 (1/250)

2017年に「Attention Is All You Need」が登場して以来、Transformerは自然言語処理だけでなく、ビジョン、音声、コード、マルチモーダルまで、ほぼすべての深層学習分野の...

작성 글자: 0원문 글자: 10,952작성 단락: 0/250