Skip to content
Published on

Torch-Titan Complete Guide: Everything About Large-Scale Distributed Training with PyTorch

Authors

Introduction

Training large language models (LLMs) is one of the most complex challenges in modern AI engineering. Training Llama 3 70B on a single GPU is simply impossible. Dozens to thousands of GPUs must be utilized efficiently, requiring a variety of parallelization strategies.

torchtitan, developed by Meta's PyTorch team, is a reference implementation for exactly this kind of complex large-scale LLM training. Built using the latest PyTorch features in a clean and scalable way, it is designed so that AI researchers and engineers can learn and apply best practices for distributed training.

This guide covers everything about torchtitan — from theoretical foundations to hands-on installation, advanced parallelization strategies, and performance optimization.


1. Introducing Torch-Titan

What is torchtitan?

torchtitan is a production-quality reference implementation for large-scale LLM training, open-sourced by Meta's PyTorch team. You can find it on GitHub as pytorch/torchtitan.

Many existing LLM training codebases (Megatron-LM, DeepSpeed, NeMo, etc.) are feature-rich but also highly complex. torchtitan takes a different philosophy:

  • Clarity: Uses only PyTorch native APIs with minimal abstraction
  • Modularity: Each parallelization technique can be turned on/off independently
  • Modernity: Actively leverages the latest features of PyTorch 2.x
  • Reproducibility: Structure that makes reproducing training experiments easy

Supported Models

Models natively supported by torchtitan:

  • Llama 2 (7B, 13B, 34B, 70B)
  • Llama 3 (8B, 70B, 405B)
  • Llama 3.1, 3.2 family

While Llama is the default, any transformer-based model can be added by following the structure.

Differences from Existing Frameworks

Megatron-LM (NVIDIA):

  • The most mature solution for GPU clusters
  • Highly optimized but tied to the NVIDIA ecosystem
  • Complex codebase makes customization difficult

DeepSpeed (Microsoft):

  • Famous for the ZeRO optimizer
  • Supports both training and inference
  • Relies on C++ custom kernels; sometimes difficult to fully integrate with PyTorch

torchtitan (Meta/PyTorch):

  • Pure PyTorch native API
  • Uses torch.compile, FSDP2, torch.distributed.tensor from PyTorch 2.x
  • Clear structure suited for educational purposes
  • Tracks PyTorch version updates quickly

Repository Structure

torchtitan/
├── torchtitan/
│   ├── models/
│   │   ├── llama/          # Llama model definitions
│   │   └── __init__.py
│   ├── parallelisms/
│   │   ├── parallelize_llama.py  # Parallelism application logic
│   │   ├── pipeline_llama.py     # Pipeline parallelism
│   │   └── __init__.py
│   ├── optimizer.py         # Optimizer configuration
│   ├── checkpoint.py        # Checkpointing
│   ├── profiling.py         # Profiling
│   └── utils.py
├── train.py                 # Training entry point
├── train_configs/           # TOML configuration files
│   ├── llama3_8b.toml
│   ├── llama3_70b.toml
│   └── llama3_405b.toml
└── estimation.py            # Memory/FLOPs estimation tool

2. A Refresher on Distributed Training Paradigms

There are four main ways to train large models across multiple GPUs. torchtitan supports all four and can even combine them simultaneously (4D parallelism).

Data Parallelism (DP)

The most basic form of parallelism. The model is copied to each GPU, and the data batch is split so that each GPU processes it independently.

GPU 0: [Batch 0-15]Gradient 0GPU 1: [Batch 16-31]Gradient 1  ├→ All-ReduceSync
GPU 2: [Batch 32-47]Gradient 2

PyTorch's DistributedDataParallel (DDP) is the standard implementation. However, the model must fit in GPU memory. A 70B parameter model in FP16 alone requires 140GB — impossible on a single GPU.

Tensor Parallelism (TP)

Tensor parallelism distributes individual model layers across multiple GPUs, splitting matrix operations along the column or row dimension.

Linear layer (d_model=4096, d_ff=16384) distributed across 4 GPUs:

GPU 0: columns 0-4095    (4096 x 4096)
GPU 1: columns 4096-8191
GPU 2: columns 8192-12287
GPU 3: columns 12288-16383

Each GPU processes only 1/4 of the full layer. All-Reduce or All-Gather is needed to combine results. The higher the NVLink bandwidth, the more efficient this is.

Pipeline Parallelism (PP)

Layers are distributed across GPUs in sequence. The first GPU processes the first N layers and passes its output to the next GPU.

GPU 0: Layer 0-9   → activation →
GPU 1: Layer 10-19 → activation →
GPU 2: Layer 20-29 → activation →
GPU 3: Layer 30-39 → loss

A naive implementation processes only one microbatch at a time, resulting in low GPU utilization (pipeline bubbles). GPipe, 1F1B, and Interleaved schedules address this.

Sequence Parallelism (SP)

Distributes the sequence dimension of attention layers across multiple GPUs. Solves the problem of attention matrix memory exploding as O(N²) when training with long contexts (128K+ tokens).

Sequence length 4096 distributed across 4 GPUs:
GPU 0: tokens 0-1023
GPU 1: tokens 1024-2047
GPU 2: tokens 2048-3071
GPU 3: tokens 3072-4095

Algorithms like Ring Attention compute the full attention across the distributed sequence.

4D Parallelism: Combining Everything

A core strength of torchtitan is support for 4D parallelism — combining all four strategies simultaneously.

4D parallelism example (128 GPUs):
- DP = 2  (data parallel, 2 replicas)
- TP = 8  (tensor parallel, 8 GPUs per layer)
- PP = 4  (pipeline parallel, 4 stages)
- SP = activated alongside TP

Total GPUs = DP x TP x PP = 2 x 8 x 4 = 64
(More configurations possible with SP)

Each parallelism has different communication patterns, so finding the optimal configuration for your hardware topology matters. Typically:

  • TP is applied within a single server where GPUs are connected via NVLink
  • PP is applied across servers
  • DP is the outermost dimension

3. FSDP2 (Fully Sharded Data Parallel v2)

The Relationship Between ZeRO and FSDP

Microsoft DeepSpeed's ZeRO (Zero Redundancy Optimizer) is an innovative approach to distributing parameters, gradients, and optimizer states across multiple GPUs. PyTorch's FSDP (Fully Sharded Data Parallel) is the PyTorch-native implementation of this idea.

ZeRO stages:

  • ZeRO-1: Only optimizer state is sharded
  • ZeRO-2: Optimizer state + gradient sharding
  • ZeRO-3: Optimizer state + gradient + parameter sharding

FSDP corresponds to ZeRO-3.

FSDP1 vs FSDP2

Up to PyTorch 2.0, torch.distributed.fsdp.FullyShardedDataParallel (FSDP1) was the standard. With PyTorch 2.4+, FSDP2 using fully_shard from torch.distributed._composable.fsdp is recommended.

Key differences:

FeatureFSDP1FSDP2
API styleWrapper-basedComposable API
TP integrationLimitedNative
Memory efficiencyGoodBetter
torch.compilePartialFull
Code readabilityComplexClear
# FSDP1 style (legacy)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
model = FSDP(model, ...)

# FSDP2 style (used in torchtitan)
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy

mp_policy = MixedPrecisionPolicy(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.float32
)

# Apply FSDP2 to each transformer layer
for layer in model.layers:
    fully_shard(layer, mp_policy=mp_policy)

# Apply to the full model too
fully_shard(model, mp_policy=mp_policy)

How FSDP Works

FSDP behaves differently during the forward pass, backward pass, and weight update.

Forward pass:

  1. Before layer execution: All-Gather reconstructs sharded parameters on all GPUs
  2. Layer execution: Computes with full parameters
  3. After layer execution: Parameters are discarded (memory savings)

Backward pass:

  1. Before layer backprop: All-Gather reconstructs parameters
  2. Gradient computation
  3. Reduce-Scatter distributes gradients to each GPU as shards
  4. Parameters discarded

Weight update:

  • Each GPU only updates the parameters, gradients, and optimizer state for its own shard

Example with a 70B model on 4 GPUs:

  • Memory savings: 140GB weights / 4 = 35GB per GPU (+ distributed optimizer state)
  • Trade-off: communication overhead (All-Gather, Reduce-Scatter)

CPU Offload

When GPU memory is insufficient, optimizer state and gradients can be offloaded to CPU RAM.

from torch.distributed._composable.fsdp import fully_shard, CPUOffloadPolicy

cpu_offload = CPUOffloadPolicy(offload_params=True)

for layer in model.layers:
    fully_shard(
        layer,
        offload_policy=cpu_offload
    )

CPU offload dramatically reduces memory usage but slows training (2-5x) due to PCIe transfers. Use as a last resort when memory is absolutely tight.

Mixed Precision with FSDP2

FSDP2 can apply different precision per layer.

from torch.distributed._composable.fsdp import MixedPrecisionPolicy

# Standard mixed precision configuration
mp_policy = MixedPrecisionPolicy(
    param_dtype=torch.bfloat16,   # Parameters: BF16 (communication efficiency)
    reduce_dtype=torch.float32    # Gradient reduction: FP32 (stability)
)

# Keep specific layers in FP32 (e.g., embedding layer)
fully_shard(model.embed_tokens)  # No mp_policy → FP32
for layer in model.layers:
    fully_shard(layer, mp_policy=mp_policy)
fully_shard(model, mp_policy=mp_policy)

4. Tensor Parallelism

Tensor Parallelism in Transformers

Tensor parallelism, first systematized in Megatron-LM, is applied to two core modules in transformers.

MLP (Feed-Forward Network):

FFN(x) = GELU(x * W1) * W2

Column Parallel (split W1):
GPU 0: x * W1[0:d/4]   -> GELU -> y[0:d/4]
GPU 1: x * W1[d/4:d/2] -> GELU -> y[d/4:d/2]
...

Row Parallel (split W2 + All-Reduce):
GPU 0: y[0:d/4] * W2[0:d/4, :] -> partial sum 0
GPU 1: y[d/4:d/2] * W2[d/4:d/2, :] -> partial sum 1
...
All-Reduce -> final output

Self-Attention: Query, Key, Value matrices are split by attention head.

# TP application in torchtitan (DTensor-based)
from torch.distributed.tensor.parallel import (
    parallelize_module,
    ColwiseParallel,
    RowwiseParallel,
    SequenceParallel,
    PrepareModuleInput,
)

# Define parallelization plan
plan = {
    "attention": PrepareModuleInput(
        input_layouts=(Shard(1),),
        desired_input_layouts=(Replicate(),),
    ),
    "attention.wq": ColwiseParallel(),
    "attention.wk": ColwiseParallel(),
    "attention.wv": ColwiseParallel(),
    "attention.wo": RowwiseParallel(output_layouts=Shard(1)),
    "feed_forward.w1": ColwiseParallel(),
    "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
    "feed_forward.w3": ColwiseParallel(),
}

parallelize_module(model_layer, tp_mesh, plan)

DTensor: The Foundation of Distributed Tensors

torchtitan implements tensor parallelism based on PyTorch's torch.distributed.tensor (DTensor). A DTensor is a new abstraction that is logically a single tensor but is physically distributed across multiple devices.

from torch.distributed.tensor import DTensor, Shard, Replicate
import torch.distributed as dist

# Create 1D mesh (for TP)
tp_mesh = dist.init_device_mesh("cuda", (8,), mesh_dim_names=("tp",))

# Create 2D mesh (DP + TP)
mesh_2d = dist.init_device_mesh(
    "cuda", (2, 8), mesh_dim_names=("dp", "tp")
)

# Shard(0): row-wise sharding
# Shard(1): column-wise sharding
# Replicate(): replicated (no sharding)

Sequence Parallelism

A natural extension of tensor parallelism that parallelizes LayerNorm and Dropout operations along the sequence dimension.

Standard TP flow:
Input(replicated) -> LayerNorm(replicated) -> Attention(TP) -> [All-Reduce] -> Output(replicated)

SP + TP flow:
Input(seq-sharded) -> LayerNorm(seq-sharded) ->
[All-Gather] -> Attention(TP) -> [Reduce-Scatter] ->
Output(seq-sharded)

SP uses All-Gather + Reduce-Scatter instead of All-Reduce. This is the same total communication volume but more memory-efficient — LayerNorm activations are distributed across N GPUs, saving memory.


5. Pipeline Parallelism

The Pipeline Bubble Problem

Naive pipeline parallelism introduces significant inefficiency.

Simple pipeline (4 GPUs, 4 microbatches):

Time ->
GPU 0: [M0 F][M0 B]           [      bubble      ]
GPU 1:        [M0 F][M0 B]    [      bubble      ]
GPU 2:               [M0 F][M0 B][    bubble    ]
GPU 3:                     [M0 F][M0 B]

F = Forward, B = Backward. GPU 0 waits a long time after processing M0. Bubble ratio = (PP - 1) / (micro_batches + PP - 1).

GPipe Schedule

GPipe injects multiple microbatches into the pipeline to reduce bubbles.

GPipe (4 GPUs, 4 microbatches):

Time ->
GPU 0: [M0F][M1F][M2F][M3F]               [M3B][M2B][M1B][M0B]
GPU 1:      [M0F][M1F][M2F][M3F]      [M3B][M2B][M1B][M0B]
GPU 2:           [M0F][M1F][M2F][M3F][M3B][M2B][M1B][M0B]
GPU 3:                [M0F][M1F][M2F][M3F][M3B][M2B][M1B][M0B]

All forward passes complete before any backward passes. Downside: high activation memory usage.

1F1B (One Forward One Backward) Schedule

1F1B is a memory-efficient schedule. Each GPU alternates forward and backward passes on microbatches.

# Pipeline parallelism configuration in torchtitan (TOML)
# [experimental]
# pipeline_parallel_degree = 4
# pipeline_parallel_schedule = "1f1b"

# Python-level setup
from torchtitan.parallelisms.pipeline_llama import pipeline_llama

# Create pipeline stages
stages, model_parts = pipeline_llama(
    model,
    pp_mesh,
    parallel_dims,
    job_config,
    device,
    model_config
)

Interleaved Schedule

The interleaved schedule assigns each GPU responsibility for multiple pipeline stages. It further reduces bubbles but is more complex to implement.

Interleaved 1F1B (4 GPUs, 2 stages/GPU):

GPU 0 handles: [layers 0-4] and [layers 20-24]
GPU 1: [layers 5-9] and [layers 25-29]
...

6. Installing and Using torchtitan

System Requirements

  • Python 3.10+
  • PyTorch 2.5+ (latest nightly recommended)
  • CUDA 12.1+
  • GPU: H100 or A100 recommended (minimum 40GB VRAM)

Installation

# Clone the repository
git clone https://github.com/pytorch/torchtitan
cd torchtitan

# Install PyTorch nightly (includes latest features)
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu124

# Install dependencies
pip install -r requirements.txt

# Install torchtitan package
pip install -e .

# Download tokenizer
python torchtitan/datasets/download_tokenizer.py \
    --repo_id meta-llama/Meta-Llama-3-8B \
    --tokenizer_path "original" \
    --hf_token YOUR_HF_TOKEN

Configuration Files (TOML)

torchtitan uses TOML-format configuration files.

# train_configs/llama3_8b.toml

[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 10
enable_tensorboard = true

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 8
seq_len = 2048
warmup_steps = 200
max_norm = 1.0
steps = 1000
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1  # Auto: uses remaining GPU count
tensor_parallel_degree = 1
enable_loss_parallel = false

[experimental]
pipeline_parallel_degree = 1

[checkpoint]
enable_checkpoint = true
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled"

[activation_checkpoint]
mode = "selective"  # full, selective, none
selective_ac_option = "op"

[float8]
enable_float8_linear = false

Running Llama 3 8B Training

# Single node, 8 GPU training of Llama 3 8B
torchrun --nproc_per_node=8 \
    train.py \
    --job.config_file train_configs/llama3_8b.toml

# With TP=4, DP=2
torchrun --nproc_per_node=8 \
    train.py \
    --job.config_file train_configs/llama3_8b.toml \
    --training.tensor_parallel_degree 4 \
    --training.data_parallel_shard_degree 2

# Multi-node training (2 nodes, 8 GPUs each)
torchrun \
    --nproc_per_node=8 \
    --nnodes=2 \
    --rdzv_id=101 \
    --rdzv_backend=c10d \
    --rdzv_endpoint=$MASTER_ADDR:29400 \
    train.py \
    --job.config_file train_configs/llama3_70b.toml

Memory/FLOPs Estimation Tool

Before starting training, estimate memory and compute requirements upfront.

# Estimate memory and FLOPs
python estimation.py \
    --job.config_file train_configs/llama3_8b.toml

# Sample output:
# Estimated model size: 15.01 GB
# Estimated optimizer state size: 30.02 GB
# Total estimated GPU memory: 65.23 GB
# Estimated FLOP per step: 1.61e+14

7. Flash Attention Integration

What is Flash Attention?

Flash Attention is an attention algorithm published in 2022 by Tri Dao and colleagues at Stanford. It reduces the memory complexity of standard attention from O(N²) to O(N) and delivers 2-4x real-world speedups.

The problem with standard attention:

Standard attention:
S = Q * K^T    # (seq_len, seq_len) matrix -- memory explosion!
P = softmax(S)
O = P * V

seq_len=4096:  S matrix = 4096 x 4096 x 4 bytes = 64MB (FP32)
seq_len=32768: S matrix = 32768 x 32768 x 4 bytes = 4GB (single layer!)

Flash Attention's key ideas:

  1. Process Q, K, V in tiles (tiling)
  2. Compute in SRAM (shared memory) instead of HBM
  3. Never materialize the full attention matrix in HBM
  4. Numerically exact (not an approximation)

Flash Attention 2 and 3

Flash Attention 2 (2023):

  • Improved attention computation parallelism
  • More efficient masking
  • 2-4x faster attention vs A100

Flash Attention 3 (2024):

  • Tailored for H100 Hopper architecture
  • Leverages WGMMA (Warpgroup Matrix Multiply Accumulate)
  • 1.5-2x additional improvement over FA2 on H100
  • FP8 support

Using Flash Attention in torchtitan

# Attention implementation in torchtitan/models/llama/model.py
import torch.nn.functional as F

def forward(
    self,
    x: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> torch.Tensor:
    bs, seqlen, _ = x.shape
    xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

    # Apply RoPE embeddings
    xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

    # Flash Attention via PyTorch SDPA
    # torch.nn.functional.scaled_dot_product_attention
    # automatically uses Flash Attention
    output = F.scaled_dot_product_attention(
        xq, xk, xv,
        attn_mask=None,
        dropout_p=0.0,
        is_causal=True,  # Causal mask for language models
    )

    return self.wo(output.view(bs, seqlen, -1))

Since PyTorch 2.0, torch.nn.functional.scaled_dot_product_attention (SDPA) uses Flash Attention automatically. No separate package installation needed.

To use the original Flash Attention package directly:

pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func

output = flash_attn_func(
    q, k, v,
    dropout_p=0.0,
    softmax_scale=None,  # Auto: 1/sqrt(head_dim)
    causal=True,
)

8. Async Tensor Parallelism

Why Async TP?

In standard tensor parallelism, All-Reduce or All-Gather/Reduce-Scatter block the next layer's computation. GPUs sit idle while waiting for communication.

Synchronous TP:
[GPU compute] -> [wait for comm] -> [GPU compute] -> [wait for comm] ...

Async TP overlaps communication and computation:

Async TP:
[GPU compute A] ────────────────────>
[comm (A result)] ─────>
                    [GPU compute B] ->
                          [comm (B)]

Async TP in torchtitan

# Enable async TP in the config file:
# [training]
# enable_async_tensor_parallel = true

# Code reference
from torchtitan.parallelisms import parallelize_llama

# async_tp internally uses torch.distributed.tensor.parallel's
# async_all_gather capability
model = parallelize_llama(
    model,
    world_mesh,
    parallel_dims,
    job_config
)

Async TP is most effective at high TP degrees (8+) through compute-communication overlap. Reports show 5-15% additional throughput gains in H100 + NVLink environments.


9. Checkpointing and Resuming

Distributed Checkpoint (dcp)

In large-scale distributed training, checkpointing is far more than saving a model. You need to save and load simultaneously across thousands of GPUs.

PyTorch's torch.distributed.checkpoint (dcp) natively supports distributed checkpointing.

# torchtitan checkpointing
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import (
    get_model_state_dict,
    get_optimizer_state_dict,
    set_model_state_dict,
    set_optimizer_state_dict,
    StateDictOptions,
)

# Save
def save_checkpoint(model, optimizer, step, output_dir):
    state_dict = {
        "model": get_model_state_dict(model),
        "optimizer": get_optimizer_state_dict(model, optimizer),
        "extra": {"step": step},
    }

    dcp.save(
        state_dict,
        checkpoint_id=f"{output_dir}/step-{step}",
    )

# Load
def load_checkpoint(model, optimizer, checkpoint_path):
    state_dict = {
        "model": get_model_state_dict(model),
        "optimizer": get_optimizer_state_dict(model, optimizer),
        "extra": {},
    }

    dcp.load(
        state_dict,
        checkpoint_id=checkpoint_path,
    )

    set_model_state_dict(
        model,
        model_state_dict=state_dict["model"],
        options=StateDictOptions(strict=True),
    )
    set_optimizer_state_dict(
        model,
        optimizer,
        optim_state_dict=state_dict["optimizer"],
    )

    return state_dict["extra"]["step"]

Async Checkpointing

Checkpointing large models can take minutes, during which GPUs stall. Async checkpointing saves in the background while training continues.

# Async checkpointing configuration
[checkpoint]
async_mode = "async"  # "disabled", "async", "async_with_pinned_mem"
from torchtitan.checkpoint import CheckpointManager

checkpoint_manager = CheckpointManager(
    dataloader=train_dataloader,
    model_parts=model_parts,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    states={"train_state": train_state},
    job_config=job_config,
)

# In the training loop
for step in range(num_steps):
    # ... training code ...

    # Save checkpoint asynchronously (non-blocking)
    checkpoint_manager.save(curr_step=step, force=False)

Checkpoint Format Conversion

Distributed training checkpoints can be converted to a single file or HuggingFace format.

# Convert distributed checkpoint to single HuggingFace model
python scripts/convert_checkpoint.py \
    --checkpoint_path outputs/checkpoint/step-1000 \
    --output_path outputs/hf_model \
    --model_type llama3 \
    --model_flavor 8B

10. Performance Profiling

Using PyTorch Profiler

To find performance bottlenecks, you first need to know where time is being spent. PyTorch Profiler is a powerful tool for this.

import torch
from torch.profiler import profile, record_function, ProfilerActivity

# Basic profiling
with profile(
    activities=[
        ProfilerActivity.CPU,
        ProfilerActivity.CUDA,
    ],
    record_shapes=True,    # Record tensor shapes
    profile_memory=True,   # Record memory usage
    with_stack=True,       # Record call stacks
) as prof:
    with record_function("model_inference"):
        output = model(input)

# Print results
print(prof.key_averages().table(
    sort_by="cuda_time_total",
    row_limit=20
))

# Save Chrome Trace (visualize at chrome://tracing)
prof.export_chrome_trace("trace.json")

torchtitan Built-in Profiling

torchtitan controls profiling via configuration files.

[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100             # Profile every 100 steps
enable_memory_snapshot = true  # Save memory snapshots
# Reference: torchtitan/profiling.py internals
from contextlib import contextmanager
import torch.profiler as profiler

@contextmanager
def maybe_enable_profiling(config, global_step=0):
    if not config.profiling.enable_profiling:
        yield
        return

    with profiler.profile(
        activities=[
            profiler.ProfilerActivity.CPU,
            profiler.ProfilerActivity.CUDA,
        ],
        schedule=profiler.schedule(
            skip_first=10,
            wait=5,
            warmup=5,
            active=1,
            repeat=1,
        ),
        on_trace_ready=profiler.tensorboard_trace_handler(
            config.profiling.save_traces_folder
        ),
        record_shapes=True,
        profile_memory=True,
        with_stack=True,
    ) as p:
        yield p

TensorBoard Integration

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(log_dir="runs/experiment_1")

# In the training loop
for step, (input, target) in enumerate(dataloader):
    loss = train_step(model, optimizer, input, target)

    writer.add_scalar("Loss/train", loss.item(), step)
    writer.add_scalar("LR", optimizer.param_groups[0]["lr"], step)

    # GPU memory monitoring
    writer.add_scalar(
        "GPU/memory_allocated_GB",
        torch.cuda.memory_allocated() / 1e9,
        step
    )
    writer.add_scalar(
        "GPU/memory_reserved_GB",
        torch.cuda.memory_reserved() / 1e9,
        step
    )

# Launch TensorBoard:
# tensorboard --logdir=runs/

Memory Usage Analysis

# GPU memory snapshot
torch.cuda.memory._record_memory_history(max_entries=100000)

# Run part of training
for step in range(100):
    loss = train_step(model, optimizer, batch)

# Save memory snapshot
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")

# Visualize with analysis tool:
# Upload .pickle at https://pytorch.org/memory_viz
# Memory profiler for per-layer analysis
from torch.cuda._memory_viz import profile_plot

with open("memory_profile.html", "w") as f:
    f.write(profile_plot(prof))

Training Efficiency Metric: MFU

MFU (Model FLOP Utilization) is the ratio of actual GPU performance to the theoretical maximum.

def compute_mfu(
    model_num_params: int,
    batch_size: int,
    seq_len: int,
    elapsed_time: float,   # Time per step (seconds)
    num_gpus: int,
    gpu_peak_flops: float, # GPU theoretical peak FLOPS
) -> float:
    """
    Compute MFU for LLM training.
    Reference: PaLM paper (Chowdhery et al., 2022)
    """
    # Forward FLOPs = 2 * num_params * batch_size * seq_len
    # Backward is ~2x forward -> total 6 * num_params * batch_size * seq_len
    flops_per_step = 6 * model_num_params * batch_size * seq_len
    achieved_flops = flops_per_step / elapsed_time
    peak_flops_total = gpu_peak_flops * num_gpus

    return achieved_flops / peak_flops_total

# Example usage
mfu = compute_mfu(
    model_num_params=8e9,     # 8B params
    batch_size=8,
    seq_len=2048,
    elapsed_time=0.5,         # 0.5 seconds per step
    num_gpus=8,
    gpu_peak_flops=989e12,    # H100 BF16: ~989 TFLOPS
)
print(f"MFU: {mfu:.1%}")  # e.g., MFU: 45.2%

# Typically achievable MFU:
# - Good implementation: 40-60%
# - Optimized (torchtitan, Megatron): 50-65%
# - Theoretical max: ~70% (communication/memory overhead unavoidable)

11. Practical Training Configuration Examples

Llama 3 8B: Single Node with 8x H100

# train_configs/llama3_8b_h100x8.toml

[job]
dump_folder = "./outputs/llama3_8b"

[model]
name = "llama3"
flavor = "8B"

[training]
batch_size = 4
seq_len = 8192
data_parallel_replicate_degree = 1
data_parallel_shard_degree = 8  # FSDP2: 8 GPUs
tensor_parallel_degree = 1

[activation_checkpoint]
mode = "selective"

[float8]
enable_float8_linear = true  # Enable FP8 training

[optimizer]
name = "AdamW"
lr = 1e-4
torchrun --nproc_per_node=8 train.py \
    --job.config_file train_configs/llama3_8b_h100x8.toml

Llama 3 70B: 4 Nodes x 8 H100 (32 GPUs Total)

# train_configs/llama3_70b_32gpu.toml

[training]
batch_size = 2
seq_len = 4096
data_parallel_replicate_degree = 2  # DDP: 2 replicas
data_parallel_shard_degree = 4      # FSDP: 4-GPU sharding
tensor_parallel_degree = 4          # TP: 4 GPUs

# Total GPUs = 2 x 4 x 4 = 32

[experimental]
pipeline_parallel_degree = 1  # PP disabled

Llama 3 405B: Large-Scale Cluster

# train_configs/llama3_405b_large.toml

[training]
batch_size = 1
seq_len = 2048
data_parallel_replicate_degree = 4   # DDP
data_parallel_shard_degree = 8       # FSDP
tensor_parallel_degree = 8           # TP

[experimental]
pipeline_parallel_degree = 8         # PP

# Total GPUs = 4 x 8 x 8 x 8 = 2048

Conclusion: What torchtitan Teaches Us

torchtitan goes beyond being a simple training tool — it is an educational resource that clearly demonstrates the best practices of modern LLM distributed training.

Key takeaways from this guide:

  1. 4D Parallelism: Combining DP, TP, PP, and SP to efficiently utilize thousands of GPUs
  2. FSDP2: ZeRO-3-level memory efficiency with PyTorch native APIs
  3. Flash Attention: O(N²) memory becomes O(N); 2-4x speedup
  4. Async Checkpointing: Save checkpoints without stopping training
  5. MFU optimization: Targeting 40-65% of theoretical GPU performance

Distributed training is a complex domain where hardware, software, and algorithms intersect. torchtitan exposes this complexity in the most transparent way possible, making learning and experimentation accessible. Run the code yourself, experiment with different parallelization combinations, and develop your intuition for distributed training.


References