Skip to content
Published on

Distributed Training & GPU Infrastructure 2026 Deep-Dive — DeepSpeed, FSDP2, Megatron-LM, Ray Train, JAX, TorchTitan, Blackwell GB200, MI325X, TPU v5p

Authors

As of May 2026, LLM training infrastructure finally has a shape that deserves to be called a "stack." NVIDIA Blackwell GB200 NVL72 racks (72-GPU NVLink domains) are shipping in volume, and AMD Instinct MI325X with 256GB HBM3E enables single-node full-parameter SFT of GPT-4-class models. On the software side, PyTorch FSDP2 has finally gone stable and combines cleanly with torch.compile for the first time, and NVIDIA Megatron-Core has been adopted as a first-class citizen inside external training frameworks (NeMo, TorchTitan, MosaicML Composer). This post walks through distributed-training frameworks, parallelism strategies, hardware choices, and failure modes in one pass.

Why revisit distributed training in 2026

As recently as 2024, H100 80GB was effectively the standard and per-token training cost climbed steeply with model size. By May 2026, the arrival of B200 (192GB HBM3E) and GB200 NVL72 has cut wall-clock for training the same model by 30 to 50 percent, and MI325X and Gaudi 3 have begun breaking single-vendor dependence on NVIDIA on price/performance. At the same time, FSDP2 + torch.compile, DeepSpeed ZeRO++ stage 3, and Megatron-Core context-parallel (CP) are all production-ready. The question is no longer which framework is fastest in microbenchmarks, but which combination is sensible end-to-end including fault tolerance.

Data, tensor, pipeline, and sequence parallelism

The four axes of distributed training are data parallel (DP/DDP), tensor parallel (TP), pipeline parallel (PP), and sequence parallel (SP). DDP places identical model replicas on multiple GPUs and All-Reduces gradients — the simplest scheme and the most efficient when the model fits in a single GPU's memory. TP slices matrix multiplications along a dimension so QKV and MLP are distributed across GPUs, PP cuts the layer stack across node groups, and SP shards activations along the sequence dimension to make long-context (128K+) training memory-feasible. Real 70B+ LLM training typically lands on 3D parallelism — say TP=8, PP=8, DP=16 — with SP layered on top.

Expert parallelism (EP) and the awkwardness of MoE training

MoE models like Mixtral 8x22B and DeepSeek-V3 route tokens to K experts, and when those experts are distributed across GPUs, All-to-All communication appears. The All-to-All is built from paired ncclSend/ncclRecv operations under NCCL, and when token routing is imbalanced (some experts swamped, others idle) GPU utilization can fall to 30 percent. The 2026 standard is a capacity factor of 1.25 paired with an auxiliary load-balancing loss around 0.01, and both Megatron-Core and DeepSpeed MoE provide expert parallelism out of the box.

ZeRO 1/2/3 equivalence with FSDP

DeepSpeed ZeRO incrementally shards optimizer state (stage 1), then gradients (stage 2), then parameters (stage 3). PyTorch FSDP1 and FSDP2 are essentially a PyTorch-native implementation of ZeRO stage 3, and for the same model, cluster, and optimizer settings the loss curves are essentially identical. FSDP2's differentiators are per-parameter sharding (instead of FSDP1's flat-parameter), stable composition with torch.compile, and a cleaner DTensor-based mental model. Most new training jobs in 2026 have migrated from DeepSpeed to FSDP2.

FSDP2 + torch.compile in practice

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FSDPModule, fully_shard, MixedPrecisionPolicy
from torch.distributed.device_mesh import init_device_mesh

dist.init_process_group(backend="nccl")
mesh = init_device_mesh("cuda", (8,))  # 8-way data parallel

model = build_llama3_70b()  # nn.Module
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32)

for layer in model.layers:
    fully_shard(layer, mesh=mesh, mp_policy=mp_policy)
fully_shard(model, mesh=mesh, mp_policy=mp_policy)

model = torch.compile(model, mode="reduce-overhead", fullgraph=False)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, fused=True)
for batch in dataloader:
    out = model(batch.input_ids)
    loss = out.loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)

The key trick above is calling fully_shard per layer to enable per-block all-gather/reduce-scatter. Where FSDP1 used a monolithic FlatParameter, FSDP2 converts each nn.Parameter into a DTensor.

DeepSpeed ZeRO-3 + ZeRO-Infinity config

DeepSpeed still wins on ZeRO++ communication optimizations and NVMe offload (ZeRO-Infinity). Below is a typical config used to train a 405B-class model on a single 32-GPU H100 node.

zero_optimization:
  stage: 3
  offload_optimizer:
    device: cpu
    pin_memory: true
  offload_param:
    device: nvme
    nvme_path: /mnt/local_nvme
    buffer_count: 5
    buffer_size: 1e9
  stage3_max_live_parameters: 1e9
  stage3_max_reuse_distance: 1e9
  stage3_prefetch_bucket_size: 5e8
  stage3_param_persistence_threshold: 1e6
  contiguous_gradients: true
  reduce_bucket_size: 5e8
  allgather_bucket_size: 5e8
bf16:
  enabled: true
gradient_clipping: 1.0
train_micro_batch_size_per_gpu: 1
gradient_accumulation_steps: 16

NVMe offload is only meaningful when you have Gen5 NVMe disks (7 GB/s+). With Gen4 or SATA SSDs the throughput is too low and the GPUs starve.

NVIDIA Megatron-LM and Megatron-Core

Megatron-LM is the reference LLM training codebase NVIDIA has maintained since 2019. Since 2024 its core distributed primitives have been extracted into the Megatron-Core library and embedded in NeMo, TorchTitan, MosaicML, and others. TP/PP/SP/CP/EP are all first-class, and it was the first open-source stack to stabilize fp8 training. A 70B launch script looks roughly like:

torchrun --nproc_per_node=8 --nnodes=64 \
  --rdzv_id=meg70b --rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR:29500 \
  pretrain_gpt.py \
  --num-layers 80 --hidden-size 8192 --num-attention-heads 64 \
  --seq-length 8192 --max-position-embeddings 8192 \
  --micro-batch-size 1 --global-batch-size 1024 \
  --tensor-model-parallel-size 8 \
  --pipeline-model-parallel-size 8 \
  --sequence-parallel \
  --context-parallel-size 2 \
  --fp8-format hybrid --fp8-margin 0 \
  --use-distributed-optimizer \
  --recompute-activations --recompute-granularity selective \
  --tokenizer-type Llama3Tokenizer \
  --data-path /data/tokens \
  --save /checkpoints/meg70b --save-interval 500 \
  --train-iters 100000

TorchTitan, Composer, and Lightning Fabric

TorchTitan, released by Meta, makes the case that "Megatron-style reference training code can also live entirely in PyTorch." It builds FSDP2 + Tensor Parallel + Pipeline Parallel + Selective Activation Checkpointing on top of pure PyTorch DTensor APIs, with bundled Llama 70B/405B recipes. MosaicML Composer (now Databricks) is a trainer that abstracts everything from BPE tokenizers to LR schedulers, and it was used to train MPT and DBRX directly. Lightning Fabric breaks the Trainer abstraction and is ideal for incrementally adding distributed training to arbitrary PyTorch code.

Where Ray Train and Ray Tune fit

Anyscale's Ray Train orchestrates PyTorch/JAX/Hugging Face Accelerate workers as Ray actors. It doesn't implement its own distributed algorithms — it focuses on bringing trainers up and tearing them down cleanly. Once a cluster grows past roughly 1000 GPUs, Slurm or Kubernetes alone provide only weak support for elastic restart and fault recovery, and Ray fills that gap. The following spins up an FSDP2 training job through Ray Train.

import ray
from ray.train.torch import TorchTrainer, TorchConfig
from ray.train import ScalingConfig, RunConfig, FailureConfig

def train_func(config):
    import torch.distributed as dist
    from torch.distributed.fsdp import fully_shard
    dist.init_process_group("nccl")
    model = build_model()
    fully_shard(model)
    train_loop(model, config)

trainer = TorchTrainer(
    train_func,
    train_loop_config={"lr": 3e-4, "batch": 1},
    scaling_config=ScalingConfig(num_workers=512, use_gpu=True),
    torch_config=TorchConfig(backend="nccl", timeout_s=1800),
    run_config=RunConfig(
        storage_path="s3://ray-train/llama70b",
        failure_config=FailureConfig(max_failures=5),
    ),
)
result = trainer.fit()

Hugging Face Accelerate, trl, axolotl, unsloth, torchtune, and LLaMA-Factory

Accelerate is a thin wrapper that flips single/multi-GPU/multi-node training on with almost no changes to user code, and it can dispatch under DeepSpeed, FSDP, or Megatron backends. trl (Transformers RL) is the standard for alignment-style training — SFT, DPO, GRPO, RLOO. axolotl is the community standard for full LoRA/QLoRA training driven by a single YAML file. unsloth wrote Triton kernels by hand and doubled LoRA training throughput on 7B models for free. PyTorch's official torchtune leans toward "framework-less training recipes." Alibaba's ms-swift and LLaMA-Factory are the de facto standards inside the Chinese ecosystem.

JAX, Flax, Equinox, MaxText, Pax, and Levanter

The Google-side distributed-training story expresses essentially all parallelism through JAX's pjit/jit + Sharding API. TP, PP, and DP are not separate libraries — they all collapse into a single Sharding specification. MaxText is Google's open reference LLM training code (JAX, runs on TPU and GPU), Pax is parts of the Google-internal framework used to train PaLM and Gemini, and Levanter is Stanford CRFM's reproducibility-and-synthetic-data-first training library. A minimal JAX pjit snippet:

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P

devices = jax.devices()  # e.g. 256 TPU chips
mesh = Mesh(devices.reshape(8, 32), axis_names=("dp", "mp"))

def init_params(key):
    w = jax.random.normal(key, (16384, 16384))
    return w

w_sharding = NamedSharding(mesh, P(None, "mp"))
w = jax.device_put(init_params(jax.random.PRNGKey(0)), w_sharding)

@jax.jit
def train_step(w, x, y):
    logits = x @ w
    loss = jnp.mean((logits - y) ** 2)
    grads = jax.grad(lambda w: jnp.mean((x @ w - y) ** 2))(w)
    return w - 0.01 * grads, loss

Mixed precision — fp16, bf16, fp8, NF4, MXFP4

Since Hopper (H100), bf16 has effectively been the default training dtype. fp8 (E4M3 and E5M2) training first became practical on Hopper, and on Blackwell B200/B100 with Megatron-Core plus Transformer Engine, stable convergence has been reported even on 70B+ models. NF4 (NormalFloat 4-bit) lives in inference and QLoRA quantization, and MXFP4 (microscaling fp4) is a new forward-only activation quantization format introduced on Blackwell. Training natively in fp4 is still experimental in 2026, but it's becoming the standard for activation-checkpoint storage and inference. EleutherAI's GPT-NeoX (the DeepSpeed + Megatron training framework behind the Pythia series) and similar codebases are now typically migrated to Megatron-Core or TorchTitan in 2026 to turn on fp8. A Transformer Engine usage example:

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling

fp8_recipe = DelayedScaling(
    fp8_format=Format.HYBRID,  # E4M3 forward, E5M2 backward
    amax_history_len=1024,
    amax_compute_algo="max",
)

with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    output = te_linear_layer(input_tensor)
    loss = compute_loss(output, target)
loss.backward()

Long-context (128K to 1M token) training is dominated by activation memory rather than parameter memory. Activation checkpointing drops mid-forward activations and recomputes them in backward — memory drops by half or more, but wall-clock grows 25 to 33 percent. Megatron-Core's selective recompute only re-runs memory-heavy blocks like attention, trimming the time cost back to 5 to 10 percent.

CUDA Graphs, NCCL tuning, and cluster networks

CUDA Graphs capture the kernel-launch overhead of forward and backward into a single graph that can be replayed. They lift throughput 1.3-1.5x on small batches (micro-batch=1). NCCL tuning is largely about setting NCCL_ALGO (Ring/Tree/CollNet), NCCL_PROTO (LL/LL128/Simple), NCCL_BUFFSIZE, and NCCL_NSOCKS_PERTHREAD to match cluster topology. Network fabrics generally fall into three buckets: NVIDIA Quantum-X800 InfiniBand (SHARP in-network reduction), Ethernet-based RoCE v2 (PFC/ECN tuning required), HPE Slingshot 11 (proven on Frontier and El Capitan), and AWS EFA (its own SRD protocol). A typical env set on a GB200 NVL72 cluster:

export NCCL_DEBUG=WARN
export NCCL_IB_HCA=mlx5_0,mlx5_1,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7
export NCCL_IB_GID_INDEX=3
export NCCL_IB_TIMEOUT=22
export NCCL_IB_RETRY_CNT=7
export NCCL_SOCKET_IFNAME=ib0
export NCCL_NET_GDR_LEVEL=PHB
export NCCL_P2P_LEVEL=NVL
export NCCL_NVLS_ENABLE=1
export NCCL_BUFFSIZE=8388608
export NCCL_ALGO=NVLSTree

NVIDIA Blackwell B100/B200/GB200 NVL72

Blackwell B200 delivers 192GB HBM3E, 8 TB/s memory bandwidth, and 9 PFLOPS FP8 / 18 PFLOPS FP4. The GB200 Grace Hopper Superchip ties a Grace CPU directly to two B200 dies over NVLink-C2C, and the GB200 NVL72 rack binds 72 B200s into a single NVLink domain for 1.4 EFLOPS FP4. Versus H100 DGX, this drops per-token training cost by 30 to 40 percent, and combined with FSDP2 it compresses a 1-trillion-token Llama-70B-class run to about a week (relative to a 32-node H100 80GB DGX baseline from 2025).

AMD Instinct MI300X/MI325X and ROCm

MI300X carries 192GB HBM3 and MI325X carries 256GB HBM3E, and that single-GPU memory advantage directly relieves memory pressure during training. On the software side PyTorch runs on ROCm with nearly identical APIs, and Megatron-LM and DeepSpeed officially support ROCm. The downside is that the newest FlashAttention and Triton kernels tend to land 6 to 12 months behind NVIDIA, and fp8 training stability is not yet at NVIDIA's level.

Intel Gaudi 3 and SynapseAI

Intel Habana Gaudi 3 supplies 128GB HBM2E and 1.835 PFLOPS BF16, and its main draw is being 30 to 40 percent cheaper than H100. SynapseAI absorbs PyTorch modules into a graph compiler rather than integrating with PyTorch eagerly, so running familiar code unchanged is harder. That said, combined with Intel Tiber AI Cloud the price/performance is attractive for canonical workloads like BERT or Llama-7B.

AWS Trainium 2 and Google TPU v5p, v6e Trillium

AWS Trainium 2 (Trn2) is sold mainly as instances with 64 Trainium chips, 1.5 TB HBM, and EFA-bound UltraCluster topology. Anthropic has publicly said Claude training runs on a Trainium 2 + GPU hybrid. Google TPU v5p ships as an 8960-chip pod, and v6e Trillium claims 4.7x peak compute over H100; it is the workhorse behind Gemini training. Combined with JAX, a single sharding spec scales to roughly 10,000 chips without changes.

Distributed checkpointing and failure recovery

You should assume training jobs do not fail to fail — they fail often. Reports of roughly 1.5 GPU/network incidents per week on a 1000-GPU cluster are routine. PyTorch standardizes async distributed checkpointing through DCP (torch.distributed.checkpoint) and TorchSnapshot, while Megatron-LM uses its own zarr-based checkpoint format that survives PP/TP topology changes. The training-loop pattern is (1) async save every N steps, (2) graceful abort when NCCL_TIMEOUT trips, and (3) elastic restart through torchrun or Ray that resumes from the most recent checkpoint.

import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict

def save_async(model, optimizer, step, ckpt_dir):
    state = {"model": model, "optim": optimizer, "step": step}
    storage_writer = dcp.FileSystemWriter(f"{ckpt_dir}/step_{step}", thread_count=8)
    return dcp.async_save(state, storage_writer=storage_writer)

def load_latest(model, optimizer, ckpt_dir):
    state = {"model": model, "optim": optimizer}
    storage_reader = dcp.FileSystemReader(ckpt_dir)
    dcp.load(state, storage_reader=storage_reader)
    return state.get("step", 0)

In-flight failure modes — CUDA OOM, NCCL hang, divergent loss

The three most common in-flight failures are CUDA OOM, NCCL collective hang, and divergent loss. OOM usually comes from activation memory blowing up quadratically with sequence length — tune recompute granularity and micro-batch first. NCCL hangs are diagnosed with NCCL_DEBUG=INFO and an explicit timeout_s so you can see which rank stalled, then check network interfaces, firewalls, and diagonal-communication failures. Divergent loss is usually an LR set too high, missing gradient clipping, or fp8 amax statistics that failed to update.

GPU clouds — Lambda, CoreWeave, Modal, Together, Crusoe, RunPod, Lepton

The GPU cloud market split into a tier-1/tier-2 shape in 2026. Tier-1 (CoreWeave, Lambda Labs, Crusoe, Together AI, Nebius) sells multi-thousand-GPU superpods connected by InfiniBand or RoCE, reserved by the hour. Tier-2 (RunPod, Vast.ai, LeptonAI, Fal) sells single nodes to small clusters by the minute. Modal lets you boot GPUs on demand from a code function and has become popular for fine-tuning and serving. Approximate listed pricing as of May 2026: H100 80GB at 1.99-3.50 USD/h, B200 at 4.50-6.50 USD/h, MI300X at 1.70-2.80 USD/h.

Cost, power, carbon — decomposing per-token training cost

Roughly, training a 70B model on 1T tokens costs (total GPU-hours) x (hourly price). On a 32-node H100 80GB DGX (256 GPUs) at 25 days, that's 256 x 25 x 24 = 153,600 GPU-hours; at 3 USD/h that's about 460,000 USD. On a single B200 NVL72 rack (72 GPUs) the same run takes about 14 days — 72 x 14 x 24 = 24,192 GPU-hours — and at 6 USD/h drops to roughly 145,000 USD. That gap is the reason new training jobs in 2026 overwhelmingly route to Blackwell. Power is around 120 kW per rack, so data-center PUE and cooling design feed directly into the budget.

Korea — LG AI Research EXAONE and Naver HyperCLOVA X

LG AI Research is widely understood to train EXAONE 3.5/4.0 on its own H100 cluster (estimated 1000+ GPUs) on a Megatron-LM-based stack. EXAONE 3.5 32B used a Korean-tuned tokenizer and additional instruction tuning over a 13B-token Korean/English mix. Naver HyperCLOVA X is trained inside a Korean IDC via a Samsung Heavy Industries / Naver Cloud collaboration, with a pipeline aligned to Korea's PIPA data-compliance regime that minimizes offshore data. KT has publicly disclosed that its trillion-parameter Mi:dm model uses a Megatron + DeepSpeed combination.

Japan — Sakana AI, Preferred Networks PFCC, and ABEJA

Sakana AI, based in Tokyo, takes an "evolutionary merge" approach to side-step training cost. Rather than from-scratch pretraining, they combine and evolve existing models, lowering infra burden significantly. Preferred Networks trains its PLaMo series on a hybrid cluster (PFCC) that mixes its own MN-Core / MN-Core 2 accelerators with H100. ABEJA has publicly published a reference architecture using AWS Trainium 2 for its Insight series. Japan's METI "AI supercomputer subsidy" program is a major driver of training-cluster buildouts in the country.

Which stack to pick — decision tree

Under 70B SFT/LoRA: torchtune + unsloth, or axolotl. 70B full-parameter pretraining: TorchTitan (FSDP2 + TP + PP) or Megatron-Core. MoE 405B+: Megatron-Core (strong EP/CP) or DeepSpeed MoE. JAX/TPU environments: MaxText or Levanter. Cluster operations: Ray Train (elastic) on top of Slurm or Kubernetes. Alignment (SFT/DPO/GRPO): trl. Fast SOTA tracking: keep Hugging Face Accelerate as a backend-agnostic wrapper.

2027 outlook — what's next

Blackwell B300, AMD MI355X, Trainium 3, and TPU v6p arrive in a tight cluster in late 2026 to early 2027. On software, three movements look likely: (1) torch.compile + FSDP2 becomes the default and DeepSpeed narrows toward ZeRO++ communication and NVMe offload niches, (2) Megatron-Core ends up embedded in more external trainers and effectively becomes the reference compute layer, and (3) JAX recovers share on GPU. Distributed-training infrastructure is shifting from a "how do we make this work" problem to a "how do we operate this" problem.

References