Skip to content
Published on

LLMコンテキストウィンドウ拡張技術完全ガイド:RoPE、ALiBi、YaRNからRing Attentionまで

Authors
  • Name
    Twitter

はじめに

GPT-2は1024トークン、GPT-3は2048トークンから始まりましたが、2026年現在、Claudeは200K、Geminiは2Mトークンのコンテキストウィンドウをサポートしています。この爆発的な拡張はどのようにして可能になったのでしょうか?

核心は**位置エンコーディング(Positional Encoding)**の進化にあります。この記事では、絶対位置エンコーディングから最新のRoPE拡張技法まで、数学的原理と実践コードを合わせて解説します。

位置エンコーディングの進化

1. 絶対位置エンコーディング(GPT時代)

import torch
import torch.nn as nn

class AbsolutePositionalEncoding(nn.Module):
    """学習可能な絶対位置埋め込み"""
    def __init__(self, max_len, d_model):
        super().__init__()
        # 学習可能なパラメータ - max_lenは固定
        self.pe = nn.Embedding(max_len, d_model)

    def forward(self, x):
        positions = torch.arange(x.size(1), device=x.device)
        return x + self.pe(positions)

限界: max_lenを超える位置への汎化が不可能です。

2. Sinusoidal位置エンコーディング(Transformer原論文)

import math

class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, max_len, d_model):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

数学的にはどんな長さの相対位置も線形変換で表現可能ですが、実際には学習範囲を超えると性能が急激に低下します。

RoPE (Rotary Position Embedding)

Su et al. (2021)が提案したRoPEは、現在ほとんどのLLM(LLaMA、Qwen、DeepSeekなど)が使用する位置エンコーディングです。

数学的原理

RoPEはトークンの位置を複素平面上の回転としてエンコードします。

2D次元ペア(q_{2i}, q_{2i+1})に対して:

RoPE(q, m) = q * e^{im*theta_i}

ここでtheta_i = base^{-2i/d}、デフォルトのbase=10000です。

import torch

def precompute_freqs_cis(dim, max_seq_len, base=10000.0):
    """RoPE周波数の事前計算"""
    # theta_i = base^(-2i/d)
    freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    # 位置ごとの角度: m * theta_i
    t = torch.arange(max_seq_len)
    freqs = torch.outer(t, freqs)  # (max_seq_len, dim/2)
    # 複素数に変換: e^{i*m*theta}
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

def apply_rotary_emb(xq, xk, freqs_cis):
    """クエリとキーにRoPEを適用"""
    # 実数テンソルを複素数に変換
    xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

    # 回転を適用
    freqs_cis = freqs_cis[:xq.shape[1]]
    xq_out = torch.view_as_real(xq_complex * freqs_cis).flatten(-2)
    xk_out = torch.view_as_real(xk_complex * freqs_cis).flatten(-2)

    return xq_out.type_as(xq), xk_out.type_as(xk)

RoPEの核心的性質

相対位置依存性: 位置mのクエリと位置nのキーの内積は(m-n)にのみ依存します。

# 証明:
# <RoPE(q,m), RoPE(k,n)>
# = <q*e^{im*theta}, k*e^{in*theta}>
# = <q, k> * e^{i(m-n)*theta}  <- 相対位置 (m-n)にのみ依存!

コンテキスト拡張技法

1. Position Interpolation (PI)

Metaが提案した最もシンプルな方法:位置インデックスをスケールダウンします。

def position_interpolation(freqs_cis, original_max_len, target_max_len):
    """位置を線形補間して拡張"""
    scale = original_max_len / target_max_len
    t = torch.arange(target_max_len) * scale  # 位置を圧縮
    # 残りは同じ
    freqs = 1.0 / (10000.0 ** (torch.arange(0, dim, 2).float() / dim))
    freqs = torch.outer(t, freqs)
    return torch.polar(torch.ones_like(freqs), freqs)

# 例: 4K -> 32K拡張
# 位置0〜31999を0〜3999の範囲にマッピング

問題点: 近い位置の解像度が低下します。隣接トークンの区別が困難になります。

2. NTK-aware Scaling

base値を調整して高周波は維持し、低周波のみを拡張します。

def ntk_aware_scaling(dim, max_seq_len, base=10000.0, scale_factor=4):
    """NTK-aware RoPE: baseを調整して高周波を保存"""
    # baseをscale_factorに比例して増加
    new_base = base * (scale_factor ** (dim / (dim - 2)))

    freqs = 1.0 / (new_base ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(max_seq_len)
    freqs = torch.outer(t, freqs)
    return torch.polar(torch.ones_like(freqs), freqs)

# 4K -> 16K (scale_factor=4)
# base: 10000 -> ~161489

3. NTK-by-Parts (Dynamic NTK)

次元ごとに異なるスケーリングを適用します。高周波次元(近い位置の解像度)はほぼそのまま維持し、低周波次元(遠い位置)のみを拡張します。

def ntk_by_parts(dim, max_seq_len, base=10000.0, original_max_len=4096,
                 target_max_len=32768, beta_fast=32, beta_slow=1):
    """次元別の差異的スケーリング"""
    scale = target_max_len / original_max_len

    # 元の周波数
    freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))

    # 各周波数の波長
    wavelengths = 2 * math.pi / freqs

    # 補間比率の計算(ランプ関数)
    low_threshold = original_max_len / beta_fast
    high_threshold = original_max_len / beta_slow

    ramp = (wavelengths - low_threshold) / (high_threshold - low_threshold)
    ramp = ramp.clamp(0, 1)

    # 高周波(ramp=0): 元を維持、低周波(ramp=1): 完全補間
    scaled_freqs = freqs / scale  # PI方式
    new_freqs = (1 - ramp) * freqs + ramp * scaled_freqs

    t = torch.arange(target_max_len)
    freqs_matrix = torch.outer(t, new_freqs)
    return torch.polar(torch.ones_like(freqs_matrix), freqs_matrix)

4. YaRN (Yet another RoPE extensioN)

NTK-by-PartsにAttention Temperature Scalingを追加した方法です。DeepSeek、Qwen、LLaMAなど、ほとんどの最新LLMがYaRNを採用しています。

def yarn_rope(dim, max_seq_len, base=10000.0, original_max_len=4096,
              target_max_len=131072, beta_fast=32, beta_slow=1):
    """YaRN: NTK-by-Parts + Temperature Scaling"""
    scale = target_max_len / original_max_len

    freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))

    wavelengths = 2 * math.pi / freqs
    low_threshold = original_max_len / beta_fast
    high_threshold = original_max_len / beta_slow

    ramp = (wavelengths - low_threshold) / (high_threshold - low_threshold)
    ramp = ramp.clamp(0, 1)

    scaled_freqs = freqs / scale
    new_freqs = (1 - ramp) * freqs + ramp * scaled_freqs

    # 核心: Attention Temperature
    # sqrt(s)補正でattention scoreの分布を保存
    temperature = 0.1 * math.log(scale) + 1.0

    t = torch.arange(target_max_len)
    freqs_matrix = torch.outer(t, new_freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs_matrix), freqs_matrix)

    return freqs_cis, temperature

# Attention計算時:
# attn_weights = (Q @ K.T) / (sqrt(d) * temperature)

YaRNの効率性: 元の学習データの0.1%のみでファインチューニングして、コンテキストを32倍に拡張できます。

ALiBi (Attention with Linear Biases)

RoPEとは異なるアプローチで、位置エンコーディングなしにattention scoreに線形バイアスを加えます。

def alibi_bias(num_heads, max_seq_len):
    """ALiBi: attentionに距離ベースの線形ペナルティ"""
    # ヘッドごとの傾き
    slopes = torch.pow(2, -torch.arange(1, num_heads + 1) * 8.0 / num_heads)

    # 距離行列
    positions = torch.arange(max_seq_len)
    distances = positions.unsqueeze(0) - positions.unsqueeze(1)  # (L, L)
    distances = distances.abs().neg()  # 負の距離

    # ヘッドごとのバイアス
    biases = slopes.unsqueeze(1).unsqueeze(2) * distances.unsqueeze(0)
    return biases  # (num_heads, L, L)

# Attention計算:
# attn = softmax(Q @ K.T / sqrt(d) + alibi_bias)

利点: 追加学習なしで学習長を超える外挿(extrapolation)が可能です。 欠点: 遠い距離のトークンに過度なペナルティを付与するため、長距離依存性が弱くなります。

Ring Attention — ハードウェアレベルの拡張

Ring Attentionはシーケンスを複数のGPUに分散して、メモリの限界を超えるコンテキストを処理します。

# Ring Attention疑似コード
def ring_attention(Q, K, V, num_devices):
    """
    シーケンスをnum_devices個に分割して分散処理。
    各デバイスは自身のQチャンクと循環するKVチャンクでattentionを計算。
    """
    seq_len = Q.shape[1]
    chunk_size = seq_len // num_devices

    # 各デバイスのQチャンク
    Q_local = Q[:, rank * chunk_size:(rank + 1) * chunk_size]

    # KVをリング形態で循環
    K_local = K[:, rank * chunk_size:(rank + 1) * chunk_size]
    V_local = V[:, rank * chunk_size:(rank + 1) * chunk_size]

    output = torch.zeros_like(Q_local)
    log_sum_exp = torch.full_like(Q_local[:, :, :1], float('-inf'))

    for step in range(num_devices):
        # 現在のKVチャンクとattentionを計算
        attn = Q_local @ K_local.transpose(-2, -1) / math.sqrt(d)
        # Online softmaxで累積
        output, log_sum_exp = online_softmax_update(
            output, log_sum_exp, attn, V_local
        )
        # KVを次のデバイスから受け取る(リング通信)
        K_local = ring_send_recv(K_local)
        V_local = ring_send_recv(V_local)

    return output

Ring Attentionにより、Geminiは2Mトークンのコンテキストを実現しました。

実践:HuggingFaceでYaRNを適用

from transformers import AutoModelForCausalLM, AutoTokenizer

# Qwen2.5-7B(デフォルト32K -> 131K拡張)
model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-7B",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    rope_scaling={
        "type": "yarn",
        "factor": 4.0,
        "original_max_position_embeddings": 32768,
    }
)

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B")

# 長文ドキュメントの処理
long_text = open("long_document.txt").read()
inputs = tokenizer(long_text, return_tensors="pt", truncation=True, max_length=131072)
outputs = model.generate(**inputs, max_new_tokens=512)
print(tokenizer.decode(outputs[0]))

性能比較まとめ

技法学習の必要性拡張倍率長距離性能使用モデル
PIファインチューン4-8倍普通初期LLaMA2
NTK-awareなし/少量4-8倍良好Code Llama
YaRN少量 (0.1%)4-32倍非常に良好Qwen, DeepSeek, LLaMA3
ALiBiなし無制限(理論上)弱いMPT, BLOOM
Ring AttentionなしGPU数に比例非常に良好Gemini

クイズ:LLMコンテキスト拡張の理解度チェック(8問)

Q1. RoPEが絶対位置エンコーディングに対して持つ核心的な利点は?

相対位置依存性:2つのトークン間のattentionは絶対位置ではなく相対距離(m-n)にのみ依存します。

Q2. Position Interpolationの欠点は?

位置を圧縮するため、隣接トークン間の解像度が低下し、近いトークンの区別が困難になります。

Q3. NTK-by-PartsがPIより優れている理由は?

高周波次元(近い位置の解像度)は維持し、低周波次元(遠い位置)のみを選択的に拡張します。

Q4. YaRNにおけるTemperature Scalingの役割は?

コンテキスト拡張時にattention scoreの分布が変化するのを補正し、softmaxのエントロピーを維持します。

Q5. ALiBiの動作原理を一文で説明すると?

attention scoreにトークン距離に比例する負の線形バイアスを加えて、遠いトークンにペナルティを付与します。

Q6. Ring Attentionがメモリの限界を克服する方法は?

シーケンスを複数のGPUに分散し、KVをリング形態で循環させながらonline softmaxで累積計算します。

Q7. YaRNで32Kを131Kに拡張する際に必要な学習データ量は?

元の学習データの約0.1%(約400Mトークン)で十分です。

Q8. RoPEのbase値を大きくするとどのような効果がありますか?

周波数が全体的に低くなり、より遠い位置まで区別できるようになりますが、近い位置の解像度は低下します。