Skip to content
Published on

Transformerアーキテクチャ完全解析: AttentionからモダンLLMまで

Authors

Transformerアーキテクチャ完全解析: AttentionからモダンLLMまで

Googleが2017年に発表した「Attention is All You Need」は、自然言語処理の世界を一変させました。逐次処理を行うRNNやLSTMアーキテクチャから脱却し、Attentionメカニズムのみで構築されたTransformerは、GPT、BERT、T5、LLaMA、すべての主要なモダンLLMの基盤となりました。このガイドでは、基礎原理から完全なアーキテクチャまで、各ステップでPyTorch実装を交えながら解説します。


1. Attentionの起源

1.1 RNN/LSTMの限界

Transformerが登場する以前、シーケンスモデリングはRNNとLSTMに大きく依存していました。これらのモデルは一度に1タイムステップずつ処理される隠れ状態を通じて情報を伝達しており、根本的な制限がありました:

長距離依存問題

RNNは離れた位置にある単語間の関係を学習するのに苦労します。「The cat that sat on the mat was hungry」という文で、「cat」と「was hungry」の関係はすべての中間トークンを通過する必要があります。シーケンスが長くなるにつれて、情報が薄れたり上書きされたりします。LSTMはゲーティングでこれを緩和しますが、完全な解決策にはなりません。

逐次処理

RNNはステップt+1を開始する前にステップtが完了する必要があります。つまり、モダンGPUの並列処理能力がシーケンス内部の計算には全く活用できません。バッチ間では並列化できますが、シーケンス内では不可能です。

勾配消失/爆発

数百〜数千タイムステップを通じた誤差逆伝播では、勾配がゼロに近づいたり制御不能に大きくなったりします。LSTMのゲートでも、数百トークンを超えるシーケンスは問題があります。

1.2 Attentionの直感

Attentionメカニズムは人間の読み方からインスピレーションを得ています。私たちはすべての単語に均等な注意を払うわけではなく、現在処理しているものに関連する単語により多くの注意を向けます。

例えば:「I saw the Eiffel Tower in Paris; it was breathtaking.」の「it」を処理する際、モデルは「Eiffel Tower」と「Paris」の両方に強く注目すべきです。すべての単語は距離に関わらず、他のすべての単語と直接相互作用できます。

Bahdanauら(2014年)はseq2seqモデルの補助コンポーネントとして最初のAttentionメカニズムを導入しました。Vaswaniら(2017年)はその後、決定的な一歩を踏み出しました:RNNを完全に排除し、Attentionのみですべてを構築したのです。


2. Scaled Dot-Product Attention

2.1 Q、K、Vのフレームワーク

TransformerはQuery、Key、Valueの3つのベクトルを使用してAttentionを実装しています。ソフトなデータベース検索として考えることができます:

  • Query (Q): 「何を探しているか?」— 現在の位置の表現
  • Key (K): 「どんな情報を持っているか?」— 各位置のラベル/識別子
  • Value (V): 「実際に何が格納されているか?」— 各位置のコンテンツ

QueryとKeyの間の類似度スコアを計算し、softmaxで正規化してAttentionウェイトにし、そのウェイトを使用してValueの加重和を生成します。

2.2 数式

Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V

ここで:

  • Q: Queryマトリックス(seq_len x d_k)
  • K: Keyマトリックス(seq_len x d_k)
  • V: Valueマトリックス(seq_len x d_v)
  • d_k: Key/Queryの次元
  • sqrt(d_k): スケーリング係数

なぜsqrt(d_k)でスケーリングするのか?

d_kが大きくなるにつれて、内積の大きさも比例して大きくなります。d_k=512では、生の内積が非常に大きくなり、softmaxがほぼゼロの勾配を持つ領域に入る可能性があります。sqrt(d_k)で除算することで、内積の分散を1に近い状態に保ち(qとkの成分が単位分散を持つと仮定)、勾配の問題を防ぎます。

2.3 マスキング

パディングマスク: バッチ内のシーケンスが異なる長さを持つ場合、短いシーケンスがパディングされます。softmaxがパディング位置にゼロのウェイトを割り当てるよう、パディング位置に-infを追加します。

因果マスク(先読みマスク): Decoderで使用されます。位置iは位置0からiまでにしかAttentionを向けるべきでなく、未来の位置には向けません。スコアマトリックスの上三角部分を-infで埋めます。

2.4 PyTorch実装

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


def scaled_dot_product_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    mask: torch.Tensor = None,
    dropout_p: float = 0.0,
) -> tuple:
    """
    Scaled Dot-Product Attention

    Args:
        query: (batch, heads, seq_len, d_k)
        key:   (batch, heads, seq_len, d_k)
        value: (batch, heads, seq_len, d_v)
        mask:  (batch, 1, 1, seq_len) or (batch, 1, seq_len, seq_len)
        dropout_p: ドロップアウト確率

    Returns:
        output: (batch, heads, seq_len, d_v)
        attn_weights: (batch, heads, seq_len, seq_len)
    """
    d_k = query.size(-1)

    # Q * K^T / sqrt(d_k): (batch, heads, seq_len, seq_len)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

    # マスクの適用
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # Softmax Attentionウェイト
    attn_weights = F.softmax(scores, dim=-1)

    # NaNの処理(すべての位置がマスクされている場合)
    attn_weights = torch.nan_to_num(attn_weights, nan=0.0)

    # ドロップアウト
    if dropout_p > 0.0:
        attn_weights = F.dropout(attn_weights, p=dropout_p)

    # 加重和: (batch, heads, seq_len, d_v)
    output = torch.matmul(attn_weights, value)

    return output, attn_weights


# クイックテスト
batch_size = 2
num_heads = 8
seq_len = 10
d_k = 64

q = torch.randn(batch_size, num_heads, seq_len, d_k)
k = torch.randn(batch_size, num_heads, seq_len, d_k)
v = torch.randn(batch_size, num_heads, seq_len, d_k)

# 因果マスク(下三角)
causal_mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0)

output, weights = scaled_dot_product_attention(q, k, v, mask=causal_mask)
print(f"Output shape: {output.shape}")       # (2, 8, 10, 64)
print(f"Weights shape: {weights.shape}")     # (2, 8, 10, 10)

3. Multi-Head Attention

3.1 なぜ複数のヘッドが必要なのか?

単一ヘッドAttentionは1つの表現部分空間のみで入力を処理します。Multi-Head Attentionは複数の独立したAttention操作を並列で実行し、それぞれがトークン関係の異なる側面を学習します:

  • ヘッド1: 構文的関係(主語-動詞の一致)
  • ヘッド2: 意味的類似性(同義語、関連概念)
  • ヘッド3: 位置的関係(隣接トークン)
  • ヘッド4: 照応(代名詞とその参照先)

各ヘッドはd_k = d_model / num_headsを使用するため、総計算量は単一の完全サイズAttentionと同様です。

3.2 数式

MultiHead(Q, K, V) = Concat(head_1, ..., head_h) * W_O

head_i = Attention(Q * W_Q_i, K * W_K_i, V * W_V_i)

3.3 完全なPyTorch実装

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # 効率のための単一の大きな射影
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

        self.dropout = nn.Dropout(dropout)
        self._init_weights()

    def _init_weights(self):
        for module in [self.W_q, self.W_k, self.W_v, self.W_o]:
            nn.init.xavier_uniform_(module.weight)

    def split_heads(self, x: torch.Tensor) -> torch.Tensor:
        """(batch, seq_len, d_model) -> (batch, num_heads, seq_len, d_k)"""
        batch_size, seq_len, _ = x.size()
        x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
        return x.transpose(1, 2)

    def combine_heads(self, x: torch.Tensor) -> torch.Tensor:
        """(batch, num_heads, seq_len, d_k) -> (batch, seq_len, d_model)"""
        batch_size, _, seq_len, _ = x.size()
        x = x.transpose(1, 2).contiguous()
        return x.view(batch_size, seq_len, self.d_model)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: torch.Tensor = None,
    ) -> tuple:
        """
        Args:
            query, key, value: (batch, seq_len, d_model)
            mask: (batch, 1, 1, seq_len) or (batch, 1, seq_len, seq_len)
        Returns:
            output: (batch, seq_len, d_model)
            attn_weights: (batch, num_heads, seq_len, seq_len)
        """
        Q = self.split_heads(self.W_q(query))
        K = self.split_heads(self.W_k(key))
        V = self.split_heads(self.W_v(value))

        attn_output, attn_weights = scaled_dot_product_attention(
            Q, K, V, mask=mask, dropout_p=self.dropout.p if self.training else 0.0
        )

        output = self.combine_heads(attn_output)
        output = self.W_o(output)

        return output, attn_weights


# テスト
mha = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(2, 10, 512)
out, weights = mha(x, x, x)
print(f"MHA output: {out.shape}")      # (2, 10, 512)
print(f"Attn weights: {weights.shape}")  # (2, 8, 10, 10)

4. Positional Encoding

4.1 なぜ位置情報が重要なのか?

AttentionはPermutation-invariant(置換不変)です:すべての入力トークンをシャッフルしても、Attentionスコアは同じです。Positional Encodingがないと、「I ate rice」と「Rice ate I」は同じAttentionパターンを生成します。シーケンス順序情報を注入する必要があります。

4.2 Sinusoidal Positional Encoding

元のTransformerは固定の正弦波関数を使用しています — 学習パラメータはありません:

PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

posはトークンの位置、iは次元インデックスです。

利点:

  • 学習するパラメータなし
  • 学習時より長いシーケンスへ外挿可能
  • PE(pos+k)はPE(pos)の線形関数として表現でき、相対位置をエンコード
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_seq_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len).unsqueeze(1).float()

        # 10000^(2i/d_model) = exp(2i * ln(10000) / d_model)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )

        pe[:, 0::2] = torch.sin(position * div_term)  # 偶数次元: sin
        pe[:, 1::2] = torch.cos(position * div_term)  # 奇数次元: cos

        pe = pe.unsqueeze(0)  # バッチブロードキャスト用に(1, max_seq_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """x: (batch, seq_len, d_model)"""
        seq_len = x.size(1)
        x = x + self.pe[:, :seq_len, :]
        return self.dropout(x)

4.3 RoPE(Rotary Position Embedding)

RoPEはモダンLLM(LLaMA、GPT-NeoX、PaLM、Mistral)の標準的な位置エンコーディングになっています。主なアイデアは、QとKのベクトルに回転行列を適用することで位置をエンコードすることです。

主要な特性:

  • 絶対位置ではなく相対位置をエンコード
  • QとKの内積が自動的に相対位置に依存
  • より長いシーケンスへの優れた外挿性
  • QとKにのみ適用 — Vには適用しない
def precompute_freqs_cis(dim: int, max_seq_len: int, theta: float = 10000.0):
    """RoPE周波数行列の事前計算"""
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(max_seq_len)
    freqs = torch.outer(t, freqs)  # (max_seq_len, dim/2)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # 複素単位ベクトル
    return freqs_cis


def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
    """QueryとKeyテンソルにRoPEを適用"""
    xq_r = xq.float().reshape(*xq.shape[:-1], -1, 2)
    xk_r = xk.float().reshape(*xk.shape[:-1], -1, 2)

    xq_complex = torch.view_as_complex(xq_r)
    xk_complex = torch.view_as_complex(xk_r)

    freqs_cis = freqs_cis[:xq.shape[1]]
    xq_out = torch.view_as_real(xq_complex * freqs_cis).flatten(-2)
    xk_out = torch.view_as_real(xk_complex * freqs_cis).flatten(-2)

    return xq_out.type_as(xq), xk_out.type_as(xk)

4.4 ALiBi(Attention with Linear Biases)

ALiBiはAttentionスコアに線形位置バイアスを追加します:

Score = Q * K^T / sqrt(d_k) - m * |i - j|

mはヘッド固有の傾きです。位置埋め込みは不要で、学習時より長いシーケンスへの汎化性が非常に高いです。


5. TransformerエンコーダEncoder

5.1 エンコーダのアーキテクチャ

エンコーダはN個の同一レイヤーを積み重ねたもので、各レイヤーには2つのサブレイヤーがあります:

  1. Multi-Head Self-Attention
  2. Position-wise Feed-Forward Network

各サブレイヤーは残差接続と層正規化でラップされています。

5.2 Feed-Forward Network

FFNは各位置に独立して適用される2層のMLPです:

FFN(x) = max(0, x * W_1 + b_1) * W_2 + b_2

元の論文ではd_model=512、d_ff=2048を使用しています。モダンLLMはSwiGLUとd_ff ≈ 2.67 * d_modelを使用しています。

5.3 Pre-LN vs Post-LN

# Post-LN(元のTransformer)
x = LayerNorm(x + Sublayer(x))

# Pre-LN(現代の標準)
x = x + Sublayer(LayerNorm(x))

Pre-LNはより学習が安定しており、学習率のウォームアップが不要なため、モダンモデルはこれを採用しています。

5.4 完全なエンコーダ実装

class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class EncoderLayer(nn.Module):
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        # Pre-LN Self-Attention
        attn_out, _ = self.self_attn(
            self.norm1(x), self.norm1(x), self.norm1(x), mask=mask
        )
        x = x + self.dropout(attn_out)

        # Pre-LN Feed-Forward
        x = x + self.dropout(self.ffn(self.norm2(x)))
        return x


class TransformerEncoder(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 512,
        num_heads: int = 8,
        num_layers: int = 6,
        d_ff: int = 2048,
        max_seq_len: int = 512,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len, dropout)
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        """x: (batch, seq_len) トークンインデックス"""
        x = self.embedding(x) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

6. TransformerデコーダDecoder

6.1 デコーダのアーキテクチャ

デコーダはレイヤーごとに3つのサブレイヤーを持っています:

  1. Masked Multi-Head Self-Attention(因果マスク付き)
  2. Multi-Head Cross-Attention(エンコーダの出力に注目)
  3. Feed-Forward Network

6.2 Cross-Attention

Cross-Attentionでは:

  • QueryはデコーダーDecoder現在の状態から来る
  • KeyとValueはエンコーダの出力から来る

これにより、デコーダは各出力トークンを生成する際に、ソースシーケンスのどの部分に注目するかを学習します。

6.3 自己回帰生成

推論時、デコーダは一度に1つのトークンを生成します:

  1. [BOS]トークンから始める
  2. これまでに生成したすべてのトークンを入力として使用
  3. 次のトークンを予測
  4. [EOS]が生成されるまで繰り返す
class DecoderLayer(nn.Module):
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        encoder_output: torch.Tensor,
        src_mask: torch.Tensor = None,
        tgt_mask: torch.Tensor = None,
    ) -> torch.Tensor:
        # 1. Masked Self-Attention(因果的)
        self_attn_out, _ = self.self_attn(
            self.norm1(x), self.norm1(x), self.norm1(x), mask=tgt_mask
        )
        x = x + self.dropout(self_attn_out)

        # 2. エンコーダ出力へのCross-Attention
        cross_attn_out, _ = self.cross_attn(
            self.norm2(x), encoder_output, encoder_output, mask=src_mask
        )
        x = x + self.dropout(cross_attn_out)

        # 3. Feed-Forward
        x = x + self.dropout(self.ffn(self.norm3(x)))
        return x

7. 完全なTransformer実装

class Transformer(nn.Module):
    """完全なEncoder-Decoder Transformer"""

    def __init__(
        self,
        src_vocab_size: int,
        tgt_vocab_size: int,
        d_model: int = 512,
        num_heads: int = 8,
        num_encoder_layers: int = 6,
        num_decoder_layers: int = 6,
        d_ff: int = 2048,
        max_seq_len: int = 512,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.d_model = d_model

        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len, dropout)

        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_encoder_layers)
        ])
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_decoder_layers)
        ])

        self.encoder_norm = nn.LayerNorm(d_model)
        self.decoder_norm = nn.LayerNorm(d_model)
        self.output_projection = nn.Linear(d_model, tgt_vocab_size)

        self._init_weights()

    def _init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def encode(self, src: torch.Tensor, src_mask: torch.Tensor = None) -> torch.Tensor:
        x = self.src_embedding(src) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        for layer in self.encoder_layers:
            x = layer(x, src_mask)
        return self.encoder_norm(x)

    def decode(
        self,
        tgt: torch.Tensor,
        encoder_output: torch.Tensor,
        src_mask: torch.Tensor = None,
        tgt_mask: torch.Tensor = None,
    ) -> torch.Tensor:
        x = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        for layer in self.decoder_layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.decoder_norm(x)

    def forward(
        self,
        src: torch.Tensor,
        tgt: torch.Tensor,
        src_mask: torch.Tensor = None,
        tgt_mask: torch.Tensor = None,
    ) -> torch.Tensor:
        encoder_output = self.encode(src, src_mask)
        decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)
        return self.output_projection(decoder_output)

    @torch.no_grad()
    def generate(
        self,
        src: torch.Tensor,
        bos_token_id: int,
        eos_token_id: int,
        max_new_tokens: int = 100,
        temperature: float = 1.0,
    ) -> torch.Tensor:
        """グリーディデコーディング"""
        self.eval()
        device = src.device

        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        encoder_output = self.encode(src, src_mask)
        tgt = torch.tensor([[bos_token_id]], device=device)

        for _ in range(max_new_tokens):
            seq_len = tgt.size(1)
            tgt_mask = torch.tril(
                torch.ones(seq_len, seq_len, device=device)
            ).unsqueeze(0).unsqueeze(0)

            logits = self.decode(tgt, encoder_output, src_mask, tgt_mask)
            next_token_logits = logits[:, -1, :] / temperature
            next_token = next_token_logits.argmax(dim=-1, keepdim=True)
            tgt = torch.cat([tgt, next_token], dim=1)

            if next_token.item() == eos_token_id:
                break

        return tgt


# モデルの作成とテスト
model = Transformer(src_vocab_size=32000, tgt_vocab_size=32000)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")  # ~44M

src = torch.randint(1, 32000, (2, 20))
tgt = torch.randint(1, 32000, (2, 15))
logits = model(src, tgt)
print(f"Logits shape: {logits.shape}")  # (2, 15, 32000)

8. Transformerのバリアント

8.1 BERT(エンコーダのみ)

BERT(Bidirectional Encoder Representations from Transformers)はエンコーダのみを使用し、2つの事前学習目標を持っています:

Masked Language Modeling(MLM): トークンの15%がランダムにマスクされます。モデルは双方向コンテキストを使用してマスクされたトークンを予測します。これによりBERTは理解タスクに優れています。

Next Sentence Prediction(NSP): 2つの文が連続しているかどうかを予測します。後の研究ではNSPの貢献は少なく、RoBERTaはパフォーマンス損失なしにNSPを削除しました。

BERTは分類、QA、NERに優れていますが、直接テキストを生成することはできません。

8.2 GPT(デコーダのみ)

GPT(Generative Pre-trained Transformer)はデコーダスタックのみを使用し、因果言語モデリングを行います:すべての前のトークンから次のトークンを予測します。

アーキテクチャはCross-Attentionなしの簡略化されたデコーダで、Masked Self-AttentionとFFNのみです。GPT-2、GPT-3、GPT-4、LLaMA、Mistral、ほとんどのモダンLLMはこのデコーダのみの設計に従っています。

8.3 T5とBART(エンコーダ-デコーダ)

T5(Text-To-Text Transfer Transformer)はすべてのNLPタスクをtext-to-textとして統一します。翻訳、要約、分類、QAはすべて同一の入出力フォーマットを使用します。

BARTはトークンマスキング、文シャッフル、ドキュメント回転など、様々な破損戦略を持つノイズ除去オートエンコーダとして事前学習されます。

8.4 Vision Transformer(ViT)

ViTは画像を16×16のパッチに分割し、各パッチをトークン埋め込みに線形射影し、位置埋め込みを追加して、シーケンスを標準のTransformerエンコーダに入力します。

大規模な事前学習により、ViTは画像分類ベンチマークでCNNに匹敵し、それを超えます。


9. モダンLLMの最適化

9.1 RMSNorm

モダンLLMはLayerNormをRMSNormに置き換えます:

RMSNorm(x) = x / RMS(x) * g
RMS(x) = sqrt(mean(x^2) + epsilon)

平均の減算が不要で、同等のパフォーマンスでより高速です。LLaMA、Mistral、Gemmaで使用されています。

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).sqrt()
        return x / rms * self.weight

9.2 SwiGLU活性化

SwiGLUはSwishとGated Linear Unitsを組み合わせます:

SwiGLU(x, W, V) = Swish(x * W) * (x * V)
Swish(x) = x * sigmoid(x) = x * sigma(x)

FFNの次元は通常d_ff = (2/3) _ 4 _ d_model ≈ 2.67 * d_modelに調整されます。

class SwiGLUFeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int = None):
        super().__init__()
        if d_ff is None:
            d_ff = int(2 * 4 * d_model / 3)
            d_ff = ((d_ff + 63) // 64) * 64  # 64の倍数に丸める

        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w2 = nn.Linear(d_ff, d_model, bias=False)
        self.w3 = nn.Linear(d_model, d_ff, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

9.3 Grouped Query Attention(GQA)

MHA: すべてのヘッドが独立したQ、K、Vを持つ。 MQA: すべてのヘッドがK、Vを共有。 GQA: Gグループが各グループ内でK、Vを共有。

LLaMA 2、Mistral、GemmaはKVキャッシュサイズを削減しながら品質を維持するためにGQAを使用しています。

class GroupedQueryAttention(nn.Module):
    def __init__(
        self,
        d_model: int,
        num_heads: int,
        num_kv_heads: int,
        dropout: float = 0.0,
    ):
        super().__init__()
        assert num_heads % num_kv_heads == 0

        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.num_rep = num_heads // num_kv_heads
        self.d_head = d_model // num_heads

        self.wq = nn.Linear(d_model, num_heads * self.d_head, bias=False)
        self.wk = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
        self.wv = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
        self.wo = nn.Linear(num_heads * self.d_head, d_model, bias=False)
        self.dropout = dropout

    def repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
        """KVヘッドをQヘッド数に合わせて繰り返す"""
        if self.num_rep == 1:
            return x
        batch, n_kv, seq_len, d_head = x.shape
        return x.unsqueeze(2).expand(
            batch, n_kv, self.num_rep, seq_len, d_head
        ).reshape(batch, n_kv * self.num_rep, seq_len, d_head)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        batch, seq_len, _ = x.shape

        xq = self.wq(x).view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
        xk = self.wk(x).view(batch, seq_len, self.num_kv_heads, self.d_head).transpose(1, 2)
        xv = self.wv(x).view(batch, seq_len, self.num_kv_heads, self.d_head).transpose(1, 2)

        xk = self.repeat_kv(xk)
        xv = self.repeat_kv(xv)

        scores = torch.matmul(xq, xk.transpose(-2, -1)) / math.sqrt(self.d_head)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn = F.softmax(scores, dim=-1)
        attn = F.dropout(attn, p=self.dropout, training=self.training)

        out = torch.matmul(attn, xv)
        out = out.transpose(1, 2).contiguous().view(batch, seq_len, -1)
        return self.wo(out)

9.4 KVキャッシュ

自己回帰生成中、すべてのステップで前のすべてのトークンのKとVを再計算するのは無駄です。KVキャッシュは以前に計算されたKとVのテンソルを保存して再利用します:

class KVCacheAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        self.wq = nn.Linear(d_model, d_model, bias=False)
        self.wk = nn.Linear(d_model, d_model, bias=False)
        self.wv = nn.Linear(d_model, d_model, bias=False)
        self.wo = nn.Linear(d_model, d_model, bias=False)
        self.cache_k = None
        self.cache_v = None

    def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
        batch, seq_len, _ = x.shape

        xq = self.wq(x).view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
        xk = self.wk(x).view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
        xv = self.wv(x).view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)

        if start_pos > 0 and self.cache_k is not None:
            self.cache_k = torch.cat([self.cache_k, xk], dim=2)
            self.cache_v = torch.cat([self.cache_v, xv], dim=2)
        else:
            self.cache_k = xk
            self.cache_v = xv

        scores = torch.matmul(xq, self.cache_k.transpose(-2, -1)) / math.sqrt(self.d_head)
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, self.cache_v)
        out = out.transpose(1, 2).contiguous().view(batch, seq_len, -1)
        return self.wo(out)

10. Flash Attention

10.1 標準Attentionのメモリ問題

AttentionマトリックスはメモリにおいてO(N^2)です。N=8192の場合、Attentionマトリックスだけで8192 x 8192 x 4バイト ≈ FP32で256MBが必要です。このマトリックスをGPUのHBMに書き込み・読み出すことで帯域幅のボトルネックが生まれます。

モダンGPUはFLOP束縛(演算多い)です:データを移動させるよりもはるかに多くの演算を実行できます。標準AttentionはメモリBW束縛で、FLOPカウントが示すよりもはるかに遅く動作します。

10.2 IO-Awareアルゴリズム

Flash Attention(Daoら、2022年)は完全なAttentionマトリックスをHBMにマテリアライズせずに正確なAttentionを計算します。

コアアイデア:タイリング

Q、K、VをSRAM(高速オンチップメモリ)に収まるブロックに分割します。Online Softmaxアルゴリズムを使用して一度に1ブロックを処理し、すべてのスコアを一度に見ることなく正しいsoftmax結果を累積します。

アルゴリズムの概要:

  1. Q_iブロックをSRAMに読み込む
  2. 各K_j、V_jブロックについて:SRAMに読み込み、S_ij = Q_i * K_j^Tを計算
  3. Running max/sumを使用してsoftmaxを段階的に更新
  4. 出力O_iに累積

計算量:

  • メモリ:O(N)(O(N^2)ではない)— 完全なAttentionマトリックスを保存しない
  • FLOP:O(N^2) — 標準と同じ
  • ウォールクロック速度:A100でHBMの読み書きが大幅に少なくなるため、2〜4倍高速

10.3 Flash Attentionのバージョン

Flash Attention 1(2022年): 最初のIO-Aware Attentionの公式化。フォワードとバックワードパスのためのカスタムCUDAカーネル。

Flash Attention 2(2023年): Q上の外側ループ、K/V上の内側ループ(より良い並列性)。最適化されたワープレベルの作業分割。約2倍の追加高速化。

Flash Attention 3(2024年): H100のWGMMA(Warp Group Matrix Multiply-Accumulate)とTMA(Tensor Memory Accelerator)の非同期コピーを活用。さらに1.5〜2倍の高速化。

10.4 使い方

import torch
import torch.nn.functional as F

# PyTorch 2.0+の組み込みFlash Attention
# CUDAでは可能な場合、自動的にFlash Attentionを使用
q = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)

# is_causal=Trueは因果マスクを効率的に使用
with torch.backends.cuda.sdp_kernel(
    enable_flash=True,
    enable_math=False,
    enable_mem_efficient=False
):
    output = F.scaled_dot_product_attention(q, k, v, is_causal=True)

print(f"Output shape: {output.shape}")  # (2, 8, 1024, 64)

# flash-attnパッケージの直接使用
# pip install flash-attn --no-build-isolation
try:
    from flash_attn import flash_attn_qkvpacked_func

    qkv = torch.randn(2, 1024, 3, 8, 64, device='cuda', dtype=torch.float16)
    out = flash_attn_qkvpacked_func(qkv, dropout_p=0.0, causal=True)
    print(f"Flash Attention output: {out.shape}")
except ImportError:
    print("flash-attn not installed")


class FlashAttentionMHA(nn.Module):
    """PyTorchの組み込みFlash AttentionをつかったMHA"""

    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.0):
        super().__init__()
        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        self.wqkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.wo = nn.Linear(d_model, d_model, bias=False)
        self.dropout = dropout

    def forward(self, x: torch.Tensor, is_causal: bool = False) -> torch.Tensor:
        batch, seq_len, d_model = x.shape

        qkv = self.wqkv(x).view(batch, seq_len, 3, self.num_heads, self.d_head)
        q, k, v = qkv.unbind(dim=2)

        q = q.transpose(1, 2)  # (batch, heads, seq_len, d_head)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        out = F.scaled_dot_product_attention(
            q, k, v,
            dropout_p=self.dropout if self.training else 0.0,
            is_causal=is_causal,
        )

        out = out.transpose(1, 2).contiguous().view(batch, seq_len, d_model)
        return self.wo(out)

11. Mixture of Experts(MoE)

11.1 コアアイデア

Mixture of Expertsは、フォワードパスごとの計算量を一定に保ちながら、モデル容量(パラメータ数)をスケールします。重要なのは:各トークンは少数の「エキスパート」ネットワークのサブセットにのみルーティングされます。

Top-2ルーティングを持つ8エキスパートのMoE:

  • 総パラメータ数:8エキスパートのFFN = 密なモデルのFFNの8倍
  • トークンごとのアクティブパラメータ:2エキスパート = 密なモデルと同じ計算量

これにより有利なトレードオフが実現:推論コストを比例的に増やすことなく、より大きなモデル(より多くの容量/知識)が得られます。

11.2 Top-kルーティング

class MoELayer(nn.Module):
    """Mixture of Experts FFNレイヤー"""

    def __init__(
        self,
        d_model: int,
        d_ff: int,
        num_experts: int = 8,
        top_k: int = 2,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k

        # ルーター / ゲーティングネットワーク
        self.router = nn.Linear(d_model, num_experts, bias=False)

        # エキスパート
        self.experts = nn.ModuleList([
            SwiGLUFeedForward(d_model, d_ff)
            for _ in range(num_experts)
        ])

        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> tuple:
        batch, seq_len, d_model = x.shape
        x_flat = x.view(-1, d_model)

        router_logits = self.router(x_flat)
        router_probs = F.softmax(router_logits, dim=-1)

        topk_probs, topk_indices = router_probs.topk(self.top_k, dim=-1)
        topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True)

        output = torch.zeros_like(x_flat)

        for i in range(self.top_k):
            expert_idx = topk_indices[:, i]
            prob = topk_probs[:, i:i+1]

            for e in range(self.num_experts):
                token_mask = (expert_idx == e)
                if token_mask.any():
                    expert_input = x_flat[token_mask]
                    expert_output = self.experts[e](expert_input)
                    output[token_mask] += prob[token_mask] * expert_output

        aux_loss = self._load_balancing_loss(router_probs)
        return self.dropout(output.view(batch, seq_len, d_model)), aux_loss

    def _load_balancing_loss(self, router_probs: torch.Tensor) -> torch.Tensor:
        """エキスパート崩壊を防ぐSwitch Transformer補助損失"""
        num_tokens = router_probs.size(0)
        avg_probs = router_probs.mean(dim=0)
        top1_indices = router_probs.argmax(dim=-1)
        expert_counts = torch.bincount(top1_indices, minlength=self.num_experts).float()
        expert_fractions = expert_counts / num_tokens
        return self.num_experts * (avg_probs * expert_fractions).sum()

11.3 MixtralとDeepSeek

Mixtral 8x7B:

  • 8エキスパートのFFN、Top-2ルーティング
  • 総パラメータ数:46.7B;トークンごとのアクティブパラメータ:12.9B
  • 標準MHA + RoPE + SwiGLU + グループ化エキスパートを使用

DeepSeek MoE:

  • 細粒度のエキスパート分割:より多くの、より小さなエキスパート
  • 共有エキスパート:すべてのトークンを処理するベースエキスパート + ルーティングされたエキスパート
  • 高度なエキスパート崩壊防止戦略

11.4 MoEのトレードオフ

利点:

  • FLOPsあたりの大きなモデル容量
  • 異なるトークンタイプへのエキスパートの専門化
  • 効率的なスケーリング

欠点:

  • アクティブなのは2つだけでも、すべてのエキスパートパラメータがメモリに収まる必要がある
  • エキスパートがGPU間でシャーディングされる場合のデバイス間通信オーバーヘッド(エキスパート並列性)
  • 補助損失なしでは不均衡なルーティングによる学習不安定性

まとめ

Transformerはseq2seq翻訳モデルから、実質的にすべての最先端AIシステムを動かす普遍的なアーキテクチャへと進化しました。

取り上げた主要な概念:

  1. Scaled Dot-Product Attention: Q/K/Vを使ったソフトデータベース検索
  2. Multi-Head Attention: 複数の部分空間での並列Attention
  3. Positional Encoding: 正弦波PE → RoPE → ALiBi
  4. エンコーダ/デコーダ: BERTとGPTファミリーの基礎
  5. モダン最適化: RMSNorm、SwiGLU、GQA、KVキャッシュ
  6. Flash Attention: IO-Awareメモリ効率の高い正確なAttention
  7. MoE: 効率的なスケーリングのためのスパース活性化

次に探求すべきトピック:

  • 推論高速化のためのSpeculative Decoding
  • LoRA/QLoRAファインチューニング
  • アライメント技術(RLHF、DPO)
  • vLLMとTensorRT-LLMを使った本番サービング

参考資料