Skip to content
Published on

LLM事前学習とスケーリング則:Chinchillaからフラッシュアテンション・MoEまで

Authors

はじめに

2020年にGPT-3が登場して以来、「モデルを大きくすれば性能が上がる」という直感が支配的でした。しかし2022年のDeepMind社によるChinchillaの研究は、この直感に真っ向から挑戦しました。パラメータ数を4分の1にしてデータを4倍にしたモデルがGPT-3を上回ったのです。本ガイドでは、スケーリング則の数学的基盤から最新の事前学習レシピまで、LLM事前学習の核心を体系的に解説します。


1. スケーリング則

1.1 Kaplan et al. (2020) — OpenAIスケーリング則

Kaplan et al.は言語モデルの損失がパラメータ数(N)、データサイズ(D)、計算量(C)のべき乗則に従うことを実験的に示しました。

L(N)(NcN)αNL(N) \approx \left(\frac{N_c}{N}\right)^{\alpha_N}

L(D)(DcD)αDL(D) \approx \left(\frac{D_c}{D}\right)^{\alpha_D}

重要な発見: 固定の計算予算のもとでは、モデルサイズに大きく投資する方が有利であるというものでした。これがGPT-3(1750億パラメータ)のような巨大モデル開発の理論的根拠になりました。

1.2 Chinchilla (Hoffmann et al. 2022) — 再定義されたスケーリング則

DeepMindチームはより厳密な実験設計によって、Kaplanの結論がデータ側を過小評価していたことを示しました。

Chinchilla則の核心:

計算予算Cが与えられたとき、最適モデルサイズ NoptN_{opt} と最適トークン数 DoptD_{opt} は以下の関係を満たします。

Nopt0.2C0.5N_{opt} \approx 0.2 \cdot C^{0.5}

Dopt10C0.5D_{opt} \approx 10 \cdot C^{0.5}

すなわち、パラメータ1個あたり約20の学習トークンが最適という結論です。

モデルパラメータ数学習トークン数Chinchilla最適比
GPT-3175B300B1:1.7(データ不足)
Chinchilla70B1.4T1:20(最適)
Llama 3.18B15T1:1875(過学習)

Llama 3.1のようにChinchilla最適よりも意図的に多くのデータで学習する場合は推論効率のためです。小さなモデルをより多く学習させることでデプロイコストが削減できます。

1.3 最適な計算リソース配分の計算

import math

def chinchilla_optimal(compute_budget_flops: float):
    """
    Chinchilla則に基づく最適N、Dの計算
    compute_budget_flops: 総FLOPs (6 * N * Dで近似)
    Returns: (optimal_params, optimal_tokens)
    """
    # Hoffmann et al. Table A3の係数を使用
    # C = 6 * N * D の近似から:
    # N_opt = sqrt(C / (6 * 20)), D_opt = 20 * N_opt
    N_opt = math.sqrt(compute_budget_flops / (6 * 20))
    D_opt = 20 * N_opt
    return N_opt, D_opt

# 例: A100 1000台、30日間の学習
# A100: ~312 TFLOPS (bf16)、稼働率40%
flops_per_sec = 1000 * 312e12 * 0.4
duration_sec = 30 * 24 * 3600
total_flops = flops_per_sec * duration_sec

N_opt, D_opt = chinchilla_optimal(total_flops)
print(f"総FLOPs: {total_flops:.2e}")
print(f"最適パラメータ数: {N_opt/1e9:.1f}B")
print(f"最適トークン数: {D_opt/1e12:.1f}T")

2. データ準備

2.1 Common Crawlフィルタリングパイプライン

Common Crawlは月次で数十TBのウェブクロールデータを提供しています。そのまま使うと品質が非常に低いため、多段階フィルタリングが不可欠です。

典型的なフィルタリングパイプライン:

  1. 言語検出 (fastText): 対象言語のみを保持
  2. 品質フィルタ: 最小文数、単語繰り返し比率、特殊文字密度
  3. ドメインブロックリスト: スパム、アダルト、広告ドメインの除去
  4. パープレキシティフィルタ: n-gram言語モデルによる低品質テキストの除去
  5. 重複排除: MinHash LSHによるファジー重複排除

2.2 MinHashによる重複排除

from datasketch import MinHash, MinHashLSH
import re

def get_shingles(text: str, k: int = 5):
    """テキストから文字kシングル集合を生成"""
    text = re.sub(r'\s+', ' ', text.lower())
    return {text[i:i+k] for i in range(len(text) - k + 1)}

def build_minhash(text: str, num_perm: int = 128) -> MinHash:
    """テキストからMinHashシグネチャを作成"""
    m = MinHash(num_perm=num_perm)
    for shingle in get_shingles(text):
        m.update(shingle.encode('utf-8'))
    return m

def deduplicate_corpus(documents: list, threshold: float = 0.8):
    """
    MinHash LSHで類似文書を除去
    threshold: 重複とみなすJaccard類似度の閾値
    """
    lsh = MinHashLSH(threshold=threshold, num_perm=128)
    unique_docs = []

    for idx, doc in enumerate(documents):
        mh = build_minhash(doc)
        result = lsh.query(mh)

        if len(result) == 0:
            lsh.insert(f"doc_{idx}", mh)
            unique_docs.append(doc)
        # 類似文書がすでに存在する場合はスキップ

    print(f"元の文書数: {len(documents)} -> 重複排除後: {len(unique_docs)}")
    return unique_docs

MinHashの核心: 2つの集合のJaccard類似度 J(A,B)=AB/ABJ(A,B) = |A \cap B| / |A \cup B| を複数のハッシュ関数の最小値で近似します。128個のハッシュ関数を使うと約3%の誤差でJaccard類似度を推定できます。

2.3 トークナイザーの学習

from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.processors import ByteLevel as ByteLevelProcessor

def train_bpe_tokenizer(
    corpus_files: list,
    vocab_size: int = 32000,
    save_path: str = "tokenizer.json"
):
    """GPT-2スタイルのByte-level BPEトークナイザーを学習"""
    tokenizer = Tokenizer(BPE(unk_token=None))
    tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)
    tokenizer.post_processor = ByteLevelProcessor(trim_offsets=False)

    trainer = BpeTrainer(
        vocab_size=vocab_size,
        min_frequency=2,
        special_tokens=["<pad>", "<s>", "</s>", "<unk>"],
        show_progress=True,
    )

    tokenizer.train(files=corpus_files, trainer=trainer)
    tokenizer.save(save_path)
    print(f"トークナイザー保存完了: {save_path} (vocab_size={vocab_size})")
    return tokenizer

2.4 データ混合比率

Llama 3.1の事前学習データ混合比率を参考にすると:

データソース比率
ウェブクロール(Common Crawl等)約80%
コード(GitHub)約8%
学術・科学論文約5%
書籍(Books3等)約4%
多言語データ約3%

3. アーキテクチャの選択

3.1 主要オープンソースLLMの設計決定比較

モデル位置エンコーディングアテンションFFN正規化
GPT-NeoXALiBiMHASwiGLULayerNorm
LLaMA-3RoPEGQASwiGLURMSNorm
Mistral 7BRoPEGQA + SWASwiGLURMSNorm
Mixtral 8x7BRoPEGQAMoE (SwiGLU)RMSNorm

3.2 RoPE(回転位置エンコーディング)

RoPEはトークン埋め込みに絶対位置情報を直接加算する代わりに、アテンションスコア計算時にQuery/Keyベクトルに回転変換を適用します。

位置mのクエリベクトルと位置nのキーベクトルの内積が相対位置(m-n)の関数になります。この性質により、学習時の最大長を超えた長いコンテキストにもある程度の外挿(extrapolation)が可能です。

qmTkn=Re[j(q(j)eimθj)(k(j)einθj)]q_m^T k_n = \text{Re}\left[\sum_{j} (q_{(j)} e^{im\theta_j}) \overline{(k_{(j)} e^{in\theta_j})}\right]

YaRNやLongRoPEなどの拡張手法はこの性質を活用して、コンテキスト長を4Kから128K以上へ拡張します。

3.3 Grouped Query Attention(GQA)

自己回帰的推論において、過去のトークンのKey・Valueテンソルをキャッシュに保存する必要があります。GQAは複数のQueryヘッドが少数のKey/Valueヘッドを共有することで、このキャッシュを削減します。

  • MHA: Hクエリヘッド、Hキーヘッド、Hバリューヘッド
  • MQA: Hクエリヘッド、1キーヘッド、1バリューヘッド
  • GQA: Hクエリヘッド、GキーValueヘッド(G < H)

Llama 3 8Bは32のクエリヘッドに8つのKVヘッドを使用し、KVキャッシュをMHAの1/4に削減しています。

3.4 Mixture of Experts(MoE)

Mixtral 8x7BとDeepSeek-V3はMoEアーキテクチャを使用します。各トークンはN個の専門家(FFNレイヤー)のうちTop-K個のみを活性化します。

import torch
import torch.nn as nn
import torch.nn.functional as F

class MoELayer(nn.Module):
    """シンプルなMixture of Expertsレイヤー"""
    def __init__(self, d_model: int, d_ff: int, num_experts: int = 8, top_k: int = 2):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.router = nn.Linear(d_model, num_experts, bias=False)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_ff),
                nn.SiLU(),
                nn.Linear(d_ff, d_model),
            )
            for _ in range(num_experts)
        ])

    def forward(self, x: torch.Tensor):
        # x: (batch, seq_len, d_model)
        B, T, D = x.shape
        x_flat = x.view(-1, D)  # (B*T, D)

        # ルーター計算
        router_logits = self.router(x_flat)  # (B*T, num_experts)
        router_probs = F.softmax(router_logits, dim=-1)

        # Top-K専門家の選択
        topk_probs, topk_idx = router_probs.topk(self.top_k, dim=-1)
        topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True)  # 再正規化

        # 負荷分散損失(専門家のバランス維持)
        expert_load = router_probs.mean(0)  # (num_experts,)
        load_balance_loss = self.num_experts * (expert_load * expert_load.mean()).sum()

        # 専門家の出力計算
        output = torch.zeros_like(x_flat)
        for k in range(self.top_k):
            expert_indices = topk_idx[:, k]
            expert_weights = topk_probs[:, k].unsqueeze(-1)
            for e_idx in range(self.num_experts):
                mask = (expert_indices == e_idx)
                if mask.any():
                    expert_out = self.experts[e_idx](x_flat[mask])
                    output[mask] += expert_weights[mask] * expert_out

        return output.view(B, T, D), load_balance_loss

負荷分散損失が必要な理由: ルーターが特定の専門家だけを選ぶよう収束すると、その専門家だけが学習され残りの専門家は学習されません(expert collapse)。負荷分散損失はすべての専門家に均等なトークンが割り振られるよう正則化します。


4. 学習の安定性

4.1 学習率スケジュール:コサインウォームアップ

import math

def cosine_lr_with_warmup(
    optimizer,
    step: int,
    warmup_steps: int,
    total_steps: int,
    max_lr: float,
    min_lr: float = 0.0,
):
    """線形ウォームアップ付きコサインアニーリング"""
    if step < warmup_steps:
        lr = max_lr * step / warmup_steps
    else:
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

# 典型的な事前学習設定
# warmup: 全ステップの1~2%
# max_lr: 3e-4 (7Bモデル向け)
# min_lr: max_lr * 0.1

4.2 損失スパイクの検出と復旧

大規模事前学習では突然損失が急増するスパイクが発生することがあります。一般的な対策:

  1. 勾配ノルムクリッピング: ノルムが閾値を超えたらスケールダウン
  2. スパイク検出後のロールバック: 以前のチェックポイントから再開し、問題のあるバッチをスキップ
  3. 損失の指数移動平均の監視: 急激な上昇時にアラート
import torch

def train_step_with_stability(model, optimizer, batch, grad_clip: float = 1.0):
    """勾配ノルムクリッピングとスパイク検出を備えた学習ステップ"""
    optimizer.zero_grad()
    loss = model(batch)
    loss.backward()

    # 勾配ノルムの計算とクリッピング
    grad_norm = torch.nn.utils.clip_grad_norm_(
        model.parameters(), max_norm=grad_clip
    )

    # 異常に大きな勾配の検出
    if grad_norm > 100 * grad_clip:
        print(f"警告: 異常なgrad_norm={grad_norm:.2f}、ステップをスキップ")
        optimizer.zero_grad()
        return None, grad_norm

    optimizer.step()
    return loss.item(), grad_norm

5. 効率的な事前学習

5.1 Flash Attention 2

Flash AttentionはGPUのSRAMを最大限に活用してアテンション計算のメモリ複雑度を O(N2)O(N^2) から O(N)O(N) に削減します。

# Flash Attention 2の使用例
# pip install flash-attn --no-build-isolation

import torch
from flash_attn import flash_attn_func

def flash_attention_forward(
    q: torch.Tensor,  # (batch, seqlen, nheads, headdim)
    k: torch.Tensor,
    v: torch.Tensor,
    causal: bool = True,
    softmax_scale: float = None,
):
    """
    Flash Attention 2の順伝播
    causal=True: 自己回帰言語モデル用の因果マスクを適用
    """
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)

    out = flash_attn_func(
        q, k, v,
        dropout_p=0.0,
        softmax_scale=softmax_scale,
        causal=causal,
    )
    return out  # (batch, seqlen, nheads, headdim)

Flash Attention 2と標準アテンションの比較:

  • メモリ: シーケンス長に線形比例(従来の2次比例と比較して)
  • 速度: A100で2〜4倍高速
  • 数値安定性: 同一結果が保証される

5.2 スライディングウィンドウアテンション(SWA)

Mistral 7Bで使用されるSWAは各トークンが直近W個のトークンにのみattendします。長いシーケンスでアテンション複雑度を O(WN)O(W \cdot N) に削減します。


6. 評価とチェックポインティング

6.1 パープレキシティ曲線の監視

import torch
import math

def compute_perplexity(model, dataloader, device: str = "cuda"):
    """検証セットのパープレキシティを計算"""
    model.eval()
    total_loss = 0.0
    total_tokens = 0

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids=input_ids, labels=labels)
            batch_tokens = (labels != -100).sum().item()
            total_loss += outputs.loss.item() * batch_tokens
            total_tokens += batch_tokens

    avg_loss = total_loss / total_tokens
    perplexity = math.exp(avg_loss)
    return perplexity

# 典型的な基準値
# GPT-2レベル: PPL 約20 (WebText)
# 良好な7Bモデル: PPL 約6〜8 (検証セット)

6.2 lm-evaluation-harnessベンチマーク

# lm-evaluation-harnessのインストールと実行
pip install lm-eval

# Llama 3 8Bの標準ベンチマーク評価
lm_eval --model hf \
    --model_args pretrained=meta-llama/Meta-Llama-3-8B \
    --tasks mmlu,hellaswag,arc_challenge,winogrande \
    --device cuda:0 \
    --batch_size 8 \
    --output_path results/llama3-8b/

# 学習中の特定チェックポイントを評価
lm_eval --model hf \
    --model_args pretrained=./checkpoints/step-50000 \
    --tasks hellaswag \
    --num_fewshot 10 \
    --device cuda:0

6.3 チェックポイント戦略

チェックポイントの種類保存間隔保持期間
直近N個を維持500ステップごと最新5個のみ
マイルストーン5,000ステップごと永続保存
最高パフォーマンス検証PPL基準永続保存

7. 最新の事前学習レシピ比較(2024-2025)

7.1 Llama 3.1

  • 学習データ: 15Tトークン(Llama 2の7.5倍)
  • コンテキスト長: 128K(RoPE拡張を適用)
  • 語彙サイズ: 128K(tiktoken ベース)
  • 特筆事項: 事前学習後に長コンテキストアニーリングフェーズを追加

7.2 Mistral Large / Mistral 7B

  • Mistral 7B: GQA + SWAで効率を最大化
  • Mixtral 8x7B: 8専門家、トークンあたり2つが有効(フォワードパスでは13Bパラメータを使用)
  • コンテキスト: SWAで32K

7.3 DeepSeek-V3

  • アーキテクチャ: 671B MoE、トークンあたり37Bが有効
  • 学習コスト: H800 2,048台 × 60日(約557万ドル — 競合他社比10倍安)
  • 革新点: Multi-head Latent Attention(MLA)、FP8混合精度学習
  • データ: 14.8Tトークン(中国語・コードの比率が高い)

7.4 phi-4(Microsoft)

  • サイズ: 140億パラメータ
  • 戦略: データ量よりデータ品質を重視 — 合成データを大幅に活用
  • 学習データ: 9.8Tトークン(合成データ約40%)
  • 成果: 数学・推論ベンチマークで70B以上のモデルと競合

クイズ

Q1. Chinchilla則がGPT-3より少ないパラメータで高い性能を出せることを示した核心的な主張とは?

答え: GPT-3は1750億パラメータに対して3000億トークンしか学習していませんでした。Chinchillaはパラメータ数とトークン数を1:20の比率で同時にスケールする必要があることを実験的に証明し、GPT-3が「データ不足」の状態だったと示しました。

解説: Hoffmann et al. 2022は厳密な実験設計により、Kaplanのスケーリング則がデータの価値を過小評価していたことを示しました。同じ計算予算のもとで、700億パラメータを1.4Tトークンで学習したChinchillaが、ほぼすべてのベンチマークでGPT-3を上回りました。核心的な洞察は、最適なトークン数が計算予算の平方根に比例するということです。

Q2. MinHashアルゴリズムが大規模テキストコーパスで重複を効率的に検出する方法は?

答え: 文書をshingle(n-gram集合)として表現し、複数のハッシュ関数の最小値(MinHashシグネチャ)でJaccard類似度を近似推定します。LSH(局所性鋭敏型ハッシュ)で類似シグネチャを同じバケツに集め、候補ペアのみを精密比較します。

解説: 直接のJaccard類似度計算はコーパスサイズNに対してO(N^2)の比較が必要です。MinHashは各文書をkハッシュ最小値の長さkシグネチャに圧縮します。2つのシグネチャの一致率が元のJaccard類似度の不偏推定量となります。LSHはさらに全体の複雑度をO(N)に近づけることができます。

Q3. Grouped Query Attention(GQA)がMulti-Head Attentionより推論メモリを削減する原理は?

答え: GQAはQueryヘッドよりも少数のKey/Valueヘッドを使用するため、KVキャッシュのサイズがQueryヘッド数ではなくKVヘッド数に比例します。

解説: 自己回帰推論時、以前のトークンのKey・ValueをKVキャッシュに保存します。MHAではHヘッドすべてが個別のKVを持つため、KVキャッシュサイズはバッチサイズ × シーケンス長 × H × head_dim × 2になります。GQAでGつのKVヘッドを使用するとキャッシュはG/Hに削減されます。Llama 3 8B(H=32、G=8)はMHA比1/4のKVキャッシュです。

Q4. RoPE(回転位置エンコーディング)が絶対位置エンコーディングより長コンテキスト外挿に有利な理由は?

答え: RoPEはアテンションスコアが絶対位置でなく相対位置(m-n)の関数になるよう設計されており、学習時に見ていない長い距離の相対位置関係にも汎化が可能です。

解説: 絶対位置エンコーディング(正弦波または学習済み)は特定の位置インデックスの埋め込みを直接加算します。学習最大長を超える位置は分布外(out-of-distribution)の入力になります。RoPEはQとKの内積が2つの位置の差(相対位置)のみに依存するよう回転行列を適用します。YaRNやLongRoPEなどの拡張手法はこの性質を活用してコンテキストウィンドウを4Kから128K以上に拡張します。

Q5. MoEにおける専門家ルーティングの負荷分散損失が必要な理由は?

答え: 負荷分散損失がなければ、ルーターは特定の専門家だけを選択するよう収束します(expert collapse)。その専門家だけが最も学習され、さらに選択される正の帰還ループが発生し、他の専門家は勾配を受け取れず未学習のままになります。

解説: ルーターのソフトマックス出力は各専門家の選択確率です。均一な初期化でも学習中に特定の専門家が支配的になります。負荷分散損失は専門家ごとの平均選択確率が均等になるよう補助損失を追加します。これによりすべての専門家が十分に学習され、分散学習時のGPU間の負荷バランスも保たれます。


おわりに

LLM事前学習は、スケーリング則という理論的な羅針盤、データ品質という実践的な課題、そしてメモリ・速度効率化というエンジニアリングの挑戦が交差する領域です。Chinchillaは「大きさだけが答えではない」ことを数学的に証明し、Flash AttentionとGQAは大型モデルを現実的なコストで学習・デプロイできるようにしました。DeepSeek-V3の成功は、MoEアーキテクチャと効率的な実装が組み合わさった時、コスト対効果最高の性能が可能であることを示しています。このガイドの概念を実験しながら、小さなスケールで事前学習を直接体験してみてください。