Skip to content

필사 모드: Torch-Titan Complete Guide: Everything About Large-Scale Distributed Training with PyTorch

English
0%
정확도 0%
💡 왼쪽 원문을 읽으면서 오른쪽에 따라 써보세요. Tab 키로 힌트를 받을 수 있습니다.
원문 렌더가 준비되기 전까지 텍스트 가이드로 표시합니다.

Introduction

Training large language models (LLMs) is one of the most complex challenges in modern AI engineering. Training Llama 3 70B on a single GPU is simply impossible. Dozens to thousands of GPUs must be utilized efficiently, requiring a variety of parallelization strategies.

**torchtitan**, developed by Meta's PyTorch team, is a reference implementation for exactly this kind of complex large-scale LLM training. Built using the latest PyTorch features in a clean and scalable way, it is designed so that AI researchers and engineers can learn and apply best practices for distributed training.

This guide covers everything about torchtitan — from theoretical foundations to hands-on installation, advanced parallelization strategies, and performance optimization.

1. Introducing Torch-Titan

What is torchtitan?

torchtitan is a production-quality reference implementation for large-scale LLM training, open-sourced by Meta's PyTorch team. You can find it on GitHub as `pytorch/torchtitan`.

Many existing LLM training codebases (Megatron-LM, DeepSpeed, NeMo, etc.) are feature-rich but also highly complex. torchtitan takes a different philosophy:

- **Clarity**: Uses only PyTorch native APIs with minimal abstraction

- **Modularity**: Each parallelization technique can be turned on/off independently

- **Modernity**: Actively leverages the latest features of PyTorch 2.x

- **Reproducibility**: Structure that makes reproducing training experiments easy

Supported Models

Models natively supported by torchtitan:

- Llama 2 (7B, 13B, 34B, 70B)

- Llama 3 (8B, 70B, 405B)

- Llama 3.1, 3.2 family

While Llama is the default, any transformer-based model can be added by following the structure.

Differences from Existing Frameworks

**Megatron-LM (NVIDIA)**:

- The most mature solution for GPU clusters

- Highly optimized but tied to the NVIDIA ecosystem

- Complex codebase makes customization difficult

**DeepSpeed (Microsoft)**:

- Famous for the ZeRO optimizer

- Supports both training and inference

- Relies on C++ custom kernels; sometimes difficult to fully integrate with PyTorch

**torchtitan (Meta/PyTorch)**:

- Pure PyTorch native API

- Uses torch.compile, FSDP2, torch.distributed.tensor from PyTorch 2.x

- Clear structure suited for educational purposes

- Tracks PyTorch version updates quickly

Repository Structure

torchtitan/

├── torchtitan/

│ ├── models/

│ │ ├── llama/ # Llama model definitions

│ │ └── __init__.py

│ ├── parallelisms/

│ │ ├── parallelize_llama.py # Parallelism application logic

│ │ ├── pipeline_llama.py # Pipeline parallelism

│ │ └── __init__.py

│ ├── optimizer.py # Optimizer configuration

│ ├── checkpoint.py # Checkpointing

│ ├── profiling.py # Profiling

│ └── utils.py

├── train.py # Training entry point

├── train_configs/ # TOML configuration files

│ ├── llama3_8b.toml

│ ├── llama3_70b.toml

│ └── llama3_405b.toml

└── estimation.py # Memory/FLOPs estimation tool

2. A Refresher on Distributed Training Paradigms

There are four main ways to train large models across multiple GPUs. torchtitan supports all four and can even combine them simultaneously (4D parallelism).

Data Parallelism (DP)

The most basic form of parallelism. The model is copied to each GPU, and the data batch is split so that each GPU processes it independently.

GPU 0: [Batch 0-15] → Gradient 0 ┐

GPU 1: [Batch 16-31] → Gradient 1 ├→ All-Reduce → Sync

GPU 2: [Batch 32-47] → Gradient 2 ┘

PyTorch's `DistributedDataParallel` (DDP) is the standard implementation. However, the model must fit in GPU memory. A 70B parameter model in FP16 alone requires 140GB — impossible on a single GPU.

Tensor Parallelism (TP)

Tensor parallelism distributes individual model layers across multiple GPUs, splitting matrix operations along the column or row dimension.

Linear layer (d_model=4096, d_ff=16384) distributed across 4 GPUs:

GPU 0: columns 0-4095 (4096 x 4096)

GPU 1: columns 4096-8191

GPU 2: columns 8192-12287

GPU 3: columns 12288-16383

Each GPU processes only 1/4 of the full layer. All-Reduce or All-Gather is needed to combine results. The higher the NVLink bandwidth, the more efficient this is.

Pipeline Parallelism (PP)

Layers are distributed across GPUs in sequence. The first GPU processes the first N layers and passes its output to the next GPU.

GPU 0: Layer 0-9 → activation →

GPU 1: Layer 10-19 → activation →

GPU 2: Layer 20-29 → activation →

GPU 3: Layer 30-39 → loss

A naive implementation processes only one microbatch at a time, resulting in low GPU utilization (pipeline bubbles). GPipe, 1F1B, and Interleaved schedules address this.

Sequence Parallelism (SP)

Distributes the sequence dimension of attention layers across multiple GPUs. Solves the problem of attention matrix memory exploding as O(N²) when training with long contexts (128K+ tokens).

Sequence length 4096 distributed across 4 GPUs:

GPU 0: tokens 0-1023

GPU 1: tokens 1024-2047

GPU 2: tokens 2048-3071

GPU 3: tokens 3072-4095

Algorithms like Ring Attention compute the full attention across the distributed sequence.

4D Parallelism: Combining Everything

A core strength of torchtitan is support for 4D parallelism — combining all four strategies simultaneously.

4D parallelism example (128 GPUs):

- DP = 2 (data parallel, 2 replicas)

- TP = 8 (tensor parallel, 8 GPUs per layer)

- PP = 4 (pipeline parallel, 4 stages)

- SP = activated alongside TP

Total GPUs = DP x TP x PP = 2 x 8 x 4 = 64

(More configurations possible with SP)

Each parallelism has different communication patterns, so finding the optimal configuration for your hardware topology matters. Typically:

- TP is applied within a single server where GPUs are connected via NVLink

- PP is applied across servers

- DP is the outermost dimension

3. FSDP2 (Fully Sharded Data Parallel v2)

The Relationship Between ZeRO and FSDP

Microsoft DeepSpeed's ZeRO (Zero Redundancy Optimizer) is an innovative approach to distributing parameters, gradients, and optimizer states across multiple GPUs. PyTorch's FSDP (Fully Sharded Data Parallel) is the PyTorch-native implementation of this idea.

ZeRO stages:

- **ZeRO-1**: Only optimizer state is sharded

- **ZeRO-2**: Optimizer state + gradient sharding

- **ZeRO-3**: Optimizer state + gradient + parameter sharding

FSDP corresponds to ZeRO-3.

FSDP1 vs FSDP2

Up to PyTorch 2.0, `torch.distributed.fsdp.FullyShardedDataParallel` (FSDP1) was the standard. With PyTorch 2.4+, FSDP2 using `fully_shard` from `torch.distributed._composable.fsdp` is recommended.

Key differences:

| Feature | FSDP1 | FSDP2 |

| ----------------- | ------------- | -------------- |

| API style | Wrapper-based | Composable API |

| TP integration | Limited | Native |

| Memory efficiency | Good | Better |

| torch.compile | Partial | Full |

| Code readability | Complex | Clear |

FSDP1 style (legacy)

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

model = FSDP(model, ...)

FSDP2 style (used in torchtitan)

from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy

mp_policy = MixedPrecisionPolicy(

param_dtype=torch.bfloat16,

reduce_dtype=torch.float32

)

Apply FSDP2 to each transformer layer

for layer in model.layers:

fully_shard(layer, mp_policy=mp_policy)

Apply to the full model too

fully_shard(model, mp_policy=mp_policy)

How FSDP Works

FSDP behaves differently during the forward pass, backward pass, and weight update.

**Forward pass**:

1. Before layer execution: All-Gather reconstructs sharded parameters on all GPUs

2. Layer execution: Computes with full parameters

3. After layer execution: Parameters are discarded (memory savings)

**Backward pass**:

1. Before layer backprop: All-Gather reconstructs parameters

2. Gradient computation

3. Reduce-Scatter distributes gradients to each GPU as shards

4. Parameters discarded

**Weight update**:

- Each GPU only updates the parameters, gradients, and optimizer state for its own shard

Example with a 70B model on 4 GPUs:

- Memory savings: 140GB weights / 4 = 35GB per GPU (+ distributed optimizer state)

- Trade-off: communication overhead (All-Gather, Reduce-Scatter)

CPU Offload

When GPU memory is insufficient, optimizer state and gradients can be offloaded to CPU RAM.

from torch.distributed._composable.fsdp import fully_shard, CPUOffloadPolicy

cpu_offload = CPUOffloadPolicy(offload_params=True)

for layer in model.layers:

fully_shard(

layer,

offload_policy=cpu_offload

)

CPU offload dramatically reduces memory usage but slows training (2-5x) due to PCIe transfers. Use as a last resort when memory is absolutely tight.

Mixed Precision with FSDP2

FSDP2 can apply different precision per layer.

from torch.distributed._composable.fsdp import MixedPrecisionPolicy

Standard mixed precision configuration

mp_policy = MixedPrecisionPolicy(

param_dtype=torch.bfloat16, # Parameters: BF16 (communication efficiency)

reduce_dtype=torch.float32 # Gradient reduction: FP32 (stability)

)

Keep specific layers in FP32 (e.g., embedding layer)

fully_shard(model.embed_tokens) # No mp_policy → FP32

for layer in model.layers:

fully_shard(layer, mp_policy=mp_policy)

fully_shard(model, mp_policy=mp_policy)

4. Tensor Parallelism

Tensor Parallelism in Transformers

Tensor parallelism, first systematized in Megatron-LM, is applied to two core modules in transformers.

**MLP (Feed-Forward Network)**:

FFN(x) = GELU(x * W1) * W2

Column Parallel (split W1):

GPU 0: x * W1[0:d/4] -> GELU -> y[0:d/4]

GPU 1: x * W1[d/4:d/2] -> GELU -> y[d/4:d/2]

...

Row Parallel (split W2 + All-Reduce):

GPU 0: y[0:d/4] * W2[0:d/4, :] -> partial sum 0

GPU 1: y[d/4:d/2] * W2[d/4:d/2, :] -> partial sum 1

...

All-Reduce -> final output

**Self-Attention**: Query, Key, Value matrices are split by attention head.

TP application in torchtitan (DTensor-based)

from torch.distributed.tensor.parallel import (

parallelize_module,

ColwiseParallel,

RowwiseParallel,

SequenceParallel,

PrepareModuleInput,

)

Define parallelization plan

plan = {

"attention": PrepareModuleInput(

input_layouts=(Shard(1),),

desired_input_layouts=(Replicate(),),

),

"attention.wq": ColwiseParallel(),

"attention.wk": ColwiseParallel(),

"attention.wv": ColwiseParallel(),

"attention.wo": RowwiseParallel(output_layouts=Shard(1)),

"feed_forward.w1": ColwiseParallel(),

"feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),

"feed_forward.w3": ColwiseParallel(),

}

parallelize_module(model_layer, tp_mesh, plan)

DTensor: The Foundation of Distributed Tensors

torchtitan implements tensor parallelism based on PyTorch's `torch.distributed.tensor` (DTensor). A DTensor is a new abstraction that is logically a single tensor but is physically distributed across multiple devices.

from torch.distributed.tensor import DTensor, Shard, Replicate

Create 1D mesh (for TP)

tp_mesh = dist.init_device_mesh("cuda", (8,), mesh_dim_names=("tp",))

Create 2D mesh (DP + TP)

mesh_2d = dist.init_device_mesh(

"cuda", (2, 8), mesh_dim_names=("dp", "tp")

)

Shard(0): row-wise sharding

Shard(1): column-wise sharding

Replicate(): replicated (no sharding)

Sequence Parallelism

A natural extension of tensor parallelism that parallelizes LayerNorm and Dropout operations along the sequence dimension.

Standard TP flow:

Input(replicated) -> LayerNorm(replicated) -> Attention(TP) -> [All-Reduce] -> Output(replicated)

SP + TP flow:

Input(seq-sharded) -> LayerNorm(seq-sharded) ->

[All-Gather] -> Attention(TP) -> [Reduce-Scatter] ->

Output(seq-sharded)

SP uses All-Gather + Reduce-Scatter instead of All-Reduce. This is the same total communication volume but more memory-efficient — LayerNorm activations are distributed across N GPUs, saving memory.

5. Pipeline Parallelism

The Pipeline Bubble Problem

Naive pipeline parallelism introduces significant inefficiency.

Simple pipeline (4 GPUs, 4 microbatches):

Time ->

GPU 0: [M0 F][M0 B] [ bubble ]

GPU 1: [M0 F][M0 B] [ bubble ]

GPU 2: [M0 F][M0 B][ bubble ]

GPU 3: [M0 F][M0 B]

F = Forward, B = Backward. GPU 0 waits a long time after processing M0.

Bubble ratio = (PP - 1) / (micro_batches + PP - 1).

GPipe Schedule

GPipe injects multiple microbatches into the pipeline to reduce bubbles.

GPipe (4 GPUs, 4 microbatches):

Time ->

GPU 0: [M0F][M1F][M2F][M3F] [M3B][M2B][M1B][M0B]

GPU 1: [M0F][M1F][M2F][M3F] [M3B][M2B][M1B][M0B]

GPU 2: [M0F][M1F][M2F][M3F][M3B][M2B][M1B][M0B]

GPU 3: [M0F][M1F][M2F][M3F][M3B][M2B][M1B][M0B]

All forward passes complete before any backward passes. Downside: high activation memory usage.

1F1B (One Forward One Backward) Schedule

1F1B is a memory-efficient schedule. Each GPU alternates forward and backward passes on microbatches.

Pipeline parallelism configuration in torchtitan (TOML)

[experimental]

pipeline_parallel_degree = 4

pipeline_parallel_schedule = "1f1b"

Python-level setup

from torchtitan.parallelisms.pipeline_llama import pipeline_llama

Create pipeline stages

stages, model_parts = pipeline_llama(

model,

pp_mesh,

parallel_dims,

job_config,

device,

model_config

)

Interleaved Schedule

The interleaved schedule assigns each GPU responsibility for multiple pipeline stages. It further reduces bubbles but is more complex to implement.

Interleaved 1F1B (4 GPUs, 2 stages/GPU):

GPU 0 handles: [layers 0-4] and [layers 20-24]

GPU 1: [layers 5-9] and [layers 25-29]

...

6. Installing and Using torchtitan

System Requirements

- Python 3.10+

- PyTorch 2.5+ (latest nightly recommended)

- CUDA 12.1+

- GPU: H100 or A100 recommended (minimum 40GB VRAM)

Installation

Clone the repository

git clone https://github.com/pytorch/torchtitan

cd torchtitan

Install PyTorch nightly (includes latest features)

pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu124

Install dependencies

pip install -r requirements.txt

Install torchtitan package

pip install -e .

Download tokenizer

python torchtitan/datasets/download_tokenizer.py \

--repo_id meta-llama/Meta-Llama-3-8B \

--tokenizer_path "original" \

--hf_token YOUR_HF_TOKEN

Configuration Files (TOML)

torchtitan uses TOML-format configuration files.

train_configs/llama3_8b.toml

[job]

dump_folder = "./outputs"

description = "Llama 3 8B training"

[profiling]

enable_profiling = true

save_traces_folder = "profile_trace"

profile_freq = 100

[metrics]

log_freq = 10

enable_tensorboard = true

[model]

name = "llama3"

flavor = "8B"

norm_type = "rmsnorm"

tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]

name = "AdamW"

lr = 3e-4

[training]

batch_size = 8

seq_len = 2048

warmup_steps = 200

max_norm = 1.0

steps = 1000

data_parallel_replicate_degree = 1

data_parallel_shard_degree = -1 # Auto: uses remaining GPU count

tensor_parallel_degree = 1

enable_loss_parallel = false

[experimental]

pipeline_parallel_degree = 1

[checkpoint]

enable_checkpoint = true

folder = "checkpoint"

interval_type = "steps"

interval = 500

model_weights_only = false

export_dtype = "float32"

async_mode = "disabled"

[activation_checkpoint]

mode = "selective" # full, selective, none

selective_ac_option = "op"

[float8]

enable_float8_linear = false

Running Llama 3 8B Training

Single node, 8 GPU training of Llama 3 8B

torchrun --nproc_per_node=8 \

train.py \

--job.config_file train_configs/llama3_8b.toml

With TP=4, DP=2

torchrun --nproc_per_node=8 \

train.py \

--job.config_file train_configs/llama3_8b.toml \

--training.tensor_parallel_degree 4 \

--training.data_parallel_shard_degree 2

Multi-node training (2 nodes, 8 GPUs each)

torchrun \

--nproc_per_node=8 \

--nnodes=2 \

--rdzv_id=101 \

--rdzv_backend=c10d \

--rdzv_endpoint=$MASTER_ADDR:29400 \

train.py \

--job.config_file train_configs/llama3_70b.toml

Memory/FLOPs Estimation Tool

Before starting training, estimate memory and compute requirements upfront.

Estimate memory and FLOPs

python estimation.py \

--job.config_file train_configs/llama3_8b.toml

Sample output:

Estimated model size: 15.01 GB

Estimated optimizer state size: 30.02 GB

Total estimated GPU memory: 65.23 GB

Estimated FLOP per step: 1.61e+14

7. Flash Attention Integration

What is Flash Attention?

Flash Attention is an attention algorithm published in 2022 by Tri Dao and colleagues at Stanford. It reduces the memory complexity of standard attention from O(N²) to O(N) and delivers 2-4x real-world speedups.

The problem with standard attention:

Standard attention:

S = Q * K^T # (seq_len, seq_len) matrix -- memory explosion!

P = softmax(S)

O = P * V

seq_len=4096: S matrix = 4096 x 4096 x 4 bytes = 64MB (FP32)

seq_len=32768: S matrix = 32768 x 32768 x 4 bytes = 4GB (single layer!)

Flash Attention's key ideas:

1. Process Q, K, V in tiles (tiling)

2. Compute in SRAM (shared memory) instead of HBM

3. Never materialize the full attention matrix in HBM

4. Numerically exact (not an approximation)

Flash Attention 2 and 3

**Flash Attention 2** (2023):

- Improved attention computation parallelism

- More efficient masking

- 2-4x faster attention vs A100

**Flash Attention 3** (2024):

- Tailored for H100 Hopper architecture

- Leverages WGMMA (Warpgroup Matrix Multiply Accumulate)

- 1.5-2x additional improvement over FA2 on H100

- FP8 support

Using Flash Attention in torchtitan

Attention implementation in torchtitan/models/llama/model.py

def forward(

self,

x: torch.Tensor,

freqs_cis: torch.Tensor,

) -> torch.Tensor:

bs, seqlen, _ = x.shape

xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

Apply RoPE embeddings

xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

Flash Attention via PyTorch SDPA

torch.nn.functional.scaled_dot_product_attention

automatically uses Flash Attention

output = F.scaled_dot_product_attention(

xq, xk, xv,

attn_mask=None,

dropout_p=0.0,

is_causal=True, # Causal mask for language models

)

return self.wo(output.view(bs, seqlen, -1))

Since PyTorch 2.0, `torch.nn.functional.scaled_dot_product_attention` (SDPA) uses Flash Attention automatically. No separate package installation needed.

To use the original Flash Attention package directly:

pip install flash-attn --no-build-isolation

from flash_attn import flash_attn_func

output = flash_attn_func(

q, k, v,

dropout_p=0.0,

softmax_scale=None, # Auto: 1/sqrt(head_dim)

causal=True,

)

8. Async Tensor Parallelism

Why Async TP?

In standard tensor parallelism, All-Reduce or All-Gather/Reduce-Scatter block the next layer's computation. GPUs sit idle while waiting for communication.

Synchronous TP:

[GPU compute] -> [wait for comm] -> [GPU compute] -> [wait for comm] ...

Async TP overlaps communication and computation:

Async TP:

[GPU compute A] ────────────────────>

[comm (A result)] ─────>

[GPU compute B] ->

[comm (B)]

Async TP in torchtitan

Enable async TP in the config file:

[training]

enable_async_tensor_parallel = true

Code reference

from torchtitan.parallelisms import parallelize_llama

async_tp internally uses torch.distributed.tensor.parallel's

async_all_gather capability

model = parallelize_llama(

model,

world_mesh,

parallel_dims,

job_config

)

Async TP is most effective at high TP degrees (8+) through compute-communication overlap. Reports show 5-15% additional throughput gains in H100 + NVLink environments.

9. Checkpointing and Resuming

Distributed Checkpoint (dcp)

In large-scale distributed training, checkpointing is far more than saving a model. You need to save and load simultaneously across thousands of GPUs.

PyTorch's `torch.distributed.checkpoint` (dcp) natively supports distributed checkpointing.

torchtitan checkpointing

from torch.distributed.checkpoint.state_dict import (

get_model_state_dict,

get_optimizer_state_dict,

set_model_state_dict,

set_optimizer_state_dict,

StateDictOptions,

)

Save

def save_checkpoint(model, optimizer, step, output_dir):

state_dict = {

"model": get_model_state_dict(model),

"optimizer": get_optimizer_state_dict(model, optimizer),

"extra": {"step": step},

}

dcp.save(

state_dict,

checkpoint_id=f"{output_dir}/step-{step}",

)

Load

def load_checkpoint(model, optimizer, checkpoint_path):

state_dict = {

"model": get_model_state_dict(model),

"optimizer": get_optimizer_state_dict(model, optimizer),

"extra": {},

}

dcp.load(

state_dict,

checkpoint_id=checkpoint_path,

)

set_model_state_dict(

model,

model_state_dict=state_dict["model"],

options=StateDictOptions(strict=True),

)

set_optimizer_state_dict(

model,

optimizer,

optim_state_dict=state_dict["optimizer"],

)

return state_dict["extra"]["step"]

Async Checkpointing

Checkpointing large models can take minutes, during which GPUs stall. Async checkpointing saves in the background while training continues.

Async checkpointing configuration

[checkpoint]

async_mode = "async" # "disabled", "async", "async_with_pinned_mem"

from torchtitan.checkpoint import CheckpointManager

checkpoint_manager = CheckpointManager(

dataloader=train_dataloader,

model_parts=model_parts,

optimizer=optimizer,

lr_scheduler=lr_scheduler,

states={"train_state": train_state},

job_config=job_config,

)

In the training loop

for step in range(num_steps):

... training code ...

Save checkpoint asynchronously (non-blocking)

checkpoint_manager.save(curr_step=step, force=False)

Checkpoint Format Conversion

Distributed training checkpoints can be converted to a single file or HuggingFace format.

Convert distributed checkpoint to single HuggingFace model

python scripts/convert_checkpoint.py \

--checkpoint_path outputs/checkpoint/step-1000 \

--output_path outputs/hf_model \

--model_type llama3 \

--model_flavor 8B

10. Performance Profiling

Using PyTorch Profiler

To find performance bottlenecks, you first need to know where time is being spent. PyTorch Profiler is a powerful tool for this.

from torch.profiler import profile, record_function, ProfilerActivity

Basic profiling

with profile(

activities=[

ProfilerActivity.CPU,

ProfilerActivity.CUDA,

],

record_shapes=True, # Record tensor shapes

profile_memory=True, # Record memory usage

with_stack=True, # Record call stacks

) as prof:

with record_function("model_inference"):

output = model(input)

Print results

print(prof.key_averages().table(

sort_by="cuda_time_total",

row_limit=20

))

Save Chrome Trace (visualize at chrome://tracing)

prof.export_chrome_trace("trace.json")

torchtitan Built-in Profiling

torchtitan controls profiling via configuration files.

[profiling]

enable_profiling = true

save_traces_folder = "profile_trace"

profile_freq = 100 # Profile every 100 steps

enable_memory_snapshot = true # Save memory snapshots

Reference: torchtitan/profiling.py internals

from contextlib import contextmanager

@contextmanager

def maybe_enable_profiling(config, global_step=0):

if not config.profiling.enable_profiling:

yield

return

with profiler.profile(

activities=[

profiler.ProfilerActivity.CPU,

profiler.ProfilerActivity.CUDA,

],

schedule=profiler.schedule(

skip_first=10,

wait=5,

warmup=5,

active=1,

repeat=1,

),

on_trace_ready=profiler.tensorboard_trace_handler(

config.profiling.save_traces_folder

),

record_shapes=True,

profile_memory=True,

with_stack=True,

) as p:

yield p

TensorBoard Integration

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(log_dir="runs/experiment_1")

In the training loop

for step, (input, target) in enumerate(dataloader):

loss = train_step(model, optimizer, input, target)

writer.add_scalar("Loss/train", loss.item(), step)

writer.add_scalar("LR", optimizer.param_groups[0]["lr"], step)

GPU memory monitoring

writer.add_scalar(

"GPU/memory_allocated_GB",

torch.cuda.memory_allocated() / 1e9,

step

)

writer.add_scalar(

"GPU/memory_reserved_GB",

torch.cuda.memory_reserved() / 1e9,

step

)

Launch TensorBoard:

tensorboard --logdir=runs/

Memory Usage Analysis

GPU memory snapshot

torch.cuda.memory._record_memory_history(max_entries=100000)

Run part of training

for step in range(100):

loss = train_step(model, optimizer, batch)

Save memory snapshot

torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")

Visualize with analysis tool:

Upload .pickle at https://pytorch.org/memory_viz

Memory profiler for per-layer analysis

from torch.cuda._memory_viz import profile_plot

with open("memory_profile.html", "w") as f:

f.write(profile_plot(prof))

Training Efficiency Metric: MFU

MFU (Model FLOP Utilization) is the ratio of actual GPU performance to the theoretical maximum.

def compute_mfu(

model_num_params: int,

batch_size: int,

seq_len: int,

elapsed_time: float, # Time per step (seconds)

num_gpus: int,

gpu_peak_flops: float, # GPU theoretical peak FLOPS

) -> float:

"""

Compute MFU for LLM training.

Reference: PaLM paper (Chowdhery et al., 2022)

"""

Forward FLOPs = 2 * num_params * batch_size * seq_len

Backward is ~2x forward -> total 6 * num_params * batch_size * seq_len

flops_per_step = 6 * model_num_params * batch_size * seq_len

achieved_flops = flops_per_step / elapsed_time

peak_flops_total = gpu_peak_flops * num_gpus

return achieved_flops / peak_flops_total

Example usage

mfu = compute_mfu(

model_num_params=8e9, # 8B params

batch_size=8,

seq_len=2048,

elapsed_time=0.5, # 0.5 seconds per step

num_gpus=8,

gpu_peak_flops=989e12, # H100 BF16: ~989 TFLOPS

)

print(f"MFU: {mfu:.1%}") # e.g., MFU: 45.2%

Typically achievable MFU:

- Good implementation: 40-60%

- Optimized (torchtitan, Megatron): 50-65%

- Theoretical max: ~70% (communication/memory overhead unavoidable)

11. Practical Training Configuration Examples

Llama 3 8B: Single Node with 8x H100

train_configs/llama3_8b_h100x8.toml

[job]

dump_folder = "./outputs/llama3_8b"

[model]

name = "llama3"

flavor = "8B"

[training]

batch_size = 4

seq_len = 8192

data_parallel_replicate_degree = 1

data_parallel_shard_degree = 8 # FSDP2: 8 GPUs

tensor_parallel_degree = 1

[activation_checkpoint]

mode = "selective"

[float8]

enable_float8_linear = true # Enable FP8 training

[optimizer]

name = "AdamW"

lr = 1e-4

torchrun --nproc_per_node=8 train.py \

--job.config_file train_configs/llama3_8b_h100x8.toml

Llama 3 70B: 4 Nodes x 8 H100 (32 GPUs Total)

train_configs/llama3_70b_32gpu.toml

[training]

batch_size = 2

seq_len = 4096

data_parallel_replicate_degree = 2 # DDP: 2 replicas

data_parallel_shard_degree = 4 # FSDP: 4-GPU sharding

tensor_parallel_degree = 4 # TP: 4 GPUs

Total GPUs = 2 x 4 x 4 = 32

[experimental]

pipeline_parallel_degree = 1 # PP disabled

Llama 3 405B: Large-Scale Cluster

train_configs/llama3_405b_large.toml

[training]

batch_size = 1

seq_len = 2048

data_parallel_replicate_degree = 4 # DDP

data_parallel_shard_degree = 8 # FSDP

tensor_parallel_degree = 8 # TP

[experimental]

pipeline_parallel_degree = 8 # PP

Total GPUs = 4 x 8 x 8 x 8 = 2048

Conclusion: What torchtitan Teaches Us

torchtitan goes beyond being a simple training tool — it is an educational resource that clearly demonstrates the best practices of modern LLM distributed training.

Key takeaways from this guide:

1. **4D Parallelism**: Combining DP, TP, PP, and SP to efficiently utilize thousands of GPUs

2. **FSDP2**: ZeRO-3-level memory efficiency with PyTorch native APIs

3. **Flash Attention**: O(N²) memory becomes O(N); 2-4x speedup

4. **Async Checkpointing**: Save checkpoints without stopping training

5. **MFU optimization**: Targeting 40-65% of theoretical GPU performance

Distributed training is a complex domain where hardware, software, and algorithms intersect. torchtitan exposes this complexity in the most transparent way possible, making learning and experimentation accessible. Run the code yourself, experiment with different parallelization combinations, and develop your intuition for distributed training.

References

- [torchtitan GitHub Repository](https://github.com/pytorch/torchtitan)

- [PyTorch FSDP Documentation](https://pytorch.org/docs/stable/fsdp.html)

- [PyTorch FSDP Tutorial](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html)

- [Megatron-LM: Training Multi-Billion Parameter Language Models](https://arxiv.org/abs/1909.08053)

- [FlashAttention: Fast and Memory-Efficient Exact Attention (Dao et al., 2022)](https://arxiv.org/abs/2205.05198)

- [FlashAttention-2 (Dao, 2023)](https://arxiv.org/abs/2307.08691)

- [Reducing Activation Recomputation in Large Transformer Models](https://arxiv.org/abs/2205.05198)

- [PyTorch DTensor Documentation](https://pytorch.org/docs/stable/distributed.tensor.html)

- [PyTorch Distributed Checkpoint Documentation](https://pytorch.org/docs/stable/distributed.checkpoint.html)

현재 단락 (1/575)

Training large language models (LLMs) is one of the most complex challenges in modern AI engineering...

작성 글자: 0원문 글자: 25,132작성 단락: 0/575