- Published on
Torch-Titan Complete Guide: Everything About Large-Scale Distributed Training with PyTorch
- Authors

- Name
- Youngju Kim
- @fjvbn20031
Introduction
Training large language models (LLMs) is one of the most complex challenges in modern AI engineering. Training Llama 3 70B on a single GPU is simply impossible. Dozens to thousands of GPUs must be utilized efficiently, requiring a variety of parallelization strategies.
torchtitan, developed by Meta's PyTorch team, is a reference implementation for exactly this kind of complex large-scale LLM training. Built using the latest PyTorch features in a clean and scalable way, it is designed so that AI researchers and engineers can learn and apply best practices for distributed training.
This guide covers everything about torchtitan — from theoretical foundations to hands-on installation, advanced parallelization strategies, and performance optimization.
1. Introducing Torch-Titan
What is torchtitan?
torchtitan is a production-quality reference implementation for large-scale LLM training, open-sourced by Meta's PyTorch team. You can find it on GitHub as pytorch/torchtitan.
Many existing LLM training codebases (Megatron-LM, DeepSpeed, NeMo, etc.) are feature-rich but also highly complex. torchtitan takes a different philosophy:
- Clarity: Uses only PyTorch native APIs with minimal abstraction
- Modularity: Each parallelization technique can be turned on/off independently
- Modernity: Actively leverages the latest features of PyTorch 2.x
- Reproducibility: Structure that makes reproducing training experiments easy
Supported Models
Models natively supported by torchtitan:
- Llama 2 (7B, 13B, 34B, 70B)
- Llama 3 (8B, 70B, 405B)
- Llama 3.1, 3.2 family
While Llama is the default, any transformer-based model can be added by following the structure.
Differences from Existing Frameworks
Megatron-LM (NVIDIA):
- The most mature solution for GPU clusters
- Highly optimized but tied to the NVIDIA ecosystem
- Complex codebase makes customization difficult
DeepSpeed (Microsoft):
- Famous for the ZeRO optimizer
- Supports both training and inference
- Relies on C++ custom kernels; sometimes difficult to fully integrate with PyTorch
torchtitan (Meta/PyTorch):
- Pure PyTorch native API
- Uses torch.compile, FSDP2, torch.distributed.tensor from PyTorch 2.x
- Clear structure suited for educational purposes
- Tracks PyTorch version updates quickly
Repository Structure
torchtitan/
├── torchtitan/
│ ├── models/
│ │ ├── llama/ # Llama model definitions
│ │ └── __init__.py
│ ├── parallelisms/
│ │ ├── parallelize_llama.py # Parallelism application logic
│ │ ├── pipeline_llama.py # Pipeline parallelism
│ │ └── __init__.py
│ ├── optimizer.py # Optimizer configuration
│ ├── checkpoint.py # Checkpointing
│ ├── profiling.py # Profiling
│ └── utils.py
├── train.py # Training entry point
├── train_configs/ # TOML configuration files
│ ├── llama3_8b.toml
│ ├── llama3_70b.toml
│ └── llama3_405b.toml
└── estimation.py # Memory/FLOPs estimation tool
2. A Refresher on Distributed Training Paradigms
There are four main ways to train large models across multiple GPUs. torchtitan supports all four and can even combine them simultaneously (4D parallelism).
Data Parallelism (DP)
The most basic form of parallelism. The model is copied to each GPU, and the data batch is split so that each GPU processes it independently.
GPU 0: [Batch 0-15] → Gradient 0 ┐
GPU 1: [Batch 16-31] → Gradient 1 ├→ All-Reduce → Sync
GPU 2: [Batch 32-47] → Gradient 2 ┘
PyTorch's DistributedDataParallel (DDP) is the standard implementation. However, the model must fit in GPU memory. A 70B parameter model in FP16 alone requires 140GB — impossible on a single GPU.
Tensor Parallelism (TP)
Tensor parallelism distributes individual model layers across multiple GPUs, splitting matrix operations along the column or row dimension.
Linear layer (d_model=4096, d_ff=16384) distributed across 4 GPUs:
GPU 0: columns 0-4095 (4096 x 4096)
GPU 1: columns 4096-8191
GPU 2: columns 8192-12287
GPU 3: columns 12288-16383
Each GPU processes only 1/4 of the full layer. All-Reduce or All-Gather is needed to combine results. The higher the NVLink bandwidth, the more efficient this is.
Pipeline Parallelism (PP)
Layers are distributed across GPUs in sequence. The first GPU processes the first N layers and passes its output to the next GPU.
GPU 0: Layer 0-9 → activation →
GPU 1: Layer 10-19 → activation →
GPU 2: Layer 20-29 → activation →
GPU 3: Layer 30-39 → loss
A naive implementation processes only one microbatch at a time, resulting in low GPU utilization (pipeline bubbles). GPipe, 1F1B, and Interleaved schedules address this.
Sequence Parallelism (SP)
Distributes the sequence dimension of attention layers across multiple GPUs. Solves the problem of attention matrix memory exploding as O(N²) when training with long contexts (128K+ tokens).
Sequence length 4096 distributed across 4 GPUs:
GPU 0: tokens 0-1023
GPU 1: tokens 1024-2047
GPU 2: tokens 2048-3071
GPU 3: tokens 3072-4095
Algorithms like Ring Attention compute the full attention across the distributed sequence.
4D Parallelism: Combining Everything
A core strength of torchtitan is support for 4D parallelism — combining all four strategies simultaneously.
4D parallelism example (128 GPUs):
- DP = 2 (data parallel, 2 replicas)
- TP = 8 (tensor parallel, 8 GPUs per layer)
- PP = 4 (pipeline parallel, 4 stages)
- SP = activated alongside TP
Total GPUs = DP x TP x PP = 2 x 8 x 4 = 64
(More configurations possible with SP)
Each parallelism has different communication patterns, so finding the optimal configuration for your hardware topology matters. Typically:
- TP is applied within a single server where GPUs are connected via NVLink
- PP is applied across servers
- DP is the outermost dimension
3. FSDP2 (Fully Sharded Data Parallel v2)
The Relationship Between ZeRO and FSDP
Microsoft DeepSpeed's ZeRO (Zero Redundancy Optimizer) is an innovative approach to distributing parameters, gradients, and optimizer states across multiple GPUs. PyTorch's FSDP (Fully Sharded Data Parallel) is the PyTorch-native implementation of this idea.
ZeRO stages:
- ZeRO-1: Only optimizer state is sharded
- ZeRO-2: Optimizer state + gradient sharding
- ZeRO-3: Optimizer state + gradient + parameter sharding
FSDP corresponds to ZeRO-3.
FSDP1 vs FSDP2
Up to PyTorch 2.0, torch.distributed.fsdp.FullyShardedDataParallel (FSDP1) was the standard. With PyTorch 2.4+, FSDP2 using fully_shard from torch.distributed._composable.fsdp is recommended.
Key differences:
| Feature | FSDP1 | FSDP2 |
|---|---|---|
| API style | Wrapper-based | Composable API |
| TP integration | Limited | Native |
| Memory efficiency | Good | Better |
| torch.compile | Partial | Full |
| Code readability | Complex | Clear |
# FSDP1 style (legacy)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
model = FSDP(model, ...)
# FSDP2 style (used in torchtitan)
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32
)
# Apply FSDP2 to each transformer layer
for layer in model.layers:
fully_shard(layer, mp_policy=mp_policy)
# Apply to the full model too
fully_shard(model, mp_policy=mp_policy)
How FSDP Works
FSDP behaves differently during the forward pass, backward pass, and weight update.
Forward pass:
- Before layer execution: All-Gather reconstructs sharded parameters on all GPUs
- Layer execution: Computes with full parameters
- After layer execution: Parameters are discarded (memory savings)
Backward pass:
- Before layer backprop: All-Gather reconstructs parameters
- Gradient computation
- Reduce-Scatter distributes gradients to each GPU as shards
- Parameters discarded
Weight update:
- Each GPU only updates the parameters, gradients, and optimizer state for its own shard
Example with a 70B model on 4 GPUs:
- Memory savings: 140GB weights / 4 = 35GB per GPU (+ distributed optimizer state)
- Trade-off: communication overhead (All-Gather, Reduce-Scatter)
CPU Offload
When GPU memory is insufficient, optimizer state and gradients can be offloaded to CPU RAM.
from torch.distributed._composable.fsdp import fully_shard, CPUOffloadPolicy
cpu_offload = CPUOffloadPolicy(offload_params=True)
for layer in model.layers:
fully_shard(
layer,
offload_policy=cpu_offload
)
CPU offload dramatically reduces memory usage but slows training (2-5x) due to PCIe transfers. Use as a last resort when memory is absolutely tight.
Mixed Precision with FSDP2
FSDP2 can apply different precision per layer.
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
# Standard mixed precision configuration
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, # Parameters: BF16 (communication efficiency)
reduce_dtype=torch.float32 # Gradient reduction: FP32 (stability)
)
# Keep specific layers in FP32 (e.g., embedding layer)
fully_shard(model.embed_tokens) # No mp_policy → FP32
for layer in model.layers:
fully_shard(layer, mp_policy=mp_policy)
fully_shard(model, mp_policy=mp_policy)
4. Tensor Parallelism
Tensor Parallelism in Transformers
Tensor parallelism, first systematized in Megatron-LM, is applied to two core modules in transformers.
MLP (Feed-Forward Network):
FFN(x) = GELU(x * W1) * W2
Column Parallel (split W1):
GPU 0: x * W1[0:d/4] -> GELU -> y[0:d/4]
GPU 1: x * W1[d/4:d/2] -> GELU -> y[d/4:d/2]
...
Row Parallel (split W2 + All-Reduce):
GPU 0: y[0:d/4] * W2[0:d/4, :] -> partial sum 0
GPU 1: y[d/4:d/2] * W2[d/4:d/2, :] -> partial sum 1
...
All-Reduce -> final output
Self-Attention: Query, Key, Value matrices are split by attention head.
# TP application in torchtitan (DTensor-based)
from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
RowwiseParallel,
SequenceParallel,
PrepareModuleInput,
)
# Define parallelization plan
plan = {
"attention": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"attention.wq": ColwiseParallel(),
"attention.wk": ColwiseParallel(),
"attention.wv": ColwiseParallel(),
"attention.wo": RowwiseParallel(output_layouts=Shard(1)),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
"feed_forward.w3": ColwiseParallel(),
}
parallelize_module(model_layer, tp_mesh, plan)
DTensor: The Foundation of Distributed Tensors
torchtitan implements tensor parallelism based on PyTorch's torch.distributed.tensor (DTensor). A DTensor is a new abstraction that is logically a single tensor but is physically distributed across multiple devices.
from torch.distributed.tensor import DTensor, Shard, Replicate
import torch.distributed as dist
# Create 1D mesh (for TP)
tp_mesh = dist.init_device_mesh("cuda", (8,), mesh_dim_names=("tp",))
# Create 2D mesh (DP + TP)
mesh_2d = dist.init_device_mesh(
"cuda", (2, 8), mesh_dim_names=("dp", "tp")
)
# Shard(0): row-wise sharding
# Shard(1): column-wise sharding
# Replicate(): replicated (no sharding)
Sequence Parallelism
A natural extension of tensor parallelism that parallelizes LayerNorm and Dropout operations along the sequence dimension.
Standard TP flow:
Input(replicated) -> LayerNorm(replicated) -> Attention(TP) -> [All-Reduce] -> Output(replicated)
SP + TP flow:
Input(seq-sharded) -> LayerNorm(seq-sharded) ->
[All-Gather] -> Attention(TP) -> [Reduce-Scatter] ->
Output(seq-sharded)
SP uses All-Gather + Reduce-Scatter instead of All-Reduce. This is the same total communication volume but more memory-efficient — LayerNorm activations are distributed across N GPUs, saving memory.
5. Pipeline Parallelism
The Pipeline Bubble Problem
Naive pipeline parallelism introduces significant inefficiency.
Simple pipeline (4 GPUs, 4 microbatches):
Time ->
GPU 0: [M0 F][M0 B] [ bubble ]
GPU 1: [M0 F][M0 B] [ bubble ]
GPU 2: [M0 F][M0 B][ bubble ]
GPU 3: [M0 F][M0 B]
F = Forward, B = Backward. GPU 0 waits a long time after processing M0. Bubble ratio = (PP - 1) / (micro_batches + PP - 1).
GPipe Schedule
GPipe injects multiple microbatches into the pipeline to reduce bubbles.
GPipe (4 GPUs, 4 microbatches):
Time ->
GPU 0: [M0F][M1F][M2F][M3F] [M3B][M2B][M1B][M0B]
GPU 1: [M0F][M1F][M2F][M3F] [M3B][M2B][M1B][M0B]
GPU 2: [M0F][M1F][M2F][M3F][M3B][M2B][M1B][M0B]
GPU 3: [M0F][M1F][M2F][M3F][M3B][M2B][M1B][M0B]
All forward passes complete before any backward passes. Downside: high activation memory usage.
1F1B (One Forward One Backward) Schedule
1F1B is a memory-efficient schedule. Each GPU alternates forward and backward passes on microbatches.
# Pipeline parallelism configuration in torchtitan (TOML)
# [experimental]
# pipeline_parallel_degree = 4
# pipeline_parallel_schedule = "1f1b"
# Python-level setup
from torchtitan.parallelisms.pipeline_llama import pipeline_llama
# Create pipeline stages
stages, model_parts = pipeline_llama(
model,
pp_mesh,
parallel_dims,
job_config,
device,
model_config
)
Interleaved Schedule
The interleaved schedule assigns each GPU responsibility for multiple pipeline stages. It further reduces bubbles but is more complex to implement.
Interleaved 1F1B (4 GPUs, 2 stages/GPU):
GPU 0 handles: [layers 0-4] and [layers 20-24]
GPU 1: [layers 5-9] and [layers 25-29]
...
6. Installing and Using torchtitan
System Requirements
- Python 3.10+
- PyTorch 2.5+ (latest nightly recommended)
- CUDA 12.1+
- GPU: H100 or A100 recommended (minimum 40GB VRAM)
Installation
# Clone the repository
git clone https://github.com/pytorch/torchtitan
cd torchtitan
# Install PyTorch nightly (includes latest features)
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu124
# Install dependencies
pip install -r requirements.txt
# Install torchtitan package
pip install -e .
# Download tokenizer
python torchtitan/datasets/download_tokenizer.py \
--repo_id meta-llama/Meta-Llama-3-8B \
--tokenizer_path "original" \
--hf_token YOUR_HF_TOKEN
Configuration Files (TOML)
torchtitan uses TOML-format configuration files.
# train_configs/llama3_8b.toml
[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"
[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100
[metrics]
log_freq = 10
enable_tensorboard = true
[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"
[optimizer]
name = "AdamW"
lr = 3e-4
[training]
batch_size = 8
seq_len = 2048
warmup_steps = 200
max_norm = 1.0
steps = 1000
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1 # Auto: uses remaining GPU count
tensor_parallel_degree = 1
enable_loss_parallel = false
[experimental]
pipeline_parallel_degree = 1
[checkpoint]
enable_checkpoint = true
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled"
[activation_checkpoint]
mode = "selective" # full, selective, none
selective_ac_option = "op"
[float8]
enable_float8_linear = false
Running Llama 3 8B Training
# Single node, 8 GPU training of Llama 3 8B
torchrun --nproc_per_node=8 \
train.py \
--job.config_file train_configs/llama3_8b.toml
# With TP=4, DP=2
torchrun --nproc_per_node=8 \
train.py \
--job.config_file train_configs/llama3_8b.toml \
--training.tensor_parallel_degree 4 \
--training.data_parallel_shard_degree 2
# Multi-node training (2 nodes, 8 GPUs each)
torchrun \
--nproc_per_node=8 \
--nnodes=2 \
--rdzv_id=101 \
--rdzv_backend=c10d \
--rdzv_endpoint=$MASTER_ADDR:29400 \
train.py \
--job.config_file train_configs/llama3_70b.toml
Memory/FLOPs Estimation Tool
Before starting training, estimate memory and compute requirements upfront.
# Estimate memory and FLOPs
python estimation.py \
--job.config_file train_configs/llama3_8b.toml
# Sample output:
# Estimated model size: 15.01 GB
# Estimated optimizer state size: 30.02 GB
# Total estimated GPU memory: 65.23 GB
# Estimated FLOP per step: 1.61e+14
7. Flash Attention Integration
What is Flash Attention?
Flash Attention is an attention algorithm published in 2022 by Tri Dao and colleagues at Stanford. It reduces the memory complexity of standard attention from O(N²) to O(N) and delivers 2-4x real-world speedups.
The problem with standard attention:
Standard attention:
S = Q * K^T # (seq_len, seq_len) matrix -- memory explosion!
P = softmax(S)
O = P * V
seq_len=4096: S matrix = 4096 x 4096 x 4 bytes = 64MB (FP32)
seq_len=32768: S matrix = 32768 x 32768 x 4 bytes = 4GB (single layer!)
Flash Attention's key ideas:
- Process Q, K, V in tiles (tiling)
- Compute in SRAM (shared memory) instead of HBM
- Never materialize the full attention matrix in HBM
- Numerically exact (not an approximation)
Flash Attention 2 and 3
Flash Attention 2 (2023):
- Improved attention computation parallelism
- More efficient masking
- 2-4x faster attention vs A100
Flash Attention 3 (2024):
- Tailored for H100 Hopper architecture
- Leverages WGMMA (Warpgroup Matrix Multiply Accumulate)
- 1.5-2x additional improvement over FA2 on H100
- FP8 support
Using Flash Attention in torchtitan
# Attention implementation in torchtitan/models/llama/model.py
import torch.nn.functional as F
def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
bs, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
# Apply RoPE embeddings
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
# Flash Attention via PyTorch SDPA
# torch.nn.functional.scaled_dot_product_attention
# automatically uses Flash Attention
output = F.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=None,
dropout_p=0.0,
is_causal=True, # Causal mask for language models
)
return self.wo(output.view(bs, seqlen, -1))
Since PyTorch 2.0, torch.nn.functional.scaled_dot_product_attention (SDPA) uses Flash Attention automatically. No separate package installation needed.
To use the original Flash Attention package directly:
pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func
output = flash_attn_func(
q, k, v,
dropout_p=0.0,
softmax_scale=None, # Auto: 1/sqrt(head_dim)
causal=True,
)
8. Async Tensor Parallelism
Why Async TP?
In standard tensor parallelism, All-Reduce or All-Gather/Reduce-Scatter block the next layer's computation. GPUs sit idle while waiting for communication.
Synchronous TP:
[GPU compute] -> [wait for comm] -> [GPU compute] -> [wait for comm] ...
Async TP overlaps communication and computation:
Async TP:
[GPU compute A] ────────────────────>
[comm (A result)] ─────>
[GPU compute B] ->
[comm (B)]
Async TP in torchtitan
# Enable async TP in the config file:
# [training]
# enable_async_tensor_parallel = true
# Code reference
from torchtitan.parallelisms import parallelize_llama
# async_tp internally uses torch.distributed.tensor.parallel's
# async_all_gather capability
model = parallelize_llama(
model,
world_mesh,
parallel_dims,
job_config
)
Async TP is most effective at high TP degrees (8+) through compute-communication overlap. Reports show 5-15% additional throughput gains in H100 + NVLink environments.
9. Checkpointing and Resuming
Distributed Checkpoint (dcp)
In large-scale distributed training, checkpointing is far more than saving a model. You need to save and load simultaneously across thousands of GPUs.
PyTorch's torch.distributed.checkpoint (dcp) natively supports distributed checkpointing.
# torchtitan checkpointing
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
get_optimizer_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
StateDictOptions,
)
# Save
def save_checkpoint(model, optimizer, step, output_dir):
state_dict = {
"model": get_model_state_dict(model),
"optimizer": get_optimizer_state_dict(model, optimizer),
"extra": {"step": step},
}
dcp.save(
state_dict,
checkpoint_id=f"{output_dir}/step-{step}",
)
# Load
def load_checkpoint(model, optimizer, checkpoint_path):
state_dict = {
"model": get_model_state_dict(model),
"optimizer": get_optimizer_state_dict(model, optimizer),
"extra": {},
}
dcp.load(
state_dict,
checkpoint_id=checkpoint_path,
)
set_model_state_dict(
model,
model_state_dict=state_dict["model"],
options=StateDictOptions(strict=True),
)
set_optimizer_state_dict(
model,
optimizer,
optim_state_dict=state_dict["optimizer"],
)
return state_dict["extra"]["step"]
Async Checkpointing
Checkpointing large models can take minutes, during which GPUs stall. Async checkpointing saves in the background while training continues.
# Async checkpointing configuration
[checkpoint]
async_mode = "async" # "disabled", "async", "async_with_pinned_mem"
from torchtitan.checkpoint import CheckpointManager
checkpoint_manager = CheckpointManager(
dataloader=train_dataloader,
model_parts=model_parts,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
states={"train_state": train_state},
job_config=job_config,
)
# In the training loop
for step in range(num_steps):
# ... training code ...
# Save checkpoint asynchronously (non-blocking)
checkpoint_manager.save(curr_step=step, force=False)
Checkpoint Format Conversion
Distributed training checkpoints can be converted to a single file or HuggingFace format.
# Convert distributed checkpoint to single HuggingFace model
python scripts/convert_checkpoint.py \
--checkpoint_path outputs/checkpoint/step-1000 \
--output_path outputs/hf_model \
--model_type llama3 \
--model_flavor 8B
10. Performance Profiling
Using PyTorch Profiler
To find performance bottlenecks, you first need to know where time is being spent. PyTorch Profiler is a powerful tool for this.
import torch
from torch.profiler import profile, record_function, ProfilerActivity
# Basic profiling
with profile(
activities=[
ProfilerActivity.CPU,
ProfilerActivity.CUDA,
],
record_shapes=True, # Record tensor shapes
profile_memory=True, # Record memory usage
with_stack=True, # Record call stacks
) as prof:
with record_function("model_inference"):
output = model(input)
# Print results
print(prof.key_averages().table(
sort_by="cuda_time_total",
row_limit=20
))
# Save Chrome Trace (visualize at chrome://tracing)
prof.export_chrome_trace("trace.json")
torchtitan Built-in Profiling
torchtitan controls profiling via configuration files.
[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100 # Profile every 100 steps
enable_memory_snapshot = true # Save memory snapshots
# Reference: torchtitan/profiling.py internals
from contextlib import contextmanager
import torch.profiler as profiler
@contextmanager
def maybe_enable_profiling(config, global_step=0):
if not config.profiling.enable_profiling:
yield
return
with profiler.profile(
activities=[
profiler.ProfilerActivity.CPU,
profiler.ProfilerActivity.CUDA,
],
schedule=profiler.schedule(
skip_first=10,
wait=5,
warmup=5,
active=1,
repeat=1,
),
on_trace_ready=profiler.tensorboard_trace_handler(
config.profiling.save_traces_folder
),
record_shapes=True,
profile_memory=True,
with_stack=True,
) as p:
yield p
TensorBoard Integration
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir="runs/experiment_1")
# In the training loop
for step, (input, target) in enumerate(dataloader):
loss = train_step(model, optimizer, input, target)
writer.add_scalar("Loss/train", loss.item(), step)
writer.add_scalar("LR", optimizer.param_groups[0]["lr"], step)
# GPU memory monitoring
writer.add_scalar(
"GPU/memory_allocated_GB",
torch.cuda.memory_allocated() / 1e9,
step
)
writer.add_scalar(
"GPU/memory_reserved_GB",
torch.cuda.memory_reserved() / 1e9,
step
)
# Launch TensorBoard:
# tensorboard --logdir=runs/
Memory Usage Analysis
# GPU memory snapshot
torch.cuda.memory._record_memory_history(max_entries=100000)
# Run part of training
for step in range(100):
loss = train_step(model, optimizer, batch)
# Save memory snapshot
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
# Visualize with analysis tool:
# Upload .pickle at https://pytorch.org/memory_viz
# Memory profiler for per-layer analysis
from torch.cuda._memory_viz import profile_plot
with open("memory_profile.html", "w") as f:
f.write(profile_plot(prof))
Training Efficiency Metric: MFU
MFU (Model FLOP Utilization) is the ratio of actual GPU performance to the theoretical maximum.
def compute_mfu(
model_num_params: int,
batch_size: int,
seq_len: int,
elapsed_time: float, # Time per step (seconds)
num_gpus: int,
gpu_peak_flops: float, # GPU theoretical peak FLOPS
) -> float:
"""
Compute MFU for LLM training.
Reference: PaLM paper (Chowdhery et al., 2022)
"""
# Forward FLOPs = 2 * num_params * batch_size * seq_len
# Backward is ~2x forward -> total 6 * num_params * batch_size * seq_len
flops_per_step = 6 * model_num_params * batch_size * seq_len
achieved_flops = flops_per_step / elapsed_time
peak_flops_total = gpu_peak_flops * num_gpus
return achieved_flops / peak_flops_total
# Example usage
mfu = compute_mfu(
model_num_params=8e9, # 8B params
batch_size=8,
seq_len=2048,
elapsed_time=0.5, # 0.5 seconds per step
num_gpus=8,
gpu_peak_flops=989e12, # H100 BF16: ~989 TFLOPS
)
print(f"MFU: {mfu:.1%}") # e.g., MFU: 45.2%
# Typically achievable MFU:
# - Good implementation: 40-60%
# - Optimized (torchtitan, Megatron): 50-65%
# - Theoretical max: ~70% (communication/memory overhead unavoidable)
11. Practical Training Configuration Examples
Llama 3 8B: Single Node with 8x H100
# train_configs/llama3_8b_h100x8.toml
[job]
dump_folder = "./outputs/llama3_8b"
[model]
name = "llama3"
flavor = "8B"
[training]
batch_size = 4
seq_len = 8192
data_parallel_replicate_degree = 1
data_parallel_shard_degree = 8 # FSDP2: 8 GPUs
tensor_parallel_degree = 1
[activation_checkpoint]
mode = "selective"
[float8]
enable_float8_linear = true # Enable FP8 training
[optimizer]
name = "AdamW"
lr = 1e-4
torchrun --nproc_per_node=8 train.py \
--job.config_file train_configs/llama3_8b_h100x8.toml
Llama 3 70B: 4 Nodes x 8 H100 (32 GPUs Total)
# train_configs/llama3_70b_32gpu.toml
[training]
batch_size = 2
seq_len = 4096
data_parallel_replicate_degree = 2 # DDP: 2 replicas
data_parallel_shard_degree = 4 # FSDP: 4-GPU sharding
tensor_parallel_degree = 4 # TP: 4 GPUs
# Total GPUs = 2 x 4 x 4 = 32
[experimental]
pipeline_parallel_degree = 1 # PP disabled
Llama 3 405B: Large-Scale Cluster
# train_configs/llama3_405b_large.toml
[training]
batch_size = 1
seq_len = 2048
data_parallel_replicate_degree = 4 # DDP
data_parallel_shard_degree = 8 # FSDP
tensor_parallel_degree = 8 # TP
[experimental]
pipeline_parallel_degree = 8 # PP
# Total GPUs = 4 x 8 x 8 x 8 = 2048
Conclusion: What torchtitan Teaches Us
torchtitan goes beyond being a simple training tool — it is an educational resource that clearly demonstrates the best practices of modern LLM distributed training.
Key takeaways from this guide:
- 4D Parallelism: Combining DP, TP, PP, and SP to efficiently utilize thousands of GPUs
- FSDP2: ZeRO-3-level memory efficiency with PyTorch native APIs
- Flash Attention: O(N²) memory becomes O(N); 2-4x speedup
- Async Checkpointing: Save checkpoints without stopping training
- MFU optimization: Targeting 40-65% of theoretical GPU performance
Distributed training is a complex domain where hardware, software, and algorithms intersect. torchtitan exposes this complexity in the most transparent way possible, making learning and experimentation accessible. Run the code yourself, experiment with different parallelization combinations, and develop your intuition for distributed training.
References
- torchtitan GitHub Repository
- PyTorch FSDP Documentation
- PyTorch FSDP Tutorial
- Megatron-LM: Training Multi-Billion Parameter Language Models
- FlashAttention: Fast and Memory-Efficient Exact Attention (Dao et al., 2022)
- FlashAttention-2 (Dao, 2023)
- Reducing Activation Recomputation in Large Transformer Models
- PyTorch DTensor Documentation
- PyTorch Distributed Checkpoint Documentation