Skip to content
Published on

Torch-Titan完全ガイド:PyTorchによる大規模分散学習のすべて

Authors

はじめに

大規模言語モデル(LLM)のトレーニングは、現代のAIエンジニアリングにおける最も複雑な課題の一つです。Llama 3 70Bを単一GPUでトレーニングすることは不可能です。数十から数千のGPUを効率的に活用するために、さまざまな並列化戦略が必要になります。

MetaのPyTorchチームが開発したtorchtitanは、まさにこの種の複雑な大規模LLMトレーニングのリファレンス実装です。最新のPyTorch機能をクリーンでスケーラブルな方法で使用して構築されており、AIの研究者やエンジニアが分散トレーニングのベストプラクティスを学び、適用できるように設計されています。

このガイドでは、torchtitanのすべてを網羅しています。理論的な基礎から実践的なインストール、高度な並列化戦略、パフォーマンスの最適化まで説明します。


1. Torch-Titanの紹介

torchtitanとは?

torchtitanは、MetaのPyTorchチームがオープンソース化した大規模LLMトレーニングのプロダクション品質のリファレンス実装です。GitHubではpytorch/torchtitanとして公開されています。

既存のLLMトレーニングコードベース(Megatron-LM、DeepSpeed、NeMoなど)は機能が豊富ですが、非常に複雑でもあります。torchtitanは異なる哲学を取っています。

  • 明確さ:最小限の抽象化でPyTorchネイティブAPIのみを使用
  • モジュール性:各並列化技術を独立してオン/オフ可能
  • モダン性:PyTorch 2.xの最新機能を積極的に活用
  • 再現性:トレーニング実験の再現を容易にする構造

サポートモデル

torchtitanがネイティブにサポートするモデル:

  • Llama 2(7B、13B、34B、70B)
  • Llama 3(8B、70B、405B)
  • Llama 3.1、3.2ファミリー

Llamaがデフォルトですが、構造に従えば任意のトランスフォーマーベースのモデルを追加できます。

既存フレームワークとの違い

Megatron-LM(NVIDIA)

  • GPUクラスタ向けの最も成熟したソリューション
  • 高度に最適化されているがNVIDIAエコシステムに依存
  • 複雑なコードベースでカスタマイズが困難

DeepSpeed(Microsoft)

  • ZeROオプティマイザで有名
  • トレーニングと推論の両方をサポート
  • C++カスタムカーネルに依存し、PyTorchとの完全な統合が難しい場合がある

torchtitan(Meta/PyTorch)

  • 純粋なPyTorchネイティブAPI
  • PyTorch 2.xのtorch.compile、FSDP2、torch.distributed.tensorを使用
  • 教育目的に適した明確な構造
  • PyTorchバージョンの更新に迅速に対応

リポジトリ構造

torchtitan/
├── torchtitan/
│   ├── models/
│   │   ├── llama/          # Llamaモデル定義
│   │   └── __init__.py
│   ├── parallelisms/
│   │   ├── parallelize_llama.py  # 並列化適用ロジック
│   │   ├── pipeline_llama.py     # パイプライン並列
│   │   └── __init__.py
│   ├── optimizer.py         # オプティマイザ設定
│   ├── checkpoint.py        # チェックポインティング
│   ├── profiling.py         # プロファイリング
│   └── utils.py
├── train.py                 # トレーニングエントリーポイント
├── train_configs/           # TOML設定ファイル
│   ├── llama3_8b.toml
│   ├── llama3_70b.toml
│   └── llama3_405b.toml
└── estimation.py            # メモリ/FLOPs推定ツール

2. 分散トレーニングパラダイムの復習

複数のGPUで大規模モデルをトレーニングするには4つの主な方法があります。torchtitanはこれら4つすべてをサポートし、同時に組み合わせることもできます(4D並列)。

データ並列(DP)

最も基本的な並列化形式。モデルは各GPUにコピーされ、データバッチが分割されて各GPUが独立して処理します。

GPU 0: [Batch 0-15]Gradient 0GPU 1: [Batch 16-31]Gradient 1  ├→ All-Reduce → 同期
GPU 2: [Batch 32-47]Gradient 2

PyTorchのDistributedDataParallel(DDP)が標準実装です。ただし、モデルはGPUメモリに収まる必要があります。FP16の70Bパラメータモデルだけで140GBが必要となり、単一GPUでは不可能です。

テンソル並列(TP)

テンソル並列は個々のモデル層を複数のGPUに分散させ、行列演算を列または行次元に沿って分割します。

線形層(d_model=4096、d_ff=16384)を4つのGPUに分散:

GPU 0:0-40954096 x 4096GPU 1:4096-8191
GPU 2:8192-12287
GPU 3:12288-16383

各GPUは全層の1/4のみを処理します。結果を結合するためにAll-ReduceまたはAll-Gatherが必要です。NVLinkの帯域幅が高いほど効率的です。

パイプライン並列(PP)

層がGPUに順番に分散されます。最初のGPUが最初のN層を処理し、その出力を次のGPUに渡します。

GPU 0: Layer 0-9   → activation →
GPU 1: Layer 10-19 → activation →
GPU 2: Layer 20-29 → activation →
GPU 3: Layer 30-39 → loss

単純な実装では一度に1つのマイクロバッチのみを処理するため、GPU利用率が低くなります(パイプラインバブル)。GPipe、1F1B、インターリーブドスケジュールがこの問題に対処します。

シーケンス並列(SP)

アテンション層のシーケンス次元を複数のGPUに分散します。長いコンテキスト(128Kトークン以上)でトレーニングする際のアテンション行列メモリがO(N²)で爆発する問題を解決します。

シーケンス長40964つのGPUに分散:
GPU 0: トークン 0-1023
GPU 1: トークン 1024-2047
GPU 2: トークン 2048-3071
GPU 3: トークン 3072-4095

Ring Attentionなどのアルゴリズムが分散シーケンス全体のフルアテンションを計算します。

4D並列:すべての組み合わせ

torchtitanの核心的な強みは、4つの戦略を同時に組み合わせる4D並列のサポートです。

4D並列の例(128 GPU):
- DP = 2  (データ並列、2レプリカ)
- TP = 8  (テンソル並列、層あたり8 GPU- PP = 4  (パイプライン並列、4ステージ)
- SP = TPと並行して有効化

GPU= DP x TP x PP = 2 x 8 x 4 = 64
SPを使用するとさらに多くの設定が可能)

各並列化は異なる通信パターンを持つため、ハードウェアトポロジに対して最適な設定を見つけることが重要です。一般的に:

  • TPは、NVLinkで接続されたGPUを持つ単一サーバー内に適用
  • PPはサーバー間に適用
  • DPは最外次元

3. FSDP2(Fully Sharded Data Parallel v2)

ZeROとFSDPの関係

Microsoft DeepSpeedのZeRO(Zero Redundancy Optimizer)は、パラメータ、勾配、オプティマイザの状態を複数のGPUに分散させる革新的なアプローチです。PyTorchのFSDP(Fully Sharded Data Parallel)は、このアイデアのPyTorchネイティブ実装です。

ZeROステージ:

  • ZeRO-1:オプティマイザ状態のみをシャーディング
  • ZeRO-2:オプティマイザ状態 + 勾配のシャーディング
  • ZeRO-3:オプティマイザ状態 + 勾配 + パラメータのシャーディング

FSPDはZeRO-3に対応します。

FSDP1とFSDP2の比較

PyTorch 2.0までは、torch.distributed.fsdp.FullyShardedDataParallel(FSDP1)が標準でした。PyTorch 2.4+では、torch.distributed._composable.fsdpfully_shardを使ったFSDP2が推奨されています。

主な違い:

機能FSDP1FSDP2
APIスタイルラッパーベースコンポーザブルAPI
TP統合限定的ネイティブ
メモリ効率良好より良好
torch.compile部分的完全
コード可読性複雑明確
# FSDP1スタイル(レガシー)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
model = FSDP(model, ...)

# FSDP2スタイル(torchtitanで使用)
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy

mp_policy = MixedPrecisionPolicy(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.float32
)

# 各トランスフォーマー層にFSDP2を適用
for layer in model.layers:
    fully_shard(layer, mp_policy=mp_policy)

# モデル全体にも適用
fully_shard(model, mp_policy=mp_policy)

FSPDの仕組み

FSPDはフォワードパス、バックワードパス、重み更新でそれぞれ異なる動作をします。

フォワードパス

  1. 層の実行前:All-Gatherが全GPUでシャーディングされたパラメータを再構築
  2. 層の実行:フルパラメータで計算
  3. 層の実行後:パラメータが破棄(メモリ節約)

バックワードパス

  1. 層の逆伝播前:All-Gatherがパラメータを再構築
  2. 勾配の計算
  3. Reduce-Scatterが勾配を各GPUにシャードとして分散
  4. パラメータの破棄

重み更新

  • 各GPUは自分のシャードのパラメータ、勾配、オプティマイザ状態のみを更新

4つのGPUで70Bモデルを使用した例:

  • メモリ節約:140GBの重み / 4 = GPU当たり35GB(+ 分散オプティマイザ状態)
  • トレードオフ:通信オーバーヘッド(All-Gather、Reduce-Scatter)

CPUオフロード

GPUメモリが不足している場合、オプティマイザの状態と勾配をCPU RAMにオフロードできます。

from torch.distributed._composable.fsdp import fully_shard, CPUOffloadPolicy

cpu_offload = CPUOffloadPolicy(offload_params=True)

for layer in model.layers:
    fully_shard(
        layer,
        offload_policy=cpu_offload
    )

CPUオフロードはメモリ使用量を劇的に削減しますが、PCIe転送によりトレーニングが遅くなります(2〜5倍)。メモリが絶対的に不足している場合の最終手段として使用してください。

FSDP2による混合精度

FSDP2は層ごとに異なる精度を適用できます。

from torch.distributed._composable.fsdp import MixedPrecisionPolicy

# 標準的な混合精度設定
mp_policy = MixedPrecisionPolicy(
    param_dtype=torch.bfloat16,   # パラメータ:BF16(通信効率)
    reduce_dtype=torch.float32    # 勾配削減:FP32(安定性)
)

# 特定の層をFP32に保つ(例:埋め込み層)
fully_shard(model.embed_tokens)  # mp_policyなし → FP32
for layer in model.layers:
    fully_shard(layer, mp_policy=mp_policy)
fully_shard(model, mp_policy=mp_policy)

4. テンソル並列

トランスフォーマーにおけるテンソル並列

Megatron-LMで最初に体系化されたテンソル並列は、トランスフォーマーの2つのコアモジュールに適用されます。

MLP(フィードフォワードネットワーク)

FFN(x) = GELU(x * W1) * W2

列並列(W1を分割):
GPU 0: x * W1[0:d/4]   -> GELU -> y[0:d/4]
GPU 1: x * W1[d/4:d/2] -> GELU -> y[d/4:d/2]
...

行並列(W2を分割 + All-Reduce):
GPU 0: y[0:d/4] * W2[0:d/4, :] -> 部分和 0
GPU 1: y[d/4:d/2] * W2[d/4:d/2, :] -> 部分和 1
...
All-Reduce -> 最終出力

セルフアテンション:Query、Key、Valueマトリックスをアテンションヘッドで分割。

# torchtitanでのTP適用(DTensorベース)
from torch.distributed.tensor.parallel import (
    parallelize_module,
    ColwiseParallel,
    RowwiseParallel,
    SequenceParallel,
    PrepareModuleInput,
)

# 並列化プランの定義
plan = {
    "attention": PrepareModuleInput(
        input_layouts=(Shard(1),),
        desired_input_layouts=(Replicate(),),
    ),
    "attention.wq": ColwiseParallel(),
    "attention.wk": ColwiseParallel(),
    "attention.wv": ColwiseParallel(),
    "attention.wo": RowwiseParallel(output_layouts=Shard(1)),
    "feed_forward.w1": ColwiseParallel(),
    "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
    "feed_forward.w3": ColwiseParallel(),
}

parallelize_module(model_layer, tp_mesh, plan)

DTensor:分散テンソルの基盤

torchtitanはPyTorchのtorch.distributed.tensor(DTensor)に基づいてテンソル並列を実装しています。DTensorは論理的には単一のテンソルですが、物理的には複数のデバイスに分散された新しい抽象概念です。

from torch.distributed.tensor import DTensor, Shard, Replicate
import torch.distributed as dist

# 1Dメッシュの作成(TP用)
tp_mesh = dist.init_device_mesh("cuda", (8,), mesh_dim_names=("tp",))

# 2Dメッシュの作成(DP + TP)
mesh_2d = dist.init_device_mesh(
    "cuda", (2, 8), mesh_dim_names=("dp", "tp")
)

# Shard(0):行方向シャーディング
# Shard(1):列方向シャーディング
# Replicate():レプリケート(シャーディングなし)

シーケンス並列

テンソル並列の自然な拡張で、LayerNormとDropout演算をシーケンス次元に沿って並列化します。

標準的なTPフロー:
Input(レプリケート)-> LayerNorm(レプリケート)-> Attention(TP-> [All-Reduce] -> Output(レプリケート)

SP + TPフロー:
Input(シーケンスシャード)-> LayerNorm(シーケンスシャード)->
[All-Gather] -> Attention(TP-> [Reduce-Scatter] ->
Output(シーケンスシャード)

SPはAll-ReduceではなくAll-Gather + Reduce-Scatterを使用します。通信量は同じですが、LayerNorm活性化がN個のGPUに分散されるためメモリ効率が向上します。


5. パイプライン並列

パイプラインバブル問題

単純なパイプライン並列は大きな非効率をもたらします。

単純なパイプライン(4 GPU4マイクロバッチ):

時間 ->
GPU 0: [M0 F][M0 B]           [      バブル      ]
GPU 1:        [M0 F][M0 B]    [      バブル      ]
GPU 2:               [M0 F][M0 B][    バブル    ]
GPU 3:                     [M0 F][M0 B]

F = フォワード、B = バックワード。GPU 0はM0を処理した後、長い時間待機します。 バブル比 = (PP - 1) / (マイクロバッチ数 + PP - 1)。

GPipeスケジュール

GPipeは複数のマイクロバッチをパイプラインに投入してバブルを削減します。

GPipe(4 GPU4マイクロバッチ):

時間 ->
GPU 0: [M0F][M1F][M2F][M3F]               [M3B][M2B][M1B][M0B]
GPU 1:      [M0F][M1F][M2F][M3F]      [M3B][M2B][M1B][M0B]
GPU 2:           [M0F][M1F][M2F][M3F][M3B][M2B][M1B][M0B]
GPU 3:                [M0F][M1F][M2F][M3F][M3B][M2B][M1B][M0B]

すべてのフォワードパスが完了してからバックワードパスが開始されます。欠点:活性化メモリ使用量が多い。

1F1B(One Forward One Backward)スケジュール

1F1Bはメモリ効率の良いスケジュールです。各GPUはマイクロバッチのフォワードとバックワードを交互に実行します。

# torchtitanでのパイプライン並列設定(TOML)
# [experimental]
# pipeline_parallel_degree = 4
# pipeline_parallel_schedule = "1f1b"

# Pythonレベルのセットアップ
from torchtitan.parallelisms.pipeline_llama import pipeline_llama

# パイプラインステージの作成
stages, model_parts = pipeline_llama(
    model,
    pp_mesh,
    parallel_dims,
    job_config,
    device,
    model_config
)

インターリーブドスケジュール

インターリーブドスケジュールは各GPUに複数のパイプラインステージの責任を割り当てます。バブルをさらに削減しますが、実装はより複雑です。

インターリーブド1F1B(4 GPUGPU当たり2ステージ):

GPU 0が担当:[レイヤー0-4][レイヤー20-24]
GPU 1[レイヤー5-9][レイヤー25-29]
...

6. torchtitanのインストールと使用

システム要件

  • Python 3.10+
  • PyTorch 2.5+(最新のナイトリー版を推奨)
  • CUDA 12.1+
  • GPU:H100またはA100推奨(最低40GB VRAM)

インストール

# リポジトリのクローン
git clone https://github.com/pytorch/torchtitan
cd torchtitan

# PyTorchナイトリーのインストール(最新機能を含む)
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu124

# 依存関係のインストール
pip install -r requirements.txt

# torchtitanパッケージのインストール
pip install -e .

# トークナイザーのダウンロード
python torchtitan/datasets/download_tokenizer.py \
    --repo_id meta-llama/Meta-Llama-3-8B \
    --tokenizer_path "original" \
    --hf_token YOUR_HF_TOKEN

設定ファイル(TOML)

torchtitanはTOML形式の設定ファイルを使用します。

# train_configs/llama3_8b.toml

[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 10
enable_tensorboard = true

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 8
seq_len = 2048
warmup_steps = 200
max_norm = 1.0
steps = 1000
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1  # 自動:残りのGPU数を使用
tensor_parallel_degree = 1
enable_loss_parallel = false

[experimental]
pipeline_parallel_degree = 1

[checkpoint]
enable_checkpoint = true
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled"

[activation_checkpoint]
mode = "selective"  # full、selective、none
selective_ac_option = "op"

[float8]
enable_float8_linear = false

Llama 3 8Bトレーニングの実行

# シングルノード、8 GPUでLlama 3 8Bをトレーニング
torchrun --nproc_per_node=8 \
    train.py \
    --job.config_file train_configs/llama3_8b.toml

# TP=4、DP=2の設定
torchrun --nproc_per_node=8 \
    train.py \
    --job.config_file train_configs/llama3_8b.toml \
    --training.tensor_parallel_degree 4 \
    --training.data_parallel_shard_degree 2

# マルチノードトレーニング(2ノード、各8 GPU)
torchrun \
    --nproc_per_node=8 \
    --nnodes=2 \
    --rdzv_id=101 \
    --rdzv_backend=c10d \
    --rdzv_endpoint=$MASTER_ADDR:29400 \
    train.py \
    --job.config_file train_configs/llama3_70b.toml

メモリ/FLOPs推定ツール

トレーニングを開始する前に、メモリとコンピューティング要件を事前に推定します。

# メモリとFLOPsを推定
python estimation.py \
    --job.config_file train_configs/llama3_8b.toml

# サンプル出力:
# Estimated model size: 15.01 GB
# Estimated optimizer state size: 30.02 GB
# Total estimated GPU memory: 65.23 GB
# Estimated FLOP per step: 1.61e+14

7. Flash Attentionの統合

Flash Attentionとは?

Flash Attentionは2022年にスタンフォード大学のTri Daoらが発表したアテンションアルゴリズムです。標準的なアテンションのメモリ複雑度をO(N²)からO(N)に削減し、実際のスピードアップが2〜4倍になります。

標準的なアテンションの問題:

標準的なアテンション:
S = Q * K^T    # (seq_len, seq_len)行列 -- メモリが爆発!
P = softmax(S)
O = P * V

seq_len=4096:  Sマトリックス = 4096 x 4096 x 4バイト = 64MB(FP32seq_len=32768: Sマトリックス = 32768 x 32768 x 4バイト = 4GB(単一レイヤー!)

Flash Attentionの主なアイデア:

  1. Q、K、Vをタイル(タイリング)で処理
  2. HBMではなくSRAM(共有メモリ)で計算
  3. HBMにフルアテンションマトリックスを実体化しない
  4. 数値的に正確(近似ではない)

Flash Attention 2と3

Flash Attention 2(2023年):

  • アテンション計算の並列性を改善
  • より効率的なマスキング
  • A100に対してアテンションが2〜4倍高速化

Flash Attention 3(2024年):

  • H100 Hopperアーキテクチャ向けに調整
  • WGMMA(Warpgroup Matrix Multiply Accumulate)を活用
  • H100でFA2よりさらに1.5〜2倍の改善
  • FP8サポート

torchtitanでのFlash Attentionの使用

# torchtitan/models/llama/model.pyでのアテンション実装
import torch.nn.functional as F

def forward(
    self,
    x: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> torch.Tensor:
    bs, seqlen, _ = x.shape
    xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

    # RoPE埋め込みの適用
    xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

    # PyTorch SDPAによるFlash Attention
    # torch.nn.functional.scaled_dot_product_attention
    # は自動的にFlash Attentionを使用
    output = F.scaled_dot_product_attention(
        xq, xk, xv,
        attn_mask=None,
        dropout_p=0.0,
        is_causal=True,  # 言語モデルのためのCausalマスク
    )

    return self.wo(output.view(bs, seqlen, -1))

PyTorch 2.0以降、torch.nn.functional.scaled_dot_product_attention(SDPA)は自動的にFlash Attentionを使用します。別のパッケージをインストールする必要はありません。

オリジナルのFlash Attentionパッケージを直接使用するには:

pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func

output = flash_attn_func(
    q, k, v,
    dropout_p=0.0,
    softmax_scale=None,  # 自動:1/sqrt(head_dim)
    causal=True,
)

8. 非同期テンソル並列

なぜ非同期TPか?

標準的なテンソル並列では、All-ReduceまたはAll-Gather/Reduce-Scatterが次の層の計算をブロックします。GPUは通信を待っている間、アイドル状態になります。

同期TP[GPU計算] -> [通信待機] -> [GPU計算] -> [通信待機] ...

非同期TPは通信と計算をオーバーラップさせます:

非同期TP[GPU計算 A] ────────────────────>
[通信(A結果)] ─────>
                    [GPU計算 B] ->
                          [通信(B]

torchtitanでの非同期TP

# 設定ファイルで非同期TPを有効化:
# [training]
# enable_async_tensor_parallel = true

# コードリファレンス
from torchtitan.parallelisms import parallelize_llama

# async_tpは内部的にtorch.distributed.tensor.parallelの
# async_all_gather機能を使用
model = parallelize_llama(
    model,
    world_mesh,
    parallel_dims,
    job_config
)

非同期TPは、計算と通信のオーバーラップにより、高いTP度(8以上)で最も効果的です。H100 + NVLink環境で5〜15%のスループット向上が報告されています。


9. チェックポインティングと再開

分散チェックポイント(dcp)

大規模分散トレーニングでは、チェックポインティングはモデルを保存するだけではありません。数千のGPU間で同時に保存・読み込みを行う必要があります。

PyTorchのtorch.distributed.checkpoint(dcp)は分散チェックポインティングをネイティブにサポートします。

# torchtitanのチェックポインティング
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import (
    get_model_state_dict,
    get_optimizer_state_dict,
    set_model_state_dict,
    set_optimizer_state_dict,
    StateDictOptions,
)

# 保存
def save_checkpoint(model, optimizer, step, output_dir):
    state_dict = {
        "model": get_model_state_dict(model),
        "optimizer": get_optimizer_state_dict(model, optimizer),
        "extra": {"step": step},
    }

    dcp.save(
        state_dict,
        checkpoint_id=f"{output_dir}/step-{step}",
    )

# 読み込み
def load_checkpoint(model, optimizer, checkpoint_path):
    state_dict = {
        "model": get_model_state_dict(model),
        "optimizer": get_optimizer_state_dict(model, optimizer),
        "extra": {},
    }

    dcp.load(
        state_dict,
        checkpoint_id=checkpoint_path,
    )

    set_model_state_dict(
        model,
        model_state_dict=state_dict["model"],
        options=StateDictOptions(strict=True),
    )
    set_optimizer_state_dict(
        model,
        optimizer,
        optim_state_dict=state_dict["optimizer"],
    )

    return state_dict["extra"]["step"]

非同期チェックポインティング

大規模モデルのチェックポインティングには数分かかる場合があり、その間GPUが停止します。非同期チェックポインティングはトレーニングを継続しながらバックグラウンドで保存します。

# 非同期チェックポインティング設定
[checkpoint]
async_mode = "async"  # "disabled"、"async"、"async_with_pinned_mem"
from torchtitan.checkpoint import CheckpointManager

checkpoint_manager = CheckpointManager(
    dataloader=train_dataloader,
    model_parts=model_parts,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    states={"train_state": train_state},
    job_config=job_config,
)

# トレーニングループ内
for step in range(num_steps):
    # ... トレーニングコード ...

    # 非同期にチェックポイントを保存(ノンブロッキング)
    checkpoint_manager.save(curr_step=step, force=False)

チェックポイント形式の変換

分散トレーニングチェックポイントを単一ファイルやHuggingFace形式に変換できます。

# 分散チェックポイントを単一HuggingFaceモデルに変換
python scripts/convert_checkpoint.py \
    --checkpoint_path outputs/checkpoint/step-1000 \
    --output_path outputs/hf_model \
    --model_type llama3 \
    --model_flavor 8B

10. パフォーマンスプロファイリング

PyTorchプロファイラーの使用

パフォーマンスのボトルネックを見つけるには、まずどこに時間がかかっているかを知る必要があります。PyTorchプロファイラーはそのための強力なツールです。

import torch
from torch.profiler import profile, record_function, ProfilerActivity

# 基本的なプロファイリング
with profile(
    activities=[
        ProfilerActivity.CPU,
        ProfilerActivity.CUDA,
    ],
    record_shapes=True,    # テンソルシェイプを記録
    profile_memory=True,   # メモリ使用量を記録
    with_stack=True,       # コールスタックを記録
) as prof:
    with record_function("model_inference"):
        output = model(input)

# 結果の表示
print(prof.key_averages().table(
    sort_by="cuda_time_total",
    row_limit=20
))

# Chromeトレースの保存(chrome://tracingで可視化)
prof.export_chrome_trace("trace.json")

torchtitan組み込みプロファイリング

torchtitanは設定ファイルでプロファイリングを制御します。

[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100             # 100ステップごとにプロファイリング
enable_memory_snapshot = true  # メモリスナップショットを保存
# リファレンス:torchtitan/profiling.pyの内部
from contextlib import contextmanager
import torch.profiler as profiler

@contextmanager
def maybe_enable_profiling(config, global_step=0):
    if not config.profiling.enable_profiling:
        yield
        return

    with profiler.profile(
        activities=[
            profiler.ProfilerActivity.CPU,
            profiler.ProfilerActivity.CUDA,
        ],
        schedule=profiler.schedule(
            skip_first=10,
            wait=5,
            warmup=5,
            active=1,
            repeat=1,
        ),
        on_trace_ready=profiler.tensorboard_trace_handler(
            config.profiling.save_traces_folder
        ),
        record_shapes=True,
        profile_memory=True,
        with_stack=True,
    ) as p:
        yield p

TensorBoard統合

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(log_dir="runs/experiment_1")

# トレーニングループ内
for step, (input, target) in enumerate(dataloader):
    loss = train_step(model, optimizer, input, target)

    writer.add_scalar("Loss/train", loss.item(), step)
    writer.add_scalar("LR", optimizer.param_groups[0]["lr"], step)

    # GPUメモリのモニタリング
    writer.add_scalar(
        "GPU/memory_allocated_GB",
        torch.cuda.memory_allocated() / 1e9,
        step
    )
    writer.add_scalar(
        "GPU/memory_reserved_GB",
        torch.cuda.memory_reserved() / 1e9,
        step
    )

# TensorBoardの起動:
# tensorboard --logdir=runs/

メモリ使用量分析

# GPUメモリスナップショット
torch.cuda.memory._record_memory_history(max_entries=100000)

# トレーニングの一部を実行
for step in range(100):
    loss = train_step(model, optimizer, batch)

# メモリスナップショットの保存
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")

# 分析ツールで可視化:
# https://pytorch.org/memory_vizに.pickleをアップロード
# レイヤーごとの分析のためのメモリプロファイラー
from torch.cuda._memory_viz import profile_plot

with open("memory_profile.html", "w") as f:
    f.write(profile_plot(prof))

トレーニング効率指標:MFU

MFU(Model FLOP Utilization)は実際のGPU性能と理論的な最大値の比率です。

def compute_mfu(
    model_num_params: int,
    batch_size: int,
    seq_len: int,
    elapsed_time: float,   # ステップあたりの時間(秒)
    num_gpus: int,
    gpu_peak_flops: float, # GPUの理論的ピークFLOPS
) -> float:
    """
    LLMトレーニングのMFUを計算。
    リファレンス:PaLM論文(Chowdhery et al., 2022)
    """
    # フォワードFLOPs = 2 * num_params * batch_size * seq_len
    # バックワードはフォワードの約2倍 -> 合計6 * num_params * batch_size * seq_len
    flops_per_step = 6 * model_num_params * batch_size * seq_len
    achieved_flops = flops_per_step / elapsed_time
    peak_flops_total = gpu_peak_flops * num_gpus

    return achieved_flops / peak_flops_total

# 使用例
mfu = compute_mfu(
    model_num_params=8e9,     # 8Bパラメータ
    batch_size=8,
    seq_len=2048,
    elapsed_time=0.5,         # ステップあたり0.5秒
    num_gpus=8,
    gpu_peak_flops=989e12,    # H100 BF16:約989 TFLOPS
)
print(f"MFU: {mfu:.1%}")  # 例:MFU: 45.2%

# 一般的に達成可能なMFU:
# - 良好な実装:40〜60%
# - 最適化済み(torchtitan、Megatron):50〜65%
# - 理論的最大:約70%(通信/メモリオーバーヘッドは避けられない)

11. 実践的なトレーニング設定例

Llama 3 8B:8x H100シングルノード

# train_configs/llama3_8b_h100x8.toml

[job]
dump_folder = "./outputs/llama3_8b"

[model]
name = "llama3"
flavor = "8B"

[training]
batch_size = 4
seq_len = 8192
data_parallel_replicate_degree = 1
data_parallel_shard_degree = 8  # FSDP2:8 GPU
tensor_parallel_degree = 1

[activation_checkpoint]
mode = "selective"

[float8]
enable_float8_linear = true  # FP8トレーニングを有効化

[optimizer]
name = "AdamW"
lr = 1e-4
torchrun --nproc_per_node=8 train.py \
    --job.config_file train_configs/llama3_8b_h100x8.toml

Llama 3 70B:4ノード × 8 H100(合計32 GPU)

# train_configs/llama3_70b_32gpu.toml

[training]
batch_size = 2
seq_len = 4096
data_parallel_replicate_degree = 2  # DDP:2レプリカ
data_parallel_shard_degree = 4      # FSDP:4GPUシャーディング
tensor_parallel_degree = 4          # TP:4 GPU

# 総GPU数 = 2 x 4 x 4 = 32

[experimental]
pipeline_parallel_degree = 1  # PP無効

Llama 3 405B:大規模クラスター

# train_configs/llama3_405b_large.toml

[training]
batch_size = 1
seq_len = 2048
data_parallel_replicate_degree = 4   # DDP
data_parallel_shard_degree = 8       # FSDP
tensor_parallel_degree = 8           # TP

[experimental]
pipeline_parallel_degree = 8         # PP

# 総GPU数 = 4 x 8 x 8 x 8 = 2048

まとめ:torchtitanから学ぶこと

torchtitanは単なるトレーニングツールを超えて、現代のLLM分散トレーニングのベストプラクティスを明確に示す教育リソースです。

このガイドの主なポイント:

  1. 4D並列:DP、TP、PP、SPを組み合わせて数千のGPUを効率的に活用
  2. FSDP2:PyTorchネイティブAPIによるZeRO-3レベルのメモリ効率
  3. Flash Attention:O(N²)のメモリがO(N)に;2〜4倍のスピードアップ
  4. 非同期チェックポインティング:トレーニングを停止せずにチェックポイントを保存
  5. MFU最適化:GPU理論性能の40〜65%を目標にする

分散トレーニングは、ハードウェア、ソフトウェア、アルゴリズムが交差する複雑な領域です。torchtitanはこの複雑さを可能な限り透明な方法で公開し、学習と実験を身近なものにしています。コードを自分で実行し、さまざまな並列化の組み合わせを試して、分散トレーニングの直感を養いましょう。


参考文献