Skip to content
Published on

分散学習 & GPUインフラ 2026 ディープダイブ — DeepSpeed、FSDP2、Megatron-LM、Ray Train、JAX、TorchTitan、Blackwell GB200、MI325X、TPU v5p 総まとめ

Authors

2026年5月、LLM学習インフラはようやく「スタック」と呼べる形になった。NVIDIA Blackwell GB200 NVL72ラック(72-GPU NVLinkドメイン)が本格出荷され、AMD Instinct MI325Xは256GB HBM3EでGPT-4級モデルのシングルノード・フルパラメータSFTを可能にする。ソフトウェア面でもPyTorch FSDP2が正式安定化し、torch.compileと初めて安定して結合できるようになった。NVIDIA Megatron-Coreは外部学習フレームワーク(NeMo、TorchTitan、MosaicML Composer)に一級市民として組み込まれている。本稿は分散学習フレームワーク、並列化戦略、ハードウェア選択、障害モードまで一度に整理する。

なぜ2026年に分散学習を改めて整理すべきか

2024年まではH100 80GBが事実上の標準で、トークン当たりの学習コストはモデルサイズに比例して急上昇していた。2026年5月現在、B200(192GB HBM3E)とGB200 NVL72の登場で同じモデルを学習する際のウォールクロックが30〜50%短縮された。MI325XとGaudi 3が価格/性能でNVIDIA単一供給依存を崩し始め、同時にFSDP2 + torch.compile、DeepSpeed ZeRO++ stage 3、Megatron-Coreのコンテキスト並列(CP)がすべてプロダクションレディになった。今や問題は「どのフレームワークが速いか」ではなく、「どの組み合わせがfault-toleranceまで含めて合理的か」である。

データ並列、テンソル並列、パイプライン並列、シーケンス並列

分散学習の4軸はデータ並列(DP/DDP)、テンソル並列(TP)、パイプライン並列(PP)、シーケンス並列(SP)である。DDPは同一モデルのレプリカを複数GPUに置き勾配をAll-Reduceする最もシンプルな方式で、モデルが1枚のGPUに収まるとき最も効率的である。TPは行列乗算を次元方向に分割しAttentionのQKVやMLPを複数GPUに分散させる。PPはモデルのレイヤースタックをノード群で切り分け、SPは長文(128K以上)学習のメモリを減らすためシーケンス方向に活性化を分割する。実用的な70B+ LLM学習はTP=8、PP=8、DP=16のような3D並列にSPを重ねる構成が標準である。

エキスパート並列(EP)とMoE学習の難しさ

Mixtral 8x22BやDeepSeek-V3のようなMoEモデルはルーターがトークンをK個のエキスパートに送り、それらがGPUに分散していると必ずAll-to-Allが発生する。このAll-to-AllはNCCLのncclSend/ncclRecvを対にして実装されるが、トークンの偏り(一部エキスパートに集中し他はアイドル)が起きるとGPU稼働率は30%まで落ちる。2026年の標準はcapacity factorを1.25に設定し、補助ロードバランス損失を0.01程度加える方法で、Megatron-CoreもDeepSpeed MoEもエキスパート並列を標準サポートしている。

ZeRO 1/2/3とFSDPの等価性

DeepSpeed ZeROはオプティマイザ状態(stage 1)、勾配(stage 2)、パラメータ(stage 3)の順にシャーディングする戦略である。PyTorch FSDP1とFSDP2は本質的にZeRO stage 3をPyTorchネイティブで実装したものに近く、同じモデル・同じクラスタ・同じオプティマイザ設定なら学習曲線はほぼ一致する。FSDP2の差別化点はper-parameter sharding(FSDP1のflat-parameterではなく)、torch.compileとの安定した合成、DTensorベースの明快なモデル分散である。2026年のほとんどの新規学習ジョブはDeepSpeedからFSDP2に移行した。

FSDP2 + torch.compile の実例

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FSDPModule, fully_shard, MixedPrecisionPolicy
from torch.distributed.device_mesh import init_device_mesh

dist.init_process_group(backend="nccl")
mesh = init_device_mesh("cuda", (8,))  # 8-way data parallel

model = build_llama3_70b()  # nn.Module
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32)

for layer in model.layers:
    fully_shard(layer, mesh=mesh, mp_policy=mp_policy)
fully_shard(model, mesh=mesh, mp_policy=mp_policy)

model = torch.compile(model, mode="reduce-overhead", fullgraph=False)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, fused=True)
for batch in dataloader:
    out = model(batch.input_ids)
    loss = out.loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)

上記コードの鍵はfully_shardをレイヤー単位で呼び、per-blockのall-gather/reduce-scatterを可能にする点にある。FSDP1のモノリシックなFlatParameterの代わりに、FSDP2はnn.ParameterをDTensorに変換する。

DeepSpeed ZeRO-3 + ZeRO-Infinity 設定

DeepSpeedは依然としてZeRO++通信最適化とNVMeオフロード(ZeRO-Infinity)で優位を持つ。以下は405B級モデルを32-GPU H100 1ノードで学習する典型的な構成である。

zero_optimization:
  stage: 3
  offload_optimizer:
    device: cpu
    pin_memory: true
  offload_param:
    device: nvme
    nvme_path: /mnt/local_nvme
    buffer_count: 5
    buffer_size: 1e9
  stage3_max_live_parameters: 1e9
  stage3_max_reuse_distance: 1e9
  stage3_prefetch_bucket_size: 5e8
  stage3_param_persistence_threshold: 1e6
  contiguous_gradients: true
  reduce_bucket_size: 5e8
  allgather_bucket_size: 5e8
bf16:
  enabled: true
gradient_clipping: 1.0
train_micro_batch_size_per_gpu: 1
gradient_accumulation_steps: 16

NVMeオフロードはGen5 NVMeディスク(7GB/s以上)があって初めて意味を持つ。Gen4やSATA SSDではスループットが足りずGPUが飢える。

NVIDIA Megatron-LM と Megatron-Core

Megatron-LMはNVIDIAが2019年から維持しているリファレンスLLM学習コードベースである。2024年以降は中核の分散プリミティブがMegatron-Coreライブラリに切り出され、NeMo、TorchTitan、MosaicML等に組み込まれる。TP/PP/SP/CP/EPがすべて一級市民で、オープンソースとして最初にfp8学習を安定化したスタックでもある。70Bモデル学習のランチャー例:

torchrun --nproc_per_node=8 --nnodes=64 \
  --rdzv_id=meg70b --rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR:29500 \
  pretrain_gpt.py \
  --num-layers 80 --hidden-size 8192 --num-attention-heads 64 \
  --seq-length 8192 --max-position-embeddings 8192 \
  --micro-batch-size 1 --global-batch-size 1024 \
  --tensor-model-parallel-size 8 \
  --pipeline-model-parallel-size 8 \
  --sequence-parallel \
  --context-parallel-size 2 \
  --fp8-format hybrid --fp8-margin 0 \
  --use-distributed-optimizer \
  --recompute-activations --recompute-granularity selective \
  --tokenizer-type Llama3Tokenizer \
  --data-path /data/tokens \
  --save /checkpoints/meg70b --save-interval 500 \
  --train-iters 100000

TorchTitan、Composer、Lightning Fabric

Metaが公開したTorchTitanは「Megatron的なリファレンス学習コードはPyTorchネイティブでも書ける」というメッセージを掲げる。FSDP2 + Tensor Parallel + Pipeline Parallel + Selective Activation Checkpointingを純粋なPyTorch DTensor APIで実装し、Llama 70B/405Bのレシピも同梱する。Databricksに買収されたMosaicML ComposerはBPEトークナイザからLRスケジューラまで抽象化したトレーナーで、MPT/DBRX学習にそのまま使われた。Lightning FabricはTrainer抽象を解き、任意のPyTorchコードに分散学習を部分的に導入するのに向く。

Ray Train と Ray Tune の位置付け

AnyscaleのRay TrainはPyTorch/JAX/Hugging Face AccelerateのワーカーをRay Actor上でオーケストレートする。自前の分散アルゴリズムは持たず、トレーナーの起動と終了を整える役割に集中する。クラスタが1000-GPU級になるとSlurmやKubernetes単体ではelastic再起動とfault recoveryが弱く、その隙間をRayが埋める。次はRay Train経由でFSDP2学習を立ち上げる例。

import ray
from ray.train.torch import TorchTrainer, TorchConfig
from ray.train import ScalingConfig, RunConfig, FailureConfig

def train_func(config):
    import torch.distributed as dist
    from torch.distributed.fsdp import fully_shard
    dist.init_process_group("nccl")
    model = build_model()
    fully_shard(model)
    train_loop(model, config)

trainer = TorchTrainer(
    train_func,
    train_loop_config={"lr": 3e-4, "batch": 1},
    scaling_config=ScalingConfig(num_workers=512, use_gpu=True),
    torch_config=TorchConfig(backend="nccl", timeout_s=1800),
    run_config=RunConfig(
        storage_path="s3://ray-train/llama70b",
        failure_config=FailureConfig(max_failures=5),
    ),
)
result = trainer.fit()

Hugging Face Accelerate、trl、axolotl、unsloth、torchtune、LLaMA-Factory

Accelerateはユーザーコードにほとんど手を入れずに単/複数GPU・複数ノード学習を有効にする薄いラッパーで、DeepSpeed、FSDP、Megatronのいずれのバックエンドも呼び出せる。trl(Transformers RL)はSFT、DPO、GRPO、RLOOのようなアラインメント学習用。axolotlはYAML 1枚でLoRA/QLoRAフル学習を回すコミュニティ標準ツールである。unslothはTritonカーネルを手書きすることで7BモデルのLoRA学習速度を2倍に伸ばす無料ツール、PyTorch公式のtorchtuneは「フレームワーク不要の学習レシピ」を志向する。アリババのms-swiftとLLaMA-Factoryは中国エコシステムでデファクト標準である。

JAX、Flax、Equinox、MaxText、Pax、Levanter

Google系の分散学習はJAXのpjit/jit + Sharding APIで実質すべての並列化を表現する。TP・PP・DPは別ライブラリではなく単一のSharding仕様に集約される。MaxTextはGoogleが公開したリファレンスLLM学習コード(JAX、TPUとGPU両対応)、PaxはPaLM/Geminiの学習に使われたGoogle社内フレームワークの一部、LevanterはStanford CRFMが作った再現性・合成データ重視の学習ライブラリである。最小のJAX pjit例:

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P

devices = jax.devices()  # 例えば256 TPU chips
mesh = Mesh(devices.reshape(8, 32), axis_names=("dp", "mp"))

def init_params(key):
    w = jax.random.normal(key, (16384, 16384))
    return w

w_sharding = NamedSharding(mesh, P(None, "mp"))
w = jax.device_put(init_params(jax.random.PRNGKey(0)), w_sharding)

@jax.jit
def train_step(w, x, y):
    logits = x @ w
    loss = jnp.mean((logits - y) ** 2)
    grads = jax.grad(lambda w: jnp.mean((x @ w - y) ** 2))(w)
    return w - 0.01 * grads, loss

混合精度 — fp16、bf16、fp8、NF4、MXFP4

Hopper(H100)以降、bf16が事実上の学習デフォルトdtypeになった。fp8(E4M3とE5M2)学習はHopperで初めて実用化され、Blackwell B200/B100とMegatron-Core・Transformer Engineの組み合わせでは70B+でも安定収束が報告されている。NF4(NormalFloat 4-bit)は推論・QLoRA量子化、MXFP4(microscaling fp4)はBlackwellで新たに導入されたforward-only活性化量子化形式である。fp4で学習そのものを回すのは2026年でも実験的だが、活性化チェックポイント保存と推論段階では標準になりつつある。EleutherAIのGPT-NeoX(Pythiaシリーズを育てたDeepSpeed+Megatron学習フレームワーク)のような既存コードベースも2026年にはMegatron-CoreまたはTorchTitanへ移行してfp8をオンにする流れである。Transformer Engineの利用例:

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling

fp8_recipe = DelayedScaling(
    fp8_format=Format.HYBRID,  # E4M3 forward, E5M2 backward
    amax_history_len=1024,
    amax_compute_algo="max",
)

with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    output = te_linear_layer(input_tensor)
    loss = compute_loss(output, target)
loss.backward()

長文(128K〜1Mトークン)学習ではパラメータメモリより活性化メモリが支配的になる。Activation checkpointingはforward中間活性化を捨ててbackwardで再計算するトレードオフで、メモリは半分以下になるがウォールクロックは25〜33%伸びる。Megatron-Coreのselective recomputeはAttentionのようなメモリ重量級ブロックだけを選んで再計算し、時間損失を5〜10%に抑える。

CUDA Graphs、NCCLチューニング、クラスタネットワーク

CUDA Graphsはforward/backwardのカーネル発行オーバーヘッドを丸ごとキャプチャし1枚のグラフとして再生する。micro-batch=1のような小バッチでスループットを1.3〜1.5倍に引き上げる。NCCLチューニングはNCCL_ALGO(Ring/Tree/CollNet)、NCCL_PROTO(LL/LL128/Simple)、NCCL_BUFFSIZE、NCCL_NSOCKS_PERTHREADをクラスタトポロジに合わせて設定するのが要である。ネットワークファブリックは概ね3種類で、NVIDIA Quantum-X800 InfiniBand(SHARP in-network reduction)、Ethernetベースの RoCE v2(PFC/ECNチューニングが必要)、HPE Slingshot 11(Frontier・El Capitanで実証)、AWS EFA(独自SRDプロトコル)。GB200 NVL72クラスタで一般的な環境変数:

export NCCL_DEBUG=WARN
export NCCL_IB_HCA=mlx5_0,mlx5_1,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7
export NCCL_IB_GID_INDEX=3
export NCCL_IB_TIMEOUT=22
export NCCL_IB_RETRY_CNT=7
export NCCL_SOCKET_IFNAME=ib0
export NCCL_NET_GDR_LEVEL=PHB
export NCCL_P2P_LEVEL=NVL
export NCCL_NVLS_ENABLE=1
export NCCL_BUFFSIZE=8388608
export NCCL_ALGO=NVLSTree

NVIDIA Blackwell B100/B200/GB200 NVL72

Blackwell B200は192GB HBM3E、8TB/sメモリ帯域、FP8 9 PFLOPS・FP4 18 PFLOPSを提供する。GB200 Grace Hopper SuperchipはNVLink-C2CでGrace CPUとB200ダイ2基を直接接続し、GB200 NVL72ラックは72個のB200を単一のNVLinkドメインで束ねて1.4 EFLOPS FP4を出す。H100 DGXに比べトークン当たり学習コストが30〜40%下がり、FSDP2との組み合わせで70Bクラスの1兆トークン学習が約1週間に圧縮される(2025年のH100 80GB DGX 32-nodeを基準として)。

AMD Instinct MI300X/MI325X と ROCm

MI300Xは192GB HBM3、MI325Xは256GB HBM3Eを搭載し、シングルGPUメモリの大きさが学習中のメモリ圧迫を直接緩和する。ソフトウェア面でPyTorchはROCm上でもほぼ同一のAPIで動作し、Megatron-LMとDeepSpeedはROCmを正式サポートする。短所は最新のFlashAttentionやTritonカーネル対応がNVIDIA比6〜12ヶ月遅れる点、そしてfp8学習の安定性がまだNVIDIA水準に達していない点である。

Intel Gaudi 3 と SynapseAI

Intel Habana Gaudi 3は128GB HBM2E、1.835 PFLOPS BF16を提供し、価格がH100比30〜40%安いのが主な訴求点である。SynapseAIはPyTorchと密に統合するのではなくPyTorchモジュールをグラフコンパイラに吸収するアプローチで、慣れ親しんだコードをそのまま動かすのは難しい。それでもIntel Tiber AI Cloudと組み合わせるとBERTやLlama-7BのようなカノニカルワークロードではコストパフォーマンスがHopperを脅かす。

AWS Trainium 2 と Google TPU v5p、v6e Trillium

AWS Trainium 2(Trn2)は1インスタンス64 Trainiumチップ、1.5TB HBM、EFAで結ぶUltraCluster構成が主力である。AnthropicのClaude学習はTrainium 2 + GPUのハイブリッドで進められていることが公開されている。Google TPU v5pは8960チップポッド、v6e TrillimuはH100比4.7倍のピーク演算を主張し、Gemini学習の主力である。JAXと組み合わせると単一のsharding仕様でおよそ1万チップまでそのままスケールする。

分散チェックポイントと障害復旧

学習ジョブは「失敗しないもの」ではなく「頻繁に失敗するもの」と仮定すべきである。1000-GPUクラスタで1週間学習するとGPU/ネットワーク障害が平均1.5件発生するという報告が一般的である。PyTorchはDCP(torch.distributed.checkpoint)とTorchSnapshotで非同期分散チェックポイントを標準化し、Megatron-LMは独自のzarrベース形式でPP/TPトポロジ変更後も同じチェックポイントをロードできる。学習ループは(1)N stepごとの非同期保存、(2)NCCL_TIMEOUT超過時のgraceful abort、(3)torchrunまたはRayによるelastic restartで最新チェックポイントから再開、というパターンに従う。

import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict

def save_async(model, optimizer, step, ckpt_dir):
    state = {"model": model, "optim": optimizer, "step": step}
    storage_writer = dcp.FileSystemWriter(f"{ckpt_dir}/step_{step}", thread_count=8)
    return dcp.async_save(state, storage_writer=storage_writer)

def load_latest(model, optimizer, ckpt_dir):
    state = {"model": model, "optim": optimizer}
    storage_reader = dcp.FileSystemReader(ckpt_dir)
    dcp.load(state, storage_reader=storage_reader)
    return state.get("step", 0)

学習中の障害モード — CUDA OOM、NCCLハング、発散ロス

最も一般的な3つの障害はCUDA OOM、NCCL集合通信のハング、ロス発散である。OOMは多くの場合シーケンス長の二乗で活性化メモリが膨らむのが原因なので、まずrecompute granularityとmicro-batchを調整する。NCCLハングはNCCL_DEBUG=INFOとtimeout_sの設定でどのrankが停止したかを確認し、ネットワークインターフェース・ファイアウォール・対角通信障害を点検する。ロス発散はLRが大きすぎる、gradient clipping漏れ、fp8 amax統計の更新失敗が典型的な原因である。

GPUクラウド — Lambda、CoreWeave、Modal、Together、Crusoe、RunPod、Lepton

GPUクラウド市場は2026年にtier-1/tier-2の形に二分された。Tier-1(CoreWeave、Lambda Labs、Crusoe、Together AI、Nebius)はInfiniBandまたはRoCEで結ばれた数千GPU規模のスーパーポッドを時間単位の予約で売る。Tier-2(RunPod、Vast.ai、LeptonAI、Fal)は単一ノード〜小規模クラスタを分単位で貸す。Modalはコード関数からGPUを即座に立ち上げるモデルで、ファインチューニングとサービングで人気を集める。価格はおおよそH100 80GBが1.99〜3.5 USD/h、B200が4.5〜6.5 USD/h、MI300Xが1.7〜2.8 USD/h(2026年5月時点の表示価格)。

コスト・電力・カーボン — トークン当たり学習コストの分解

70Bモデルの1兆トークン学習コストはおおむね(総GPU時間)×(時間単価)である。32-node H100 80GB DGX(256 GPU)で約25日学習なら256×25×24=153,600 GPU時間、時間3 USDなら約46万 USD。B200 NVL72ラック1本(72 GPU)で同じ学習が約14日なら72×14×24=24,192 GPU時間、時間6 USDなら約14.5万 USDに落ちる。この差が2026年の新規学習ジョブが圧倒的にBlackwellへ流れる理由である。電力は1ラック当たり120kW級なのでデータセンターのPUEと冷却設計が直接予算に効いてくる。

韓国事例 — LG AI Research EXAONE と Naver HyperCLOVA X

LG AI ResearchはEXAONE 3.5/4.0の学習に自社H100クラスタ(推定1000+ GPU)とMegatron-LMベースのスタックを使うとされる。EXAONE 3.5 32Bは13Bトークンの韓・英混合学習に韓国語トークナイザと追加の指示チューニングを重ねた。Naver HyperCLOVA Xは韓国内IDCでSamsung Heavy Industries / Naver Cloudの協業で学習され、韓国のPIPAデータコンプライアンスに合わせ海外データを最小化するパイプラインを持つ。KTは自社1兆パラメータ級のMi:dmモデルでMegatron + DeepSpeedの組み合わせを公にしている。

日本事例 — Sakana AI、Preferred Networks PFCC、ABEJA

東京拠点のSakana AIは学習コストを回避する「進化的マージ」アプローチを取る。フルスクラッチ事前学習ではなく既存モデルを組み合わせ・進化させ、インフラ負担を大幅に下げている。Preferred Networksは自社MN-Core / MN-Core 2アクセラレータとH100を混在させたハイブリッドクラスタ(PFCC)でPLaMoシリーズを学習する。ABEJAはInsightシリーズの学習にAWS Trainium 2を採用し、リファレンスアーキテクチャを公開している。日本では経済産業省(METI)の「AIスーパーコンピュータ補助金」が国内の学習クラスタ整備の大きな推進力となっている。

どのスタックを選ぶか — 意思決定ツリー

70B未満のSFT/LoRA: torchtune + unsloth、またはaxolotl。70Bフルパラメータ事前学習: TorchTitan(FSDP2 + TP + PP)またはMegatron-Core。MoE 405B+: Megatron-Core(EP/CP強力)またはDeepSpeed MoE。JAX/TPU環境: MaxTextまたはLevanter。クラスタ運用: Ray Train(elastic) + SlurmまたはKubernetes。アラインメント(SFT/DPO/GRPO): trl。高速なSOTA追従: バックエンドを抽象化したラッパーとしてHF Accelerateを維持。

2027年の展望 — その先

Blackwell B300、AMD MI355X、Trainium 3、TPU v6pが2026年後半〜2027年初に立て続けに登場する。ソフトウェアでは(1)torch.compile + FSDP2がデフォルトとなり、DeepSpeedはZeRO++通信やNVMeオフロードのようなニッチに収まる、(2)Megatron-Coreがより多くの外部トレーナーに組み込まれ事実上のリファレンス計算層になる、(3)JAXがGPUでもシェアを取り戻す、の3点が見えてくる動きである。分散学習インフラはいまや「どう動かすか」より「どう運用するか」の問題に移りつつある。

References