Skip to content
Published on

大規模モデルトレーニング完全ガイド:1000億パラメータ超LLM事前学習の実践戦略

Authors

はじめに

LLaMA-3 405B、GPT-4、Falcon 180B——これらのモデルは実際にどのようにして学習されたのでしょうか?「GPUをたくさん使えばよい」という答えは、実際の複雑さを著しく単純化しています。1000億パラメータ超のLLMを学習するには、スケーリング則に基づいたハイパーパラメータ設計、高度な分散学習戦略、学習安定性の確保、そして膨大なコンピューティングリソースの効率的な管理が必要です。

このガイドでは、大規模LLMの事前学習における重要な側面を実践的な観点からすべて網羅します。


1. スケーリング則

1.1 Kaplan et al.(OpenAI)のスケーリング則

2020年、Kaplan et al.の論文「Scaling Laws for Neural Language Models」は、言語モデルの性能が3つの主要因子のべき乗則に従って向上することを示しました。

  • N: モデルパラメータ数
  • D: 学習データのトークン数
  • C: 総計算予算(FLOPs)

主な発見:

  1. モデルサイズ優先(計算効率の観点): 固定した計算予算においては、データ増加よりもモデルサイズ増加の方が効率的
  2. データ効率: 同じモデルサイズをより長く学習しても収穫逓減が生じる
  3. べき乗則: 損失Lはおよそ L(N) ≈ (Nc/N)^αN に従う

この法則は、最適よりも少ないデータで計算予算の大部分を大型モデルに投資することを示唆しており、GPT-3(175B)が比較的少ないトークン数(300B)で学習された理由でもあります。

1.2 Chinchilla最適スケーリング(Hoffmann et al. 2022)

2022年、DeepMindのHoffmann et al.は「Training Compute-Optimal Large Language Models」を発表し、Kaplanの結論を覆す重要な発見をもたらしました。

核心的主張: 既存の大型モデルは著しく学習不足であった。

Chinchilla実験:

  • 従来のアプローチ:Gopher(280Bパラメータ、300Bトークン)
  • Chinchilla:70Bパラメータ、1.4Tトークン
  • 結果:Chinchillaは小型でありながらGopherを圧倒

Chinchilla最適スケーリング則:

計算予算Cが与えられた場合、最適なモデルサイズNとデータD:

N_optimal0.1174 × C^0.4999  (パラメータ数)
D_optimal1.6972 × C^0.5001  (トークン数)

NとDは計算量に対してほぼ等しくスケールすべきです。経験則として:

D_optimal20 × N

7Bパラメータモデル → 最低140Bトークン必要 70Bパラメータモデル → 最低1.4Tトークン必要

1.3 実践的な計算式

FLOPs推定:

def estimate_training_flops(
    num_params,       # パラメータ数
    num_tokens,       # 学習トークン数
    include_backward=True
):
    """
    C ≈ 6 × N × D (順伝播 + 逆伝播)
    逆伝播は順伝播の約2倍のコスト
    """
    flops_per_token = 6 * num_params
    total_flops = flops_per_token * num_tokens
    return total_flops

# 例: LLaMA-2 7B, 2Tトークン
flops = estimate_training_flops(7e9, 2e12)
print(f"Total FLOPs: {flops:.2e}")
# Total FLOPs: 8.40e+22

# GPU学習時間推定 (A100 312 TFLOPs, MFU 50%想定)
a100_flops = 312e12   # 312 TFLOPs
mfu = 0.5             # Model FLOP Utilization
training_seconds = flops / (a100_flops * mfu)
training_hours = training_seconds / 3600

# 1000台のA100使用時
num_gpus = 1000
gpu_hours_per_gpu = training_hours / num_gpus
print(f"Training time: {gpu_hours_per_gpu:.1f} GPU-hours per GPU")
print(f"Total GPU-hours: {training_hours:.0f} GPU-hours")

1.4 Chinchilla後のトレンド:「過剰学習」戦略

実際には、Chinchilla最適値を超えて学習するトレンドがあります。

理由: 学習コスト対推論コストのトレードオフ

  • 学習は一度だけだが、推論は何百万回も実行される
  • 小さいモデルを長く学習 → 推論コストの低下

例:LLaMA-3 8Bは15Tトークンで学習(Chinchilla最適値の約100倍)


2. 事前学習データパイプライン

2.1 データソース構成

高品質な事前学習データはLLM性能の核心です。

主要データソース:

ソース特徴品質
Common Crawlウェブクロール、PBスケールノイズ多、フィルタリング必須
Books3/Gutenberg書籍、高品質多様性が限定的
Wikipedia百科事典、検証済み量が限定的
GitHubコード、推論能力向上ライセンス考慮必要
ArXiv/PubMed学術論文専門性が高い
StackExchangeQ&A、実践的知識品質良好

典型的な混合比率(推定値、Llama-2ベース):

data_mixture = {
    "Common Crawl (filtered)": 0.67,   # 67%
    "Books": 0.14,                     # 14%
    "GitHub": 0.045,                   # 4.5%
    "Wikipedia": 0.045,                # 4.5%
    "Gutenberg": 0.025,                # 2.5%
    "ArXiv": 0.025,                    # 2.5%
    "StackExchange": 0.02,             # 2%
}

2.2 データクリーニングパイプライン

import re
from typing import List, Optional
from dataclasses import dataclass

@dataclass
class DocumentFilter:
    min_tokens: int = 50
    max_tokens: int = 100000
    min_avg_word_length: float = 3.0
    max_symbol_ratio: float = 0.1
    min_alpha_ratio: float = 0.7

def filter_document(text: str, config: DocumentFilter) -> Optional[str]:
    """基本的な品質フィルター"""
    tokens = text.split()
    token_count = len(tokens)

    # 長さフィルター
    if not (config.min_tokens <= token_count <= config.max_tokens):
        return None

    # 平均単語長
    avg_word_len = sum(len(t) for t in tokens) / token_count
    if avg_word_len < config.min_avg_word_length:
        return None

    # アルファベット文字の比率
    alpha_chars = sum(1 for c in text if c.isalpha())
    if alpha_chars / len(text) < config.min_alpha_ratio:
        return None

    return text

def deduplicate_documents(texts: List[str], n_gram_size: int = 13) -> List[str]:
    """
    MinHash LSHベースの重複排除
    (実際の実装はdatasketchライブラリを使用)
    """
    from datasketch import MinHash, MinHashLSH

    lsh = MinHashLSH(threshold=0.8, num_perm=128)
    unique_texts = []

    for i, text in enumerate(texts):
        minhash = MinHash(num_perm=128)
        words = text.lower().split()
        for j in range(len(words) - n_gram_size + 1):
            ngram = " ".join(words[j:j+n_gram_size])
            minhash.update(ngram.encode("utf-8"))

        if not lsh.query(minhash):
            lsh.insert(str(i), minhash)
            unique_texts.append(text)

    return unique_texts

2.3 トークナイザーの学習

from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import ByteLevel

def train_tokenizer(
    corpus_files: List[str],
    vocab_size: int = 32000,
    output_path: str = "tokenizer.json"
):
    """BPEトークナイザーの学習(SentencePieceスタイル)"""
    tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
    tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)

    trainer = BpeTrainer(
        vocab_size=vocab_size,
        min_frequency=2,
        special_tokens=["[UNK]", "[BOS]", "[EOS]", "[PAD]"],
        show_progress=True
    )

    tokenizer.train(files=corpus_files, trainer=trainer)
    tokenizer.save(output_path)
    return tokenizer

def count_tokens(file_path: str, tokenizer) -> int:
    total = 0
    with open(file_path, "r") as f:
        for line in f:
            tokens = tokenizer.encode(line.strip())
            total += len(tokens.ids)
    return total

3. Megatron-LM

3.1 NVIDIA Megatronの紹介

Megatron-LMはNVIDIAが開発した大規模言語モデル学習フレームワークです。GPT、BERT、T5などの大型モデルの学習に特化した並列化技術を提供します。

主要機能:

  • テンソル並列: 行列演算をGPUに分割
  • パイプライン並列: 層を順番にGPUに分散
  • シーケンス並列: シーケンス次元でも分散(Megatron v2)
  • Flash Attention統合: メモリ効率的なアテンション計算

3.2 テンソル並列の実装原理

トランスフォーマーの核心演算は行列乗算です。これをどのように分割するかがテンソル並列の本質です。

Multi-Head Attentionのテンソル分割:

Q, K, V 射影:  [d_model, d_head x n_heads]
GPU: [d_model, d_head x (n_heads/tp_size)]

FFN層のテンソル分割:

1線形層:   [d_model, d_ffn]  → 列方向分割
2線形層:  [d_ffn, d_model]  → 行方向分割
GPU: [d_model, d_ffn/tp_size] + [d_ffn/tp_size, d_model]
# Megatron ColumnParallelLinear の概念的実装(簡略版)
class ColumnParallelLinear(nn.Module):
    """重み行列を列方向に分割"""
    def __init__(self, in_features, out_features, tp_size):
        super().__init__()
        self.tp_size = tp_size
        assert out_features % tp_size == 0
        local_out = out_features // tp_size
        # 各GPUが全重みの1/tp_sizeを保持
        self.weight = nn.Parameter(torch.empty(local_out, in_features))

    def forward(self, x):
        # 入力はフルサイズ、出力は分割済み
        return torch.nn.functional.linear(x, self.weight)
        # 後続のAll-reduceまたはAll-gatherが必要

class RowParallelLinear(nn.Module):
    """重み行列を行方向に分割"""
    def __init__(self, in_features, out_features, tp_size):
        super().__init__()
        self.tp_size = tp_size
        assert in_features % tp_size == 0
        local_in = in_features // tp_size
        # 各GPUが入力次元の1/tp_sizeを処理
        self.weight = nn.Parameter(torch.empty(out_features, local_in))

    def forward(self, x):
        # x: [batch, seq, in_features/tp_size]
        local_output = torch.nn.functional.linear(x, self.weight)
        # 各GPUからの部分結果を合算するAll-reduce
        dist.all_reduce(local_output)
        return local_output

3.3 シーケンス並列

シーケンス並列はLayerNormとDropoutをシーケンス次元で分散します。

アテンション入力: [batch, seq/tp_size, d_model] (GPU)
All-gather: [batch, seq, d_model]
Self-Attention (列分割)
Reduce-scatter: [batch, seq/tp_size, d_model]
FFN (行分割)

これにより、LayerNormのメモリもtp_size分削減されます。

3.4 Megatron設定例

#!/bin/bash
# Megatron-LMでGPT-3 175Bを学習する例

GPUS_PER_NODE=8
NNODES=64
TP_SIZE=8    # テンソル並列
PP_SIZE=16   # パイプライン並列
DP_SIZE=$((GPUS_PER_NODE * NNODES / TP_SIZE / PP_SIZE))
# DP_SIZE = 8 * 64 / 8 / 16 = 4

torchrun \
    --nproc_per_node $GPUS_PER_NODE \
    --nnodes $NNODES \
    pretrain_gpt.py \
    --num-layers 96 \
    --hidden-size 12288 \
    --num-attention-heads 96 \
    --seq-length 2048 \
    --max-position-embeddings 2048 \
    --global-batch-size 1536 \
    --train-iters 300000 \
    --lr 6e-5 \
    --lr-decay-style cosine \
    --min-lr 6e-6 \
    --lr-warmup-fraction 0.001 \
    --weight-decay 0.1 \
    --tensor-model-parallel-size $TP_SIZE \
    --pipeline-model-parallel-size $PP_SIZE \
    --micro-batch-size 1 \
    --bf16 \
    --use-flash-attn \
    --sequence-parallel \
    --data-path /data/pile_text_document \
    --vocab-file /data/gpt2-vocab.json \
    --merge-file /data/gpt2-merges.txt \
    --save /checkpoints/gpt3-175b \
    --load /checkpoints/gpt3-175b

4. 3D並列化(DP + TP + PP)

4.1 3種類の並列化の組み合わせ

大規模LLMの学習では、3種類の並列化を同時に使用します。

並列化の種類分散対象通信パターンオーバーヘッド
データ並列(DP)バッチ勾配のAll-reduce
テンソル並列(TP)層内行列層毎のAll-reduce中(レイテンシに敏感)
パイプライン並列(PP)層グループポイントツーポイントパイプラインバブルオーバーヘッド

4.2 最適な並列化設定の探索

def find_optimal_parallelism(
    world_size: int,
    num_layers: int,
    hidden_size: int,
    gpu_memory_gb: float = 80.0
) -> dict:
    """
    最適な3D並列化設定を探索する。
    一般的な規則:
    - TP: 単一ノード内(NVLinkを活用)、通常4または8
    - PP: ノード間、層数で決定
    - DP: 残りのGPU
    """
    configs = []
    for tp in [1, 2, 4, 8]:
        for pp in range(1, world_size + 1):
            dp = world_size // (tp * pp)
            if dp < 1:
                continue
            if world_size != dp * tp * pp:
                continue
            if num_layers % pp != 0:
                continue

            # パイプラインバブル率の計算
            # mマイクロバッチとpパイプラインステージの場合: bubble = (p-1)/(m+p-1)
            micro_batch = 4   # 仮定
            global_batch = 2048   # 仮定
            m = global_batch // (dp * micro_batch)
            if m <= 0:
                continue
            bubble_rate = (pp - 1) / (m + pp - 1)

            configs.append({
                "tp": tp, "pp": pp, "dp": dp,
                "bubble_rate": bubble_rate,
                "layers_per_stage": num_layers // pp
            })

    # 低バブル率と適切なTPサイズを優先
    configs.sort(key=lambda x: (x["bubble_rate"], x["tp"]))
    return configs[:5]

# 例: 512 GPU、GPT-3スケール(96層)
best_configs = find_optimal_parallelism(512, 96, 12288)
for c in best_configs:
    print(f"TP={c['tp']}, PP={c['pp']}, DP={c['dp']}, "
          f"Bubble={c['bubble_rate']:.2%}, Layers/Stage={c['layers_per_stage']}")

4.3 DeepSpeed + Megatronの組み合わせ

# Megatron-DeepSpeed統合設定
deepspeed_config = {
    "train_batch_size": 2048,
    "train_micro_batch_size_per_gpu": 1,
    "gradient_accumulation_steps": 256,   # 2048 / (8 DP * 1 micro)
    "bf16": {"enabled": True},
    "zero_optimization": {
        "stage": 1,   # TP+PPと併用してZeRO-1を使用
    },
    "gradient_clipping": 1.0,
    "wall_clock_breakdown": True,
    "steps_per_print": 10
}

4.4 通信トポロジーの最適化

# NVLink対InfiniBandの帯域幅を考慮
# NVLink 3.0: 600 GB/s (双方向)
# InfiniBand HDR: 200 Gb/s = ポート毎25 GB/s

# 通信グループ設計原則:
# TPグループ: 同じノード内のGPU(NVLinkを活用)
# PPグループ: ノード間(InfiniBandを活用)
# DPグループ: ノード間(ZeROで通信を最小化)

def setup_process_groups(tp_size, pp_size, dp_size):
    """プロセスグループの設定"""
    world_size = tp_size * pp_size * dp_size
    rank = torch.distributed.get_rank()

    # テンソル並列グループ(単一ノード内推奨)
    for dp_rank in range(dp_size):
        for pp_rank in range(pp_size):
            ranks = [
                dp_rank * tp_size * pp_size + pp_rank * tp_size + tp_rank
                for tp_rank in range(tp_size)
            ]
            group = torch.distributed.new_group(ranks)
            if rank in ranks:
                tp_group = group

    return tp_group

5. 学習安定性

5.1 損失スパイクへの対処

大規模学習では突然の損失スパイクが頻繁に発生します。

根本原因の分析:

  1. 勾配ノルムの過大: 特定バッチで異常に大きな勾配
  2. 学習率が高すぎる: 損失のランドスケープトポロジーの急激な変化
  3. 悪質なデータバッチ: 外れ値または誤って処理されたデータ
  4. 数値不安定性: bf16/fp16でのオーバーフロー/アンダーフロー

モニタリングコード:

import torch
import wandb
from collections import deque

class TrainingMonitor:
    def __init__(self, spike_threshold: float = 3.0, window_size: int = 100):
        self.loss_history = deque(maxlen=window_size)
        self.grad_norm_history = deque(maxlen=window_size)
        self.spike_threshold = spike_threshold
        self.spike_count = 0

    def check_loss_spike(self, current_loss: float) -> bool:
        if len(self.loss_history) < 10:
            self.loss_history.append(current_loss)
            return False

        mean_loss = sum(self.loss_history) / len(self.loss_history)
        if current_loss > mean_loss * self.spike_threshold:
            self.spike_count += 1
            print(f"SPIKE detected: current={current_loss:.4f}, mean={mean_loss:.4f}")
            return True

        self.loss_history.append(current_loss)
        return False

    def compute_grad_norm(self, model) -> float:
        total_norm = 0.0
        for p in model.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        return total_norm ** 0.5

    def log_step(self, step: int, loss: float, model, lr: float):
        grad_norm = self.compute_grad_norm(model)
        is_spike = self.check_loss_spike(loss)

        wandb.log({
            "train/loss": loss,
            "train/grad_norm": grad_norm,
            "train/lr": lr,
            "train/is_spike": int(is_spike),
            "train/spike_count": self.spike_count,
        }, step=step)

        return grad_norm, is_spike

5.2 勾配クリッピングとノイズスケール

def adaptive_gradient_clipping(
    model,
    optimizer,
    max_norm: float = 1.0,
    clip_coef: float = 0.01
):
    """
    適応的勾配クリッピング(AGC)
    パラメータノルム比に基づいて層毎にクリッピング
    """
    for module in model.modules():
        for name, param in module.named_parameters(recurse=False):
            if param.grad is None:
                continue

            param_norm = param.data.norm(2)
            grad_norm = param.grad.data.norm(2)

            # クリッピング閾値: clip_coef × パラメータノルム
            max_grad_norm = clip_coef * param_norm

            if grad_norm > max_grad_norm and grad_norm > 0:
                param.grad.data.mul_(max_grad_norm / grad_norm)

def compute_gradient_noise_scale(model, world_size: int) -> float:
    """
    勾配ノイズスケール(GNS)の計算
    GNS = tr(S) / ||g||^2
    高GNS → 大きなバッチサイズが許容可能
    """
    grads = [p.grad.data for p in model.parameters() if p.grad is not None]

    g_norm_sq = sum(g.norm(2).item() ** 2 for g in grads)
    variance = sum((g - g.mean()).norm(2).item() ** 2 for g in grads)

    gns = variance / g_norm_sq if g_norm_sq > 0 else 0
    return gns

5.3 学習率スケジューリング戦略

import math

def cosine_schedule_with_warmup(
    current_step: int,
    warmup_steps: int,
    total_steps: int,
    max_lr: float,
    min_lr: float
) -> float:
    """
    線形ウォームアップ付きコサイン減衰。
    ほとんどのLLM学習で標準的に使用される。
    """
    if current_step < warmup_steps:
        # 線形ウォームアップ
        return max_lr * current_step / warmup_steps
    else:
        # コサイン減衰
        progress = (current_step - warmup_steps) / (total_steps - warmup_steps)
        cosine_val = 0.5 * (1 + math.cos(math.pi * progress))
        return min_lr + (max_lr - min_lr) * cosine_val

# WSDスケジュール(Warmup-Stable-Decay)— 最近の研究で人気
def wsd_schedule(
    current_step: int,
    warmup_steps: int,
    stable_steps: int,
    decay_steps: int,
    max_lr: float,
    min_lr: float
) -> float:
    """
    3フェーズスケジュール: ウォームアップ -> 安定 -> 減衰
    MiniCPM、LLaMA-3などで使用される。
    継続的学習に有利。
    """
    if current_step < warmup_steps:
        return max_lr * current_step / warmup_steps
    elif current_step < warmup_steps + stable_steps:
        return max_lr
    else:
        decay_progress = (current_step - warmup_steps - stable_steps) / decay_steps
        return max_lr - (max_lr - min_lr) * min(decay_progress, 1.0)

5.4 バッチサイズスケジューリング

# バッチサイズランプ戦略(OpenAI GPT-3論文での提案)
def get_batch_size_schedule(
    current_tokens: int,
    final_batch_tokens: int = 4_000_000,   # 最終バッチサイズ(トークン)
    initial_batch_tokens: int = 32_000,    # 初期バッチサイズ
    ramp_tokens: int = 1_200_000_000       # ランプアップ区間(トークン)
) -> int:
    """
    トークンスループットに対してバッチサイズを線形増加。
    序盤: 小バッチで高速収束。
    終盤: 大バッチで安定した学習。
    """
    if current_tokens < ramp_tokens:
        progress = current_tokens / ramp_tokens
        batch_tokens = initial_batch_tokens + (
            final_batch_tokens - initial_batch_tokens
        ) * progress
        return int(batch_tokens)
    return final_batch_tokens

6. チェックポイント戦略

6.1 分散チェックポイント

import os
import torch
import torch.distributed as dist
from pathlib import Path

class DistributedCheckpointer:
    def __init__(self, save_dir: str, max_checkpoints: int = 5):
        self.save_dir = Path(save_dir)
        self.max_checkpoints = max_checkpoints
        self.save_dir.mkdir(parents=True, exist_ok=True)

    def save(
        self,
        model,
        optimizer,
        scheduler,
        step: int,
        rank: int,
        world_size: int
    ):
        """分散チェックポイントの保存"""
        checkpoint_dir = self.save_dir / f"step_{step:08d}"
        checkpoint_dir.mkdir(parents=True, exist_ok=True)

        # 各ランクが自分のシャードを保存
        rank_path = checkpoint_dir / f"rank_{rank:04d}_of_{world_size:04d}.pt"

        model_state = {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict() if scheduler else None,
            "step": step,
            "rank": rank,
            "world_size": world_size,
        }

        torch.save(model_state, rank_path)

        # ランク0がメタデータを保存
        if rank == 0:
            meta = {
                "step": step,
                "world_size": world_size,
                "timestamp": __import__("time").time()
            }
            torch.save(meta, checkpoint_dir / "meta.pt")

        # 古いチェックポイントをクリーンアップ
        if rank == 0:
            self._cleanup_old_checkpoints()

    def _cleanup_old_checkpoints(self):
        checkpoints = sorted(self.save_dir.glob("step_*"))
        while len(checkpoints) > self.max_checkpoints:
            oldest = checkpoints.pop(0)
            import shutil
            shutil.rmtree(oldest)

6.2 非同期チェックポイント保存

import threading
from queue import Queue

class AsyncCheckpointer:
    """別スレッドでチェックポイントを保存(学習を中断しない)"""
    def __init__(self, save_dir: str):
        self.save_dir = save_dir
        self.queue = Queue(maxsize=2)
        self.worker = threading.Thread(target=self._worker, daemon=True)
        self.worker.start()

    def _worker(self):
        while True:
            item = self.queue.get()
            if item is None:
                break
            state_dict, path = item
            torch.save(state_dict, path)
            self.queue.task_done()

    def save_async(self, model, step: int, rank: int):
        """メインスレッドをブロックせずに保存キューに追加"""
        # CPUにコピー(GPUメモリを解放)
        state_dict = {
            k: v.cpu().clone() for k, v in model.state_dict().items()
        }
        path = os.path.join(self.save_dir, f"step_{step}_rank_{rank}.pt")

        if not self.queue.full():
            self.queue.put((state_dict, path))
        else:
            # キューが満杯の場合: 同期的に保存
            torch.save(state_dict, path)

7. 学習モニタリング

7.1 コアメトリクスのトラッキング

import wandb
import numpy as np

class LLMTrainingTracker:
    def __init__(self, project_name: str, run_name: str):
        wandb.init(project=project_name, name=run_name)
        self.step = 0
        self.token_count = 0

    def log_training_step(
        self,
        loss: float,
        learning_rate: float,
        grad_norm: float,
        batch_tokens: int,
        elapsed_seconds: float
    ):
        tokens_per_second = batch_tokens / elapsed_seconds

        self.token_count += batch_tokens
        self.step += 1

        wandb.log({
            # 学習メトリクス
            "train/loss": loss,
            "train/perplexity": np.exp(min(loss, 20)),
            "train/grad_norm": grad_norm,
            "train/learning_rate": learning_rate,
            # 効率メトリクス
            "throughput/tokens_per_second": tokens_per_second,
            "throughput/samples_per_second": tokens_per_second / 2048,
            # 進捗
            "progress/total_tokens": self.token_count,
            "progress/step": self.step,
        }, step=self.step)

    def log_evaluation(self, eval_loss: float, perplexities: dict):
        wandb.log({
            "eval/loss": eval_loss,
            "eval/perplexity": np.exp(eval_loss),
            **{f"eval/{k}_ppl": v for k, v in perplexities.items()}
        }, step=self.step)

7.2 損失曲線の解釈

健全な損失曲線の特徴:

  • ウォームアップフェーズ: 急速な減少
  • 安定フェーズ: ゆっくりだが着実な減少
  • 後期フェーズ: 極めて緩やかな減少

警告シグナル:

  • 完全なプラトー: 学習率が低すぎる、データ枯渇
  • 急なスパイク: 悪質なバッチ、勾配爆発
  • NaN/Inf: 数値不安定性、fp16オーバーフロー
def analyze_loss_curve(losses: list, window: int = 100) -> dict:
    """損失曲線の自動分析"""
    if len(losses) < window * 2:
        return {}

    recent = losses[-window:]
    previous = losses[-2*window:-window]

    recent_mean = np.mean(recent)
    previous_mean = np.mean(previous)
    improvement = (previous_mean - recent_mean) / previous_mean

    global_mean = np.mean(losses)
    global_std = np.std(losses)
    spikes = [l for l in losses if l > global_mean + 3 * global_std]

    return {
        "recent_loss": recent_mean,
        "improvement_rate": improvement,
        "spike_count": len(spikes),
        "is_stagnant": improvement < 0.001,
        "recommendation": "Consider increasing LR" if improvement < 0.001 else "Normal"
    }

8. オープンソースLLM学習コードベース

8.1 GPT-NeoX(EleutherAI)

# GPT-NeoXのインストールと実行
git clone https://github.com/EleutherAI/gpt-neox
cd gpt-neox
pip install -r requirements/requirements.txt

# 設定ファイル(configs/20B.yml)
# 学習開始
python deepy.py train configs/20B.yml

GPT-NeoXはMegatronベースのパイプライン並列とDeepSpeedを組み合わせています。EleutherAIのPythiaモデルシリーズはこのコードベースで学習されました。

8.2 OLMo(Allen AI)

# OLMo学習設定(簡略版)
from olmo import TrainConfig, ModelConfig, OptimizerConfig

train_config = TrainConfig(
    model=ModelConfig(
        d_model=4096,
        n_heads=32,
        n_layers=32,
        mlp_ratio=8/3,
        vocab_size=50280,
        max_sequence_length=2048,
        attention_type="flash",
    ),
    optimizer=OptimizerConfig(
        name="adamw",
        learning_rate=3e-4,
        weight_decay=0.1,
        betas=(0.9, 0.95),
    ),
    max_duration="300000ba",    # 300,000バッチ
    global_train_batch_size=2048,
    device_train_microbatch_size=2,
    precision="bf16",
    fsdp_config={
        "wrapping_strategy": "by_block",
        "precision": "bf16",
        "sharding_strategy": "FULL_SHARD",
    },
)

OLMoは学習データ、コード、中間チェックポイントを公開する完全透明性を目指しています。

8.3 torchtitan

# torchtitan - PyTorchネイティブLLM学習
# Metaが開発したモダンな事前学習フレームワーク

# torchtitan設定(TOML形式)
config = """
[model]
name = "llama3"
flavor = "8B"
tokenizer_path = "./original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 8
seq_len = 8192
max_norm = 1.0
steps = 10000
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1  # FSDP2
tensor_parallel_degree = 1
compile = true   # torch.compileを有効化

[checkpoint]
enable_checkpoint = true
folder = "outputs/checkpoint"
interval_type = "steps"
interval = 500
"""

torchtitanはFSDP2、テンソル並列、パイプライン並列をPyTorchでネイティブに実装しています。

8.4 FSDP2(Fully Sharded Data Parallel)

import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

from transformers.models.llama.modeling_llama import LlamaDecoderLayer

# FSDPラップポリシー
auto_wrap_policy = transformer_auto_wrap_policy(
    transformer_layer_cls={LlamaDecoderLayer}
)

# FSDP2(torch 2.xの新API)
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy

mp_policy = MixedPrecisionPolicy(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.float32,
)

for layer in model.model.layers:
    fully_shard(
        layer,
        mp_policy=mp_policy,
        reshard_after_forward=True,   # ZeRO-3に類似
    )

fully_shard(model, mp_policy=mp_policy)

9. コスト推定と効率化戦略

9.1 GPU時間の計算

def estimate_training_cost(
    model_params: float,   # パラメータ数(例:7e9)
    training_tokens: float,  # 学習トークン数(例:2e12)
    num_gpus: int,
    gpu_type: str = "A100-80GB",
    cloud_provider: str = "AWS"
) -> dict:
    """学習コストの推定"""
    # GPU仕様(理論ピーク値、bf16)
    gpu_specs = {
        "A100-80GB": {"tflops": 312, "memory_gb": 80, "nvlink": True},
        "H100-80GB": {"tflops": 989, "memory_gb": 80, "nvlink": True},
        "A10G-24GB": {"tflops": 125, "memory_gb": 24, "nvlink": False},
        "V100-32GB": {"tflops": 130, "memory_gb": 32, "nvlink": True},
    }

    # クラウド価格(概算USD/時間)
    cloud_prices = {
        "AWS": {"A100-80GB": 3.97, "H100-80GB": 8.0, "A10G-24GB": 1.006},
        "GCP": {"A100-80GB": 3.67, "H100-80GB": 7.0},
        "Azure": {"A100-80GB": 3.40, "H100-80GB": 6.0},
        "Lambda": {"A100-80GB": 1.99, "H100-80GB": 3.29},
    }

    spec = gpu_specs[gpu_type]
    price_per_hour = cloud_prices.get(cloud_provider, {}).get(gpu_type, 0)

    # FLOPs計算
    total_flops = 6 * model_params * training_tokens

    # 仮定のMFU(通常35-55%)
    mfu = 0.45
    effective_tflops = spec["tflops"] * 1e12 * mfu

    # 総学習時間
    total_seconds = total_flops / (effective_tflops * num_gpus)
    total_hours = total_seconds / 3600
    total_days = total_hours / 24

    # コスト計算
    total_cost = price_per_hour * num_gpus * total_hours

    return {
        "total_flops": f"{total_flops:.2e}",
        "training_hours": f"{total_hours:.1f}",
        "training_days": f"{total_days:.1f}",
        "gpu_hours": f"{total_hours * num_gpus:.0f}",
        "estimated_cost_usd": f"{total_cost:,.0f}",
        "mfu": mfu,
    }

# 例
result = estimate_training_cost(
    model_params=7e9,
    training_tokens=2e12,
    num_gpus=256,
    gpu_type="A100-80GB",
    cloud_provider="Lambda"
)
for k, v in result.items():
    print(f"{k}: {v}")

9.2 MFU(Model FLOP Utilization)の最適化

def measure_mfu(
    model,
    batch_tokens: int,
    elapsed_seconds: float,
    theoretical_peak_tflops: float
) -> float:
    """
    実際のMFUを測定する。
    MFU = 実際のFLOP率 / 理論ピークFLOP率
    """
    num_params = sum(p.numel() for p in model.parameters())
    actual_flops = 6 * num_params * batch_tokens   # 順伝播(2) + 逆伝播(4)

    actual_tflops = actual_flops / elapsed_seconds / 1e12
    mfu = actual_tflops / theoretical_peak_tflops

    return mfu

# MFU改善戦略:
# 1. Flash Attentionの使用(メモリI/Oを最小化)
# 2. torch.compileの有効化(カーネルフュージョン)
# 3. 活性化チェックポイントを慎重に使用(速度とのトレードオフ)
# 4. 最適なバッチサイズの選択(GPU占有率を最大化)
# 5. 通信と計算のオーバーラップ(FSDP/DeepSpeed設定)

9.3 効率化戦略まとめ

戦略メモリ削減速度影響実装難易度
ZeRO-2-5%簡単
ZeRO-3-15%中程度
Flash Attention 2+20%簡単
torch.compileなし+15-30%簡単
活性化チェックポイント-30%簡単
bf16(fp16比)なし安定性向上簡単
シーケンスパッキングなし+20%中程度

10. 完全な事前学習ランチャースクリプト

#!/usr/bin/env python3
"""
大規模LLM事前学習の完全例。
GPTスタイルモデル、DeepSpeed ZeRO-2 + Flash Attention。
"""
import os
import time
import math
import torch
import torch.nn as nn
import deepspeed
from torch.utils.data import DataLoader, IterableDataset
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    get_cosine_schedule_with_warmup,
)

# 設定
TRAINING_CONFIG = {
    "model_name": "meta-llama/Llama-2-7b",
    "total_tokens": 1_000_000_000,    # 1Bトークン(デモ)
    "max_seq_len": 2048,
    "global_batch_size": 1024,        # グローバルバッチ(トークン: 1024 * 2048 = 2M)
    "micro_batch_per_gpu": 4,
    "learning_rate": 3e-4,
    "min_lr": 3e-5,
    "warmup_tokens": 20_000_000,      # 2%ウォームアップ
    "weight_decay": 0.1,
    "grad_clip": 1.0,
    "save_every_steps": 1000,
    "eval_every_steps": 500,
    "log_every_steps": 10,
}

class StreamingTokenDataset(IterableDataset):
    """ストリーミングトークンデータセット(大規模データを処理)"""
    def __init__(self, data_path: str, seq_len: int, rank: int, world_size: int):
        self.data_path = data_path
        self.seq_len = seq_len
        self.rank = rank
        self.world_size = world_size

    def __iter__(self):
        # 各ランクが異なるデータセグメントを処理
        with open(self.data_path, "rb") as f:
            f.seek(self.rank * self.seq_len * 2)   # uint16基準
            while True:
                chunk = f.read(self.seq_len * 2 * self.world_size)
                if not chunk:
                    break
                tokens = torch.frombuffer(chunk, dtype=torch.uint16).long()
                if len(tokens) < self.seq_len + 1:
                    break
                input_ids = tokens[:self.seq_len]
                labels = tokens[1:self.seq_len + 1]
                yield {"input_ids": input_ids, "labels": labels}

def train():
    deepspeed.init_distributed()
    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()
    local_rank = int(os.environ.get("LOCAL_RANK", 0))

    # モデルのロード
    config = AutoConfig.from_pretrained(TRAINING_CONFIG["model_name"])
    with deepspeed.zero.Init():
        model = AutoModelForCausalLM.from_config(config)

    # DeepSpeedエンジンの初期化
    model_engine, optimizer, _, _ = deepspeed.initialize(
        model=model,
        config={
            "train_micro_batch_size_per_gpu": TRAINING_CONFIG["micro_batch_per_gpu"],
            "gradient_accumulation_steps": (
                TRAINING_CONFIG["global_batch_size"]
                // TRAINING_CONFIG["micro_batch_per_gpu"]
                // world_size
            ),
            "bf16": {"enabled": True},
            "zero_optimization": {
                "stage": 2,
                "overlap_comm": True,
                "contiguous_gradients": True,
                "allgather_bucket_size": 2e8,
                "reduce_bucket_size": 2e8,
            },
            "gradient_clipping": TRAINING_CONFIG["grad_clip"],
            "optimizer": {
                "type": "AdamW",
                "params": {
                    "lr": TRAINING_CONFIG["learning_rate"],
                    "betas": [0.9, 0.95],
                    "eps": 1e-8,
                    "weight_decay": TRAINING_CONFIG["weight_decay"],
                }
            },
            "steps_per_print": TRAINING_CONFIG["log_every_steps"],
        }
    )

    # データセット
    train_dataset = StreamingTokenDataset(
        "/data/train_tokens.bin",
        TRAINING_CONFIG["max_seq_len"],
        rank, world_size
    )
    train_loader = DataLoader(
        train_dataset,
        batch_size=TRAINING_CONFIG["micro_batch_per_gpu"],
        num_workers=4,
        pin_memory=True,
    )

    # 学習ループ
    total_tokens = 0
    target_tokens = TRAINING_CONFIG["total_tokens"]

    for batch in train_loader:
        if total_tokens >= target_tokens:
            break

        input_ids = batch["input_ids"].to(model_engine.device)
        labels = batch["labels"].to(model_engine.device)

        # 学習率スケジュール
        warmup_tokens = TRAINING_CONFIG["warmup_tokens"]
        lr = cosine_schedule_with_warmup(
            total_tokens, warmup_tokens, target_tokens,
            TRAINING_CONFIG["learning_rate"], TRAINING_CONFIG["min_lr"]
        )
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr

        # 順伝播/逆伝播
        outputs = model_engine(input_ids=input_ids, labels=labels)
        loss = outputs.loss
        model_engine.backward(loss)
        model_engine.step()

        # トークン数を更新
        batch_tokens = input_ids.numel() * world_size
        total_tokens += batch_tokens

        # ログ(ランク0のみ)
        if rank == 0 and model_engine.global_steps % TRAINING_CONFIG["log_every_steps"] == 0:
            print(f"Tokens: {total_tokens/1e9:.2f}B, "
                  f"Loss: {loss.item():.4f}, "
                  f"LR: {lr:.2e}, "
                  f"PPL: {math.exp(loss.item()):.1f}")

        # チェックポイントの保存
        if model_engine.global_steps % TRAINING_CONFIG["save_every_steps"] == 0:
            model_engine.save_checkpoint(
                "./checkpoints",
                tag=f"step_{model_engine.global_steps}"
            )

    if rank == 0:
        print("学習完了!")
        model_engine.save_checkpoint("./checkpoints", tag="final")

def cosine_schedule_with_warmup(step, warmup, total, max_lr, min_lr):
    if step < warmup:
        return max_lr * step / max(warmup, 1)
    progress = (step - warmup) / max(total - warmup, 1)
    return min_lr + (max_lr - min_lr) * 0.5 * (1 + math.cos(math.pi * progress))

if __name__ == "__main__":
    train()

起動コマンド:

deepspeed --num_nodes=4 --num_gpus=8 \
    --master_addr=node0 --master_port=29500 \
    pretrain.py

まとめ

1000億パラメータ超のLLMの事前学習は、単にコードを実行することではなく、無数の工学的判断とトレードオフの慎重なバランスです。

重要なポイント:

  • スケーリング則: Chinchilla最適化に従ってモデルサイズとデータ量をバランスさせる。ただし推論コストを考慮すると「過剰学習」も合理的
  • データが基本: 高品質なデータキュレーションとバランスの取れた混合がモデル能力を決定する
  • 3D並列化: DP + TP + PPの組み合わせで数千のGPUを効率的に活用
  • 安定性優先: 損失スパイクと勾配爆発を監視し、迅速に対応することが不可欠
  • コスト意識: MFUとGPU使用率を継続的に監視して効率を最適化

オープンソースエコシステム(OLMo、GPT-NeoX、torchtitan)は大規模学習への参入障壁を下げています。このガイドの技術を自分のプロジェクトに適用して、独自のLLMを学習してみてください。

参考文献