Skip to content

Split View: Torch-Titan 완전 가이드: PyTorch 대규모 분산 학습의 모든 것

|

Torch-Titan 완전 가이드: PyTorch 대규모 분산 학습의 모든 것

시작하며

대규모 언어 모델(LLM)을 학습하는 것은 현대 AI 엔지니어링의 가장 복잡한 도전 중 하나입니다. Llama 3 70B를 단일 GPU로 학습하는 것은 불가능합니다. 수십 개에서 수천 개의 GPU를 효율적으로 활용해야 하며, 이를 위해 다양한 병렬화 전략이 필요합니다.

Meta PyTorch 팀이 개발한 torchtitan은 이런 복잡한 대규모 LLM 학습을 위한 참조 구현체(reference implementation)입니다. 최신 PyTorch 기능들을 활용해 깨끗하고 확장 가능한 방식으로 구현되었으며, AI 연구자와 엔지니어가 분산 학습의 best practice를 배우고 적용할 수 있도록 설계되었습니다.

이 가이드에서는 torchtitan의 모든 것을 다룹니다. 기초 이론부터 실전 설치, 고급 병렬화 전략, 성능 최적화까지 한 번에 정리했습니다.


1. Torch-Titan 소개

torchtitan이란?

torchtitan은 Meta의 PyTorch 팀이 공개한 대규모 LLM 학습을 위한 프로덕션급 레퍼런스 구현체입니다. GitHub에서 pytorch/torchtitan으로 찾을 수 있습니다.

기존의 많은 LLM 학습 코드베이스(Megatron-LM, DeepSpeed, NeMo 등)는 기능이 풍부하지만 복잡성도 높습니다. torchtitan은 다른 철학으로 접근합니다:

  • 명료성(Clarity): PyTorch native API만 사용, 최소한의 추상화
  • 모듈성(Modularity): 각 병렬화 기법을 독립적으로 켜고 끌 수 있음
  • 현대성(Modernity): PyTorch 2.x의 최신 기능들 적극 활용
  • 재현성(Reproducibility): 학습 실험의 재현이 쉬운 구조

지원하는 모델

현재 torchtitan이 기본 지원하는 모델:

  • Llama 2 (7B, 13B, 34B, 70B)
  • Llama 3 (8B, 70B, 405B)
  • Llama 3.1, 3.2 계열

Llama 계열이 기본이지만, 트랜스포머 기반 모델이라면 구조를 따라 커스텀 모델을 쉽게 추가할 수 있습니다.

기존 프레임워크와의 차이점

Megatron-LM (NVIDIA):

  • GPU 클러스터에서 가장 성숙한 솔루션
  • 매우 최적화되어 있지만 NVIDIA 생태계에 종속
  • 코드베이스가 복잡해 커스터마이징이 어려움

DeepSpeed (Microsoft):

  • ZeRO 옵티마이저로 유명
  • 학습과 추론 모두 지원
  • C++ 커스텀 커널에 의존, PyTorch와 완전 통합이 어려운 경우 있음

torchtitan (Meta/PyTorch):

  • 순수 PyTorch native API
  • PyTorch 2.x의 torch.compile, FSDP2, torch.distributed.tensor 활용
  • 교육적 목적에도 적합한 명료한 구조
  • 빠른 PyTorch 버전 업데이트 추적

레포지토리 구조

torchtitan/
├── torchtitan/
│   ├── models/
│   │   ├── llama/          # Llama 모델 정의
│   │   └── __init__.py
│   ├── parallelisms/
│   │   ├── parallelize_llama.py  # 병렬화 적용 로직
│   │   ├── pipeline_llama.py     # 파이프라인 병렬화
│   │   └── __init__.py
│   ├── optimizer.py         # 옵티마이저 설정
│   ├── checkpoint.py        # 체크포인팅
│   ├── profiling.py         # 프로파일링
│   └── utils.py
├── train.py                 # 학습 진입점
├── train_configs/           # TOML 설정 파일들
│   ├── llama3_8b.toml
│   ├── llama3_70b.toml
│   └── llama3_405b.toml
└── estimation.py            # 메모리/FLOPs 추정 도구

2. 분산 학습 패러다임 복습

대규모 모델을 여러 GPU에 학습시키는 방법은 크게 네 가지로 나뉩니다. torchtitan은 이 네 가지를 모두 지원하고, 심지어 동시에 조합(4D 병렬화)할 수 있습니다.

데이터 병렬화 (Data Parallelism, DP)

가장 기본적인 병렬화 방식입니다. 모델을 각 GPU에 복사하고, 데이터 배치를 나눠서 각 GPU가 독립적으로 처리합니다.

GPU 0: [배치 0~15]  →  그래디언트 0GPU 1: [배치 16~31] →  그래디언트 1  ├→ All-Reduce → 동기화
GPU 2: [배치 32~47] →  그래디언트 2

PyTorch의 DistributedDataParallel(DDP)이 표준 구현입니다. 하지만 모델이 GPU 메모리에 들어가야 한다는 제약이 있습니다. 70B 파라미터 모델은 FP16로만 해도 140GB → 단일 GPU에 불가능합니다.

모델 병렬화 (Tensor Parallelism, TP)

텐서 병렬화는 모델의 개별 레이어를 여러 GPU에 걸쳐 분산합니다. 행렬 연산을 열(column) 또는 행(row) 방향으로 분할합니다.

Linear 레이어 (d_model=4096, d_ff=16384)4GPU에 분산:

GPU 0:0~4095   (4096 x 4096)
GPU 1:4096~8191
GPU 2:8192~12287
GPU 3:12288~16383

각 GPU는 전체 레이어의 1/4만 처리합니다. 결과를 합치기 위해 All-Reduce 또는 All-Gather가 필요합니다. NVLink 대역폭이 높을수록 효율적입니다.

파이프라인 병렬화 (Pipeline Parallelism, PP)

모델의 레이어를 순서대로 여러 GPU에 나눕니다. 첫 번째 GPU가 첫 N개 레이어를 처리하고, 그 출력을 다음 GPU로 전달하는 방식입니다.

GPU 0: Layer 0~9   → activation →
GPU 1: Layer 10~19 → activation →
GPU 2: Layer 20~29 → activation →
GPU 3: Layer 30~39 → loss

단순한 구현은 한 번에 하나의 마이크로배치만 처리해 GPU 활용률이 낮습니다(파이프라인 버블). GPipe, 1F1B, Interleaved 스케줄 등으로 이를 개선합니다.

시퀀스 병렬화 (Sequence Parallelism, SP)

어텐션 레이어에서 시퀀스 차원을 여러 GPU에 나눕니다. 긴 컨텍스트(128K+ 토큰) 학습 시 어텐션 행렬의 메모리가 O(n²)으로 폭발하는 문제를 해결합니다.

시퀀스 길이 40964GPU에 분산:
GPU 0: 토큰 0~1023
GPU 1: 토큰 1024~2047
GPU 2: 토큰 2048~3071
GPU 3: 토큰 3072~4095

Ring Attention 등의 알고리즘으로 전체 어텐션을 계산합니다.

4D 병렬화: 모든 것을 합치다

torchtitan의 핵심 강점은 이 네 가지 병렬화를 동시에 조합하는 4D 병렬화를 지원한다는 점입니다.

4D 병렬화 예시 (128 GPU):
- DP = 2 (데이터 병렬, 2개의 복제본)
- TP = 8 (텐서 병렬, 8GPU가 하나의 레이어)
- PP = 4 (파이프라인 병렬, 4단계 파이프라인)
- SP = TP와 함께 활성화

GPU = DP × TP × PP = 2 × 8 × 4 = 64
(또는 SP 포함시 더 많은 구성 가능)

각 병렬화는 서로 다른 통신 패턴을 가지므로, 하드웨어 토폴로지에 맞게 최적 구성을 찾는 것이 중요합니다. 일반적으로:

  • TP는 NVLink로 연결된 같은 서버 내 GPU들에 적용
  • PP는 서버 간에 적용
  • DP는 가장 외부 차원에 적용

3. FSDP2 (Fully Sharded Data Parallel v2)

ZeRO와 FSDP의 관계

Microsoft DeepSpeed의 ZeRO(Zero Redundancy Optimizer)는 파라미터, 그래디언트, 옵티마이저 상태를 여러 GPU에 분산하는 혁신적인 방법론입니다. PyTorch의 FSDP(Fully Sharded Data Parallel)는 이 아이디어를 PyTorch native로 구현한 것입니다.

ZeRO 3단계:

  • ZeRO-1: 옵티마이저 상태만 샤딩
  • ZeRO-2: 옵티마이저 상태 + 그래디언트 샤딩
  • ZeRO-3: 옵티마이저 상태 + 그래디언트 + 파라미터 샤딩

FSDP는 ZeRO-3에 해당합니다.

FSDP1과 FSDP2의 차이

PyTorch 2.0까지는 torch.distributed.fsdp.FullyShardedDataParallel(FSDP1)이었습니다. PyTorch 2.4+에서는 torch.distributed._composable.fsdpfully_shard를 사용하는 FSDP2가 권장됩니다.

주요 차이점:

특성FSDP1FSDP2
API 스타일래퍼(wrapper) 방식composable API
TP 통합제한적네이티브 통합
메모리 효율좋음더 좋음
torch.compile부분 지원완전 지원
코드 가독성복잡명료
# FSDP1 방식 (구식)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
model = FSDP(model, ...)

# FSDP2 방식 (torchtitan 사용)
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy

mp_policy = MixedPrecisionPolicy(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.float32
)

# 각 트랜스포머 레이어에 FSDP2 적용
for layer in model.layers:
    fully_shard(layer, mp_policy=mp_policy)

# 전체 모델에도 적용
fully_shard(model, mp_policy=mp_policy)

FSDP의 동작 원리

FSDP는 순전파, 역전파, 가중치 업데이트의 세 단계에서 다르게 동작합니다.

순전파 (Forward pass):

  1. 레이어 실행 전: All-Gather로 샤딩된 파라미터를 모든 GPU에서 재구성
  2. 레이어 실행: 완전한 파라미터로 계산
  3. 레이어 실행 후: 파라미터를 다시 버림 (메모리 절약)

역전파 (Backward pass):

  1. 레이어 역전파 전: All-Gather로 파라미터 재구성
  2. 그래디언트 계산
  3. Reduce-Scatter로 그래디언트를 샤딩하여 각 GPU에 분산
  4. 파라미터 버림

가중치 업데이트:

  • 각 GPU는 자신이 담당하는 샤드의 파라미터와 그래디언트, 옵티마이저 상태만 업데이트

4개 GPU에서 70B 모델 학습 예시:

  • GPU 메모리 절약: 가중치 140GB / 4 = 각 GPU 35GB (+ 옵티마이저 상태 분산)
  • 대신 통신 오버헤드 발생 (All-Gather, Reduce-Scatter)

CPU Offload

GPU 메모리가 부족할 때, 옵티마이저 상태와 그래디언트를 CPU RAM으로 오프로드할 수 있습니다.

from torch.distributed._composable.fsdp import fully_shard, CPUOffloadPolicy

# CPU 오프로드 정책
cpu_offload = CPUOffloadPolicy(offload_params=True)

for layer in model.layers:
    fully_shard(
        layer,
        offload_policy=cpu_offload
    )

CPU 오프로드는 메모리 사용량을 크게 줄이지만 PCIe 전송으로 인해 학습 속도가 느려집니다 (2-5배). 메모리가 절대적으로 부족할 때 마지막 수단으로 사용합니다.

혼합 정밀도와 FSDP2

FSDP2는 레이어별로 다른 정밀도를 적용할 수 있습니다.

from torch.distributed._composable.fsdp import MixedPrecisionPolicy

# 기본 혼합 정밀도 설정
mp_policy = MixedPrecisionPolicy(
    param_dtype=torch.bfloat16,   # 파라미터: BF16 (통신 효율)
    reduce_dtype=torch.float32    # 그래디언트 리덕션: FP32 (안정성)
)

# 특정 레이어는 FP32로 유지 (예: 임베딩 레이어)
fully_shard(model.embed_tokens)  # mp_policy 없이 → FP32
for layer in model.layers:
    fully_shard(layer, mp_policy=mp_policy)
fully_shard(model, mp_policy=mp_policy)

4. Tensor Parallelism (텐서 병렬화)

트랜스포머에서의 텐서 병렬화

Megatron-LM에서 처음 체계화된 텐서 병렬화는 트랜스포머의 두 핵심 모듈에 적용됩니다.

MLP (Feed-Forward Network):

FFN(x) = GELU(x * W1) * W2

Column Parallel (W1 분할):
GPU 0: x * W1[0:d/4]GELU → y[0:d/4]
GPU 1: x * W1[d/4:d/2]GELU → y[d/4:d/2]
...

Row Parallel (W2 분할 + All-Reduce):
GPU 0: y[0:d/4] * W2[0:d/4, :] → partial sum 0
GPU 1: y[d/4:d/2] * W2[d/4:d/2, :] → partial sum 1
...
All-Reduce → 최종 출력

Self-Attention: Query, Key, Value 행렬을 헤드 단위로 분할합니다.

# torchtitan에서의 TP 적용 (DTensor 기반)
from torch.distributed.tensor.parallel import (
    parallelize_module,
    ColwiseParallel,
    RowwiseParallel,
    SequenceParallel,
    PrepareModuleInput,
)

# 병렬화 계획 정의
plan = {
    "attention": PrepareModuleInput(
        input_layouts=(Shard(1),),
        desired_input_layouts=(Replicate(),),
    ),
    "attention.wq": ColwiseParallel(),
    "attention.wk": ColwiseParallel(),
    "attention.wv": ColwiseParallel(),
    "attention.wo": RowwiseParallel(output_layouts=Shard(1)),
    "feed_forward.w1": ColwiseParallel(),
    "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
    "feed_forward.w3": ColwiseParallel(),
}

parallelize_module(model_layer, tp_mesh, plan)

DTensor: 분산 텐서의 기반

torchtitan은 PyTorch의 torch.distributed.tensor(DTensor)를 기반으로 텐서 병렬화를 구현합니다. DTensor는 논리적으로 하나의 텐서이지만 실제로는 여러 디바이스에 분산된 새로운 추상화입니다.

from torch.distributed.tensor import DTensor, Shard, Replicate
import torch.distributed as dist

# 1D 메시 생성 (TP용)
tp_mesh = dist.init_device_mesh("cuda", (8,), mesh_dim_names=("tp",))

# 2D 메시 생성 (DP + TP)
mesh_2d = dist.init_device_mesh(
    "cuda", (2, 8), mesh_dim_names=("dp", "tp")
)

# 텐서를 특정 차원으로 샤딩
# Shard(0): 행 방향 샤딩
# Shard(1): 열 방향 샤딩
# Replicate(): 복제 (샤딩 없음)

시퀀스 병렬화 (Sequence Parallelism)

텐서 병렬화의 자연스러운 확장으로, LayerNorm과 Dropout 연산을 시퀀스 차원으로 병렬화합니다.

표준 TP 흐름:
Input(복제)LayerNorm(복제)Attention(TP)[All-Reduce]Output(복제)

SP + TP 흐름:
Input(시퀀스 샤딩)LayerNorm(시퀀스 샤딩)[All-Gather]Attention(TP)[Reduce-Scatter]Output(시퀀스 샤딩)

SP는 TP 통신과 연계되어 All-Reduce 대신 All-Gather + Reduce-Scatter를 사용합니다. 이는 동일한 양의 통신이지만 메모리 측면에서 효율적입니다. LayerNorm 활성화를 N개 GPU에 분산하여 메모리 절약.


5. Pipeline Parallelism (파이프라인 병렬화)

파이프라인 버블 문제

단순한 파이프라인 병렬화는 심각한 비효율을 초래합니다.

단순 파이프라인 (4 GPU, 4 마이크로배치):

시간 →
GPU 0: [M0 F][M0 B]      [      버블      ]
GPU 1:        [M0 F][M0 B][      버블      ]
GPU 2:               [M0 F][M0 B][  버블  ]
GPU 3:                      [M0 F][M0 B]

F = Forward, B = Backward. GPU 0은 M0을 처리한 후 오래 기다립니다. 버블 비율 = (PP - 1) / (micro_batches + PP - 1).

GPipe 스케줄

GPipe는 여러 마이크로배치를 파이프라인에 주입하여 버블을 줄입니다.

GPipe (4 GPU, 4 마이크로배치):

시간 →
GPU 0: [M0F][M1F][M2F][M3F]               [M3B][M2B][M1B][M0B]
GPU 1:      [M0F][M1F][M2F][M3F]      [M3B][M2B][M1B][M0B]
GPU 2:           [M0F][M1F][M2F][M3F][M3B][M2B][M1B][M0B]
GPU 3:                [M0F][M1F][M2F][M3F][M3B][M2B][M1B][M0B]

모든 순전파 후 역전파를 수행합니다. 활성화 메모리가 많이 필요하다는 단점.

1F1B (One Forward One Backward) 스케줄

1F1B는 메모리 효율을 개선한 스케줄입니다. 각 GPU가 마이크로배치를 번갈아가며 순전파와 역전파를 수행합니다.

# torchtitan의 파이프라인 병렬화 설정 (TOML)
# [experimental]
# pipeline_parallel_degree = 4
# pipeline_parallel_schedule = "1f1b"

# Python 코드에서의 설정
from torchtitan.parallelisms.pipeline_llama import pipeline_llama

# 파이프라인 스테이지 생성
stages, model_parts = pipeline_llama(
    model,
    pp_mesh,
    parallel_dims,
    job_config,
    device,
    model_config
)

Interleaved 스케줄

인터리브드 스케줄은 각 GPU가 여러 파이프라인 스테이지를 담당합니다. 버블을 더욱 줄일 수 있지만 구현이 복잡합니다.

Interleaved 1F1B (4 GPU, 2 스테이지/GPU):

GPU 0이 담당하는 레이어: [레이어 0~4] and [레이어 20~24]
GPU 1: [레이어 5~9] and [레이어 25~29]
...

6. torchtitan 설치 및 기본 사용

시스템 요구사항

  • Python 3.10+
  • PyTorch 2.5+ (최신 nightly 권장)
  • CUDA 12.1+
  • GPU: H100 또는 A100 권장 (최소 40GB VRAM)

설치

# 저장소 클론
git clone https://github.com/pytorch/torchtitan
cd torchtitan

# PyTorch nightly 설치 (최신 기능 포함)
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu124

# 의존성 설치
pip install -r requirements.txt

# torchtitan 패키지 설치
pip install -e .

# 학습 데이터 토크나이저 다운로드
python torchtitan/datasets/download_tokenizer.py \
    --repo_id meta-llama/Meta-Llama-3-8B \
    --tokenizer_path "original" \
    --hf_token YOUR_HF_TOKEN

설정 파일 (TOML)

torchtitan은 TOML 형식의 설정 파일을 사용합니다.

# train_configs/llama3_8b.toml

[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 10
enable_tensorboard = true

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 8
seq_len = 2048
warmup_steps = 200
max_norm = 1.0
steps = 1000
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1  # 자동으로 남은 GPU 수
tensor_parallel_degree = 1
enable_loss_parallel = false

[experimental]
pipeline_parallel_degree = 1

[checkpoint]
enable_checkpoint = true
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled"

[activation_checkpoint]
mode = "selective"  # full, selective, none
selective_ac_option = "op"  # 선택적 체크포인팅 옵션

[float8]
enable_float8_linear = false

Llama 3 8B 학습 실행

# 단일 노드, 8개 GPU로 Llama 3 8B 학습
torchrun --nproc_per_node=8 \
    train.py \
    --job.config_file train_configs/llama3_8b.toml

# TP=4, DP=2 설정으로 학습
torchrun --nproc_per_node=8 \
    train.py \
    --job.config_file train_configs/llama3_8b.toml \
    --training.tensor_parallel_degree 4 \
    --training.data_parallel_shard_degree 2

# 멀티노드 학습 (2 노드, 각 8 GPU)
torchrun \
    --nproc_per_node=8 \
    --nnodes=2 \
    --rdzv_id=101 \
    --rdzv_backend=c10d \
    --rdzv_endpoint=$MASTER_ADDR:29400 \
    train.py \
    --job.config_file train_configs/llama3_70b.toml

메모리/FLOPs 추정 도구

학습을 시작하기 전에 메모리와 연산량을 미리 추정할 수 있습니다.

# 메모리 및 FLOP 추정
python estimation.py \
    --job.config_file train_configs/llama3_8b.toml

# 출력 예시:
# Estimated model size: 15.01 GB
# Estimated optimizer state size: 30.02 GB
# Total estimated GPU memory: 65.23 GB
# Estimated FLOP per step: 1.61e+14

7. Flash Attention 통합

Flash Attention이란?

Flash Attention은 Stanford 대학의 Tri Dao 등이 2022년 발표한 어텐션 알고리즘입니다. 표준 어텐션의 메모리 복잡도를 O(N²)에서 O(N)으로 줄이고, 실제 속도도 2-4배 향상시킵니다.

표준 어텐션의 문제:

표준 어텐션 연산:
S = Q * K^T      # (seq_len, seq_len) 행렬 - 메모리 폭발!
P = softmax(S)
O = P * V

seq_len=4096: S 행렬 = 4096 × 4096 × 4 bytes = 64MB (FP32)
seq_len=32768: S 행렬 = 32768 × 32768 × 4 bytes = 4GB (단일 레이어!)

Flash Attention의 핵심 아이디어:

  1. Q, K, V를 블록 단위로 처리 (tiling)
  2. HBM 대신 SRAM(공유 메모리)에서 계산
  3. attention 행렬을 HBM에 저장하지 않음
  4. 수치적으로 정확한 결과 (근사값 아님)

Flash Attention 2와 3

Flash Attention 2 (2023):

  • 어텐션 계산 병렬화 개선
  • 마스킹 처리 효율화
  • A100 대비 2-4배 빠른 어텐션

Flash Attention 3 (2024):

  • H100 Hopper 아키텍처 특화
  • WGMMA(Warpgroup Matrix Multiply Accumulate) 활용
  • H100에서 FA2 대비 1.5-2배 추가 향상
  • FP8 지원

torchtitan에서의 Flash Attention 사용

# torchtitan/models/llama/model.py의 어텐션 구현
import torch.nn.functional as F

def forward(
    self,
    x: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> torch.Tensor:
    bs, seqlen, _ = x.shape
    xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

    # RoPE 임베딩 적용
    xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

    # Flash Attention 사용 (PyTorch 내장 SDPA)
    # torch.nn.functional.scaled_dot_product_attention은
    # 자동으로 Flash Attention을 사용
    output = F.scaled_dot_product_attention(
        xq, xk, xv,
        attn_mask=None,
        dropout_p=0.0,
        is_causal=True,  # 인과적 마스크 (언어 모델)
    )

    return self.wo(output.view(bs, seqlen, -1))

PyTorch 2.0부터 torch.nn.functional.scaled_dot_product_attention(SDPA)이 Flash Attention을 자동으로 사용합니다. 별도의 패키지 설치 없이 Flash Attention의 이점을 누릴 수 있습니다.

Flash Attention 원본 패키지를 직접 사용하려면:

pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func

# Flash Attention 직접 호출
output = flash_attn_func(
    q, k, v,
    dropout_p=0.0,
    softmax_scale=None,  # 1/sqrt(head_dim)로 자동 설정
    causal=True,
)

8. Async Tensor Parallelism

비동기 TP의 필요성

표준 텐서 병렬화에서 All-Reduce 또는 All-Gather/Reduce-Scatter는 다음 레이어 계산을 블록합니다. GPU들이 통신을 기다리는 동안 유휴(idle) 상태가 됩니다.

동기 TP:
[GPU 계산][통신 대기][GPU 계산][통신 대기] ...

비동기 TP는 통신과 계산을 겹쳐서 수행합니다:

비동기 TP:
[GPU 계산 A] ──────────────────→
[통신 (A 결과)]──────→
                  [GPU 계산 B] ─→
                       [통신(B)]

torchtitan의 Async TP 구현

# 설정 파일에서 비동기 TP 활성화
# [training]
# enable_async_tensor_parallel = true

# 코드에서 확인
from torchtitan.parallelisms import parallelize_llama

# async_tp는 내부적으로 torch.distributed.tensor.parallel의
# async_all_gather를 활용
model = parallelize_llama(
    model,
    world_mesh,
    parallel_dims,
    job_config
)

비동기 TP는 compute-communication overlap을 통해 특히 높은 TP degree (8 이상)에서 효과적입니다. H100 + NVLink 환경에서 5-15% 추가 성능 향상을 보고하는 경우가 있습니다.


9. 체크포인팅과 재시작

Distributed Checkpoint (dcp)

대규모 분산 학습에서 체크포인팅은 단순히 모델을 저장하는 것 이상입니다. 수천 개의 GPU에서 동시에 저장하고 로드해야 합니다.

PyTorch의 torch.distributed.checkpoint (dcp)는 분산 체크포인팅을 네이티브로 지원합니다.

# torchtitan의 체크포인팅
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import (
    get_model_state_dict,
    get_optimizer_state_dict,
    set_model_state_dict,
    set_optimizer_state_dict,
    StateDictOptions,
)

# 저장
def save_checkpoint(model, optimizer, step, output_dir):
    state_dict = {
        "model": get_model_state_dict(model),
        "optimizer": get_optimizer_state_dict(model, optimizer),
        "extra": {"step": step},
    }

    dcp.save(
        state_dict,
        checkpoint_id=f"{output_dir}/step-{step}",
    )

# 로드
def load_checkpoint(model, optimizer, checkpoint_path):
    state_dict = {
        "model": get_model_state_dict(model),
        "optimizer": get_optimizer_state_dict(model, optimizer),
        "extra": {},
    }

    dcp.load(
        state_dict,
        checkpoint_id=checkpoint_path,
    )

    set_model_state_dict(
        model,
        model_state_dict=state_dict["model"],
        options=StateDictOptions(strict=True),
    )
    set_optimizer_state_dict(
        model,
        optimizer,
        optim_state_dict=state_dict["optimizer"],
    )

    return state_dict["extra"]["step"]

비동기 체크포인팅

대용량 모델의 체크포인팅은 수 분이 걸릴 수 있습니다. 이 동안 GPU가 멈춥니다. 비동기 체크포인팅은 학습을 계속하면서 백그라운드에서 저장합니다.

# 비동기 체크포인팅 설정
[checkpoint]
async_mode = "async"  # "disabled", "async", "async_with_pinned_mem"
# 비동기 체크포인팅은 내부적으로 별도 스레드에서 실행
# 학습 루프는 체크포인트 저장을 기다리지 않음
# 단, 다음 체크포인트 저장 시점에 이전 저장이 완료되었는지 확인

from torchtitan.checkpoint import CheckpointManager

checkpoint_manager = CheckpointManager(
    dataloader=train_dataloader,
    model_parts=model_parts,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    states={"train_state": train_state},
    job_config=job_config,
)

# 학습 루프에서
for step in range(num_steps):
    # ... 학습 코드 ...

    # 비동기로 체크포인트 저장 (블로킹 없음)
    checkpoint_manager.save(curr_step=step, force=False)

체크포인트 형식 변환

분산 학습으로 저장된 체크포인트를 단일 파일로 변환하거나 HuggingFace 형식으로 변환할 수 있습니다.

# 분산 체크포인트를 단일 HuggingFace 모델로 변환
python scripts/convert_checkpoint.py \
    --checkpoint_path outputs/checkpoint/step-1000 \
    --output_path outputs/hf_model \
    --model_type llama3 \
    --model_flavor 8B

10. 성능 프로파일링

PyTorch Profiler 활용

성능 병목을 찾으려면 먼저 어디에 시간이 소요되는지 알아야 합니다. PyTorch Profiler는 이를 위한 강력한 도구입니다.

import torch
from torch.profiler import profile, record_function, ProfilerActivity

# 기본 프로파일링
with profile(
    activities=[
        ProfilerActivity.CPU,
        ProfilerActivity.CUDA,
    ],
    record_shapes=True,    # 텐서 shape 기록
    profile_memory=True,   # 메모리 사용량 기록
    with_stack=True,       # 콜 스택 기록
) as prof:
    with record_function("model_inference"):
        output = model(input)

# 결과 출력
print(prof.key_averages().table(
    sort_by="cuda_time_total",
    row_limit=20
))

# Chrome Trace로 저장 (chrome://tracing에서 시각화)
prof.export_chrome_trace("trace.json")

torchtitan 내장 프로파일링

torchtitan은 설정 파일로 프로파일링을 제어합니다.

[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100      # 100 스텝마다 프로파일링
enable_memory_snapshot = true  # 메모리 스냅샷 저장
# torchtitan/profiling.py 내부 구현 참고
from contextlib import contextmanager
import torch.profiler as profiler

@contextmanager
def maybe_enable_profiling(config, global_step=0):
    if not config.profiling.enable_profiling:
        yield
        return

    with profiler.profile(
        activities=[
            profiler.ProfilerActivity.CPU,
            profiler.ProfilerActivity.CUDA,
        ],
        schedule=profiler.schedule(
            skip_first=10,
            wait=5,
            warmup=5,
            active=1,
            repeat=1,
        ),
        on_trace_ready=profiler.tensorboard_trace_handler(
            config.profiling.save_traces_folder
        ),
        record_shapes=True,
        profile_memory=True,
        with_stack=True,
    ) as p:
        yield p

TensorBoard 통합

# 메트릭 로깅
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(log_dir="runs/experiment_1")

# 학습 루프에서
for step, (input, target) in enumerate(dataloader):
    loss = train_step(model, optimizer, input, target)

    writer.add_scalar("Loss/train", loss.item(), step)
    writer.add_scalar("LR", optimizer.param_groups[0]["lr"], step)

    # GPU 메모리 모니터링
    writer.add_scalar(
        "GPU/memory_allocated_GB",
        torch.cuda.memory_allocated() / 1e9,
        step
    )
    writer.add_scalar(
        "GPU/memory_reserved_GB",
        torch.cuda.memory_reserved() / 1e9,
        step
    )

# TensorBoard 실행
# tensorboard --logdir=runs/

메모리 사용량 분석

# GPU 메모리 스냅샷
torch.cuda.memory._record_memory_history(max_entries=100000)

# 학습 일부 실행
for step in range(100):
    loss = train_step(model, optimizer, batch)

# 메모리 스냅샷 저장
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")

# 분석 도구로 시각화
# https://pytorch.org/memory_viz 에서 .pickle 파일 업로드
# 메모리 프로파일러로 레이어별 메모리 사용 분석
from torch.cuda._memory_viz import profile_plot

# 프로파일 결과를 HTML로 저장
with open("memory_profile.html", "w") as f:
    f.write(profile_plot(prof))

학습 효율 지표: MFU

MFU(Model FLOP Utilization)는 실제 GPU 성능 대비 이론적 최대 성능의 비율입니다.

def compute_mfu(
    model_num_params: int,
    batch_size: int,
    seq_len: int,
    elapsed_time: float,  # 스텝 소요 시간 (초)
    num_gpus: int,
    gpu_peak_flops: float,  # GPU 이론적 최대 FLOPS
) -> float:
    """
    LLM 학습에서 MFU 계산.
    참고: PaLM 논문 (Chowdhery et al., 2022)
    """
    # 순전파 FLOPs = 2 * num_params * batch_size * seq_len
    # 역전파는 순전파의 약 2배 → 총 6 * num_params * batch_size * seq_len
    flops_per_step = 6 * model_num_params * batch_size * seq_len
    achieved_flops = flops_per_step / elapsed_time
    peak_flops_total = gpu_peak_flops * num_gpus

    return achieved_flops / peak_flops_total

# 사용 예시
mfu = compute_mfu(
    model_num_params=8e9,     # 8B 파라미터
    batch_size=8,
    seq_len=2048,
    elapsed_time=0.5,         # 0.5초/스텝
    num_gpus=8,
    gpu_peak_flops=989e12,    # H100 BF16: ~989 TFLOPS
)
print(f"MFU: {mfu:.1%}")  # 예: MFU: 45.2%

# 일반적으로 달성 가능한 MFU:
# - 좋은 구현: 40-60%
# - 최적화된 구현 (torchtitan, Megatron): 50-65%
# - 이론적 최대: ~70% (통신/메모리 오버헤드 불가피)

11. 실전 학습 설정 예시

Llama 3 8B: 단일 노드 8× H100

# train_configs/llama3_8b_h100x8.toml

[job]
dump_folder = "./outputs/llama3_8b"

[model]
name = "llama3"
flavor = "8B"

[training]
batch_size = 4
seq_len = 8192
data_parallel_replicate_degree = 1
data_parallel_shard_degree = 8  # FSDP2: 8 GPU
tensor_parallel_degree = 1

[activation_checkpoint]
mode = "selective"

[float8]
enable_float8_linear = true  # FP8 학습 활성화

[optimizer]
name = "AdamW"
lr = 1e-4
torchrun --nproc_per_node=8 train.py \
    --job.config_file train_configs/llama3_8b_h100x8.toml

Llama 3 70B: 4 노드 × 8 H100 (총 32 GPU)

# train_configs/llama3_70b_32gpu.toml

[training]
batch_size = 2
seq_len = 4096
data_parallel_replicate_degree = 2  # DDP: 2개 복제본
data_parallel_shard_degree = 4      # FSDP: 4 GPU 샤딩
tensor_parallel_degree = 4          # TP: 4 GPU

# 총 GPU = 2 × 4 × 4 = 32

[experimental]
pipeline_parallel_degree = 1  # PP 비활성화

Llama 3 405B: 대규모 클러스터

# train_configs/llama3_405b_large.toml

[training]
batch_size = 1
seq_len = 2048
data_parallel_replicate_degree = 4   # DDP
data_parallel_shard_degree = 8       # FSDP
tensor_parallel_degree = 8           # TP

[experimental]
pipeline_parallel_degree = 8         # PP

# 총 GPU = 4 × 8 × 8 × 8 = 2048

마무리: torchtitan으로 배우는 것들

torchtitan은 단순한 학습 도구를 넘어서 현대적인 LLM 분산 학습의 베스트 프랙티스를 명료하게 보여주는 교육적 리소스입니다.

이 가이드를 통해 배운 핵심 내용들:

  1. 4D 병렬화: DP, TP, PP, SP를 조합하여 수천 개의 GPU를 효율적으로 활용
  2. FSDP2: ZeRO-3 수준의 메모리 효율을 PyTorch native API로
  3. Flash Attention: O(N²) 메모리를 O(N)으로, 2-4배 속도 향상
  4. 비동기 체크포인팅: 학습 중단 없이 체크포인트 저장
  5. MFU 최적화: 이론적 GPU 성능 대비 40-65% 달성 목표

분산 학습은 하드웨어, 소프트웨어, 알고리즘이 교차하는 복잡한 영역입니다. torchtitan은 이 복잡성을 최대한 단순하게 드러내어 학습과 실험을 용이하게 합니다. 직접 코드를 실행하고, 다양한 병렬화 조합을 실험해보면서 분산 학습의 감을 익히시기 바랍니다.


참고 자료

Torch-Titan Complete Guide: Everything About Large-Scale Distributed Training with PyTorch

Introduction

Training large language models (LLMs) is one of the most complex challenges in modern AI engineering. Training Llama 3 70B on a single GPU is simply impossible. Dozens to thousands of GPUs must be utilized efficiently, requiring a variety of parallelization strategies.

torchtitan, developed by Meta's PyTorch team, is a reference implementation for exactly this kind of complex large-scale LLM training. Built using the latest PyTorch features in a clean and scalable way, it is designed so that AI researchers and engineers can learn and apply best practices for distributed training.

This guide covers everything about torchtitan — from theoretical foundations to hands-on installation, advanced parallelization strategies, and performance optimization.


1. Introducing Torch-Titan

What is torchtitan?

torchtitan is a production-quality reference implementation for large-scale LLM training, open-sourced by Meta's PyTorch team. You can find it on GitHub as pytorch/torchtitan.

Many existing LLM training codebases (Megatron-LM, DeepSpeed, NeMo, etc.) are feature-rich but also highly complex. torchtitan takes a different philosophy:

  • Clarity: Uses only PyTorch native APIs with minimal abstraction
  • Modularity: Each parallelization technique can be turned on/off independently
  • Modernity: Actively leverages the latest features of PyTorch 2.x
  • Reproducibility: Structure that makes reproducing training experiments easy

Supported Models

Models natively supported by torchtitan:

  • Llama 2 (7B, 13B, 34B, 70B)
  • Llama 3 (8B, 70B, 405B)
  • Llama 3.1, 3.2 family

While Llama is the default, any transformer-based model can be added by following the structure.

Differences from Existing Frameworks

Megatron-LM (NVIDIA):

  • The most mature solution for GPU clusters
  • Highly optimized but tied to the NVIDIA ecosystem
  • Complex codebase makes customization difficult

DeepSpeed (Microsoft):

  • Famous for the ZeRO optimizer
  • Supports both training and inference
  • Relies on C++ custom kernels; sometimes difficult to fully integrate with PyTorch

torchtitan (Meta/PyTorch):

  • Pure PyTorch native API
  • Uses torch.compile, FSDP2, torch.distributed.tensor from PyTorch 2.x
  • Clear structure suited for educational purposes
  • Tracks PyTorch version updates quickly

Repository Structure

torchtitan/
├── torchtitan/
│   ├── models/
│   │   ├── llama/          # Llama model definitions
│   │   └── __init__.py
│   ├── parallelisms/
│   │   ├── parallelize_llama.py  # Parallelism application logic
│   │   ├── pipeline_llama.py     # Pipeline parallelism
│   │   └── __init__.py
│   ├── optimizer.py         # Optimizer configuration
│   ├── checkpoint.py        # Checkpointing
│   ├── profiling.py         # Profiling
│   └── utils.py
├── train.py                 # Training entry point
├── train_configs/           # TOML configuration files
│   ├── llama3_8b.toml
│   ├── llama3_70b.toml
│   └── llama3_405b.toml
└── estimation.py            # Memory/FLOPs estimation tool

2. A Refresher on Distributed Training Paradigms

There are four main ways to train large models across multiple GPUs. torchtitan supports all four and can even combine them simultaneously (4D parallelism).

Data Parallelism (DP)

The most basic form of parallelism. The model is copied to each GPU, and the data batch is split so that each GPU processes it independently.

GPU 0: [Batch 0-15]Gradient 0GPU 1: [Batch 16-31]Gradient 1  ├→ All-ReduceSync
GPU 2: [Batch 32-47]Gradient 2

PyTorch's DistributedDataParallel (DDP) is the standard implementation. However, the model must fit in GPU memory. A 70B parameter model in FP16 alone requires 140GB — impossible on a single GPU.

Tensor Parallelism (TP)

Tensor parallelism distributes individual model layers across multiple GPUs, splitting matrix operations along the column or row dimension.

Linear layer (d_model=4096, d_ff=16384) distributed across 4 GPUs:

GPU 0: columns 0-4095    (4096 x 4096)
GPU 1: columns 4096-8191
GPU 2: columns 8192-12287
GPU 3: columns 12288-16383

Each GPU processes only 1/4 of the full layer. All-Reduce or All-Gather is needed to combine results. The higher the NVLink bandwidth, the more efficient this is.

Pipeline Parallelism (PP)

Layers are distributed across GPUs in sequence. The first GPU processes the first N layers and passes its output to the next GPU.

GPU 0: Layer 0-9   → activation →
GPU 1: Layer 10-19 → activation →
GPU 2: Layer 20-29 → activation →
GPU 3: Layer 30-39 → loss

A naive implementation processes only one microbatch at a time, resulting in low GPU utilization (pipeline bubbles). GPipe, 1F1B, and Interleaved schedules address this.

Sequence Parallelism (SP)

Distributes the sequence dimension of attention layers across multiple GPUs. Solves the problem of attention matrix memory exploding as O(N²) when training with long contexts (128K+ tokens).

Sequence length 4096 distributed across 4 GPUs:
GPU 0: tokens 0-1023
GPU 1: tokens 1024-2047
GPU 2: tokens 2048-3071
GPU 3: tokens 3072-4095

Algorithms like Ring Attention compute the full attention across the distributed sequence.

4D Parallelism: Combining Everything

A core strength of torchtitan is support for 4D parallelism — combining all four strategies simultaneously.

4D parallelism example (128 GPUs):
- DP = 2  (data parallel, 2 replicas)
- TP = 8  (tensor parallel, 8 GPUs per layer)
- PP = 4  (pipeline parallel, 4 stages)
- SP = activated alongside TP

Total GPUs = DP x TP x PP = 2 x 8 x 4 = 64
(More configurations possible with SP)

Each parallelism has different communication patterns, so finding the optimal configuration for your hardware topology matters. Typically:

  • TP is applied within a single server where GPUs are connected via NVLink
  • PP is applied across servers
  • DP is the outermost dimension

3. FSDP2 (Fully Sharded Data Parallel v2)

The Relationship Between ZeRO and FSDP

Microsoft DeepSpeed's ZeRO (Zero Redundancy Optimizer) is an innovative approach to distributing parameters, gradients, and optimizer states across multiple GPUs. PyTorch's FSDP (Fully Sharded Data Parallel) is the PyTorch-native implementation of this idea.

ZeRO stages:

  • ZeRO-1: Only optimizer state is sharded
  • ZeRO-2: Optimizer state + gradient sharding
  • ZeRO-3: Optimizer state + gradient + parameter sharding

FSDP corresponds to ZeRO-3.

FSDP1 vs FSDP2

Up to PyTorch 2.0, torch.distributed.fsdp.FullyShardedDataParallel (FSDP1) was the standard. With PyTorch 2.4+, FSDP2 using fully_shard from torch.distributed._composable.fsdp is recommended.

Key differences:

FeatureFSDP1FSDP2
API styleWrapper-basedComposable API
TP integrationLimitedNative
Memory efficiencyGoodBetter
torch.compilePartialFull
Code readabilityComplexClear
# FSDP1 style (legacy)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
model = FSDP(model, ...)

# FSDP2 style (used in torchtitan)
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy

mp_policy = MixedPrecisionPolicy(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.float32
)

# Apply FSDP2 to each transformer layer
for layer in model.layers:
    fully_shard(layer, mp_policy=mp_policy)

# Apply to the full model too
fully_shard(model, mp_policy=mp_policy)

How FSDP Works

FSDP behaves differently during the forward pass, backward pass, and weight update.

Forward pass:

  1. Before layer execution: All-Gather reconstructs sharded parameters on all GPUs
  2. Layer execution: Computes with full parameters
  3. After layer execution: Parameters are discarded (memory savings)

Backward pass:

  1. Before layer backprop: All-Gather reconstructs parameters
  2. Gradient computation
  3. Reduce-Scatter distributes gradients to each GPU as shards
  4. Parameters discarded

Weight update:

  • Each GPU only updates the parameters, gradients, and optimizer state for its own shard

Example with a 70B model on 4 GPUs:

  • Memory savings: 140GB weights / 4 = 35GB per GPU (+ distributed optimizer state)
  • Trade-off: communication overhead (All-Gather, Reduce-Scatter)

CPU Offload

When GPU memory is insufficient, optimizer state and gradients can be offloaded to CPU RAM.

from torch.distributed._composable.fsdp import fully_shard, CPUOffloadPolicy

cpu_offload = CPUOffloadPolicy(offload_params=True)

for layer in model.layers:
    fully_shard(
        layer,
        offload_policy=cpu_offload
    )

CPU offload dramatically reduces memory usage but slows training (2-5x) due to PCIe transfers. Use as a last resort when memory is absolutely tight.

Mixed Precision with FSDP2

FSDP2 can apply different precision per layer.

from torch.distributed._composable.fsdp import MixedPrecisionPolicy

# Standard mixed precision configuration
mp_policy = MixedPrecisionPolicy(
    param_dtype=torch.bfloat16,   # Parameters: BF16 (communication efficiency)
    reduce_dtype=torch.float32    # Gradient reduction: FP32 (stability)
)

# Keep specific layers in FP32 (e.g., embedding layer)
fully_shard(model.embed_tokens)  # No mp_policy → FP32
for layer in model.layers:
    fully_shard(layer, mp_policy=mp_policy)
fully_shard(model, mp_policy=mp_policy)

4. Tensor Parallelism

Tensor Parallelism in Transformers

Tensor parallelism, first systematized in Megatron-LM, is applied to two core modules in transformers.

MLP (Feed-Forward Network):

FFN(x) = GELU(x * W1) * W2

Column Parallel (split W1):
GPU 0: x * W1[0:d/4]   -> GELU -> y[0:d/4]
GPU 1: x * W1[d/4:d/2] -> GELU -> y[d/4:d/2]
...

Row Parallel (split W2 + All-Reduce):
GPU 0: y[0:d/4] * W2[0:d/4, :] -> partial sum 0
GPU 1: y[d/4:d/2] * W2[d/4:d/2, :] -> partial sum 1
...
All-Reduce -> final output

Self-Attention: Query, Key, Value matrices are split by attention head.

# TP application in torchtitan (DTensor-based)
from torch.distributed.tensor.parallel import (
    parallelize_module,
    ColwiseParallel,
    RowwiseParallel,
    SequenceParallel,
    PrepareModuleInput,
)

# Define parallelization plan
plan = {
    "attention": PrepareModuleInput(
        input_layouts=(Shard(1),),
        desired_input_layouts=(Replicate(),),
    ),
    "attention.wq": ColwiseParallel(),
    "attention.wk": ColwiseParallel(),
    "attention.wv": ColwiseParallel(),
    "attention.wo": RowwiseParallel(output_layouts=Shard(1)),
    "feed_forward.w1": ColwiseParallel(),
    "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
    "feed_forward.w3": ColwiseParallel(),
}

parallelize_module(model_layer, tp_mesh, plan)

DTensor: The Foundation of Distributed Tensors

torchtitan implements tensor parallelism based on PyTorch's torch.distributed.tensor (DTensor). A DTensor is a new abstraction that is logically a single tensor but is physically distributed across multiple devices.

from torch.distributed.tensor import DTensor, Shard, Replicate
import torch.distributed as dist

# Create 1D mesh (for TP)
tp_mesh = dist.init_device_mesh("cuda", (8,), mesh_dim_names=("tp",))

# Create 2D mesh (DP + TP)
mesh_2d = dist.init_device_mesh(
    "cuda", (2, 8), mesh_dim_names=("dp", "tp")
)

# Shard(0): row-wise sharding
# Shard(1): column-wise sharding
# Replicate(): replicated (no sharding)

Sequence Parallelism

A natural extension of tensor parallelism that parallelizes LayerNorm and Dropout operations along the sequence dimension.

Standard TP flow:
Input(replicated) -> LayerNorm(replicated) -> Attention(TP) -> [All-Reduce] -> Output(replicated)

SP + TP flow:
Input(seq-sharded) -> LayerNorm(seq-sharded) ->
[All-Gather] -> Attention(TP) -> [Reduce-Scatter] ->
Output(seq-sharded)

SP uses All-Gather + Reduce-Scatter instead of All-Reduce. This is the same total communication volume but more memory-efficient — LayerNorm activations are distributed across N GPUs, saving memory.


5. Pipeline Parallelism

The Pipeline Bubble Problem

Naive pipeline parallelism introduces significant inefficiency.

Simple pipeline (4 GPUs, 4 microbatches):

Time ->
GPU 0: [M0 F][M0 B]           [      bubble      ]
GPU 1:        [M0 F][M0 B]    [      bubble      ]
GPU 2:               [M0 F][M0 B][    bubble    ]
GPU 3:                     [M0 F][M0 B]

F = Forward, B = Backward. GPU 0 waits a long time after processing M0. Bubble ratio = (PP - 1) / (micro_batches + PP - 1).

GPipe Schedule

GPipe injects multiple microbatches into the pipeline to reduce bubbles.

GPipe (4 GPUs, 4 microbatches):

Time ->
GPU 0: [M0F][M1F][M2F][M3F]               [M3B][M2B][M1B][M0B]
GPU 1:      [M0F][M1F][M2F][M3F]      [M3B][M2B][M1B][M0B]
GPU 2:           [M0F][M1F][M2F][M3F][M3B][M2B][M1B][M0B]
GPU 3:                [M0F][M1F][M2F][M3F][M3B][M2B][M1B][M0B]

All forward passes complete before any backward passes. Downside: high activation memory usage.

1F1B (One Forward One Backward) Schedule

1F1B is a memory-efficient schedule. Each GPU alternates forward and backward passes on microbatches.

# Pipeline parallelism configuration in torchtitan (TOML)
# [experimental]
# pipeline_parallel_degree = 4
# pipeline_parallel_schedule = "1f1b"

# Python-level setup
from torchtitan.parallelisms.pipeline_llama import pipeline_llama

# Create pipeline stages
stages, model_parts = pipeline_llama(
    model,
    pp_mesh,
    parallel_dims,
    job_config,
    device,
    model_config
)

Interleaved Schedule

The interleaved schedule assigns each GPU responsibility for multiple pipeline stages. It further reduces bubbles but is more complex to implement.

Interleaved 1F1B (4 GPUs, 2 stages/GPU):

GPU 0 handles: [layers 0-4] and [layers 20-24]
GPU 1: [layers 5-9] and [layers 25-29]
...

6. Installing and Using torchtitan

System Requirements

  • Python 3.10+
  • PyTorch 2.5+ (latest nightly recommended)
  • CUDA 12.1+
  • GPU: H100 or A100 recommended (minimum 40GB VRAM)

Installation

# Clone the repository
git clone https://github.com/pytorch/torchtitan
cd torchtitan

# Install PyTorch nightly (includes latest features)
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu124

# Install dependencies
pip install -r requirements.txt

# Install torchtitan package
pip install -e .

# Download tokenizer
python torchtitan/datasets/download_tokenizer.py \
    --repo_id meta-llama/Meta-Llama-3-8B \
    --tokenizer_path "original" \
    --hf_token YOUR_HF_TOKEN

Configuration Files (TOML)

torchtitan uses TOML-format configuration files.

# train_configs/llama3_8b.toml

[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 10
enable_tensorboard = true

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 8
seq_len = 2048
warmup_steps = 200
max_norm = 1.0
steps = 1000
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1  # Auto: uses remaining GPU count
tensor_parallel_degree = 1
enable_loss_parallel = false

[experimental]
pipeline_parallel_degree = 1

[checkpoint]
enable_checkpoint = true
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled"

[activation_checkpoint]
mode = "selective"  # full, selective, none
selective_ac_option = "op"

[float8]
enable_float8_linear = false

Running Llama 3 8B Training

# Single node, 8 GPU training of Llama 3 8B
torchrun --nproc_per_node=8 \
    train.py \
    --job.config_file train_configs/llama3_8b.toml

# With TP=4, DP=2
torchrun --nproc_per_node=8 \
    train.py \
    --job.config_file train_configs/llama3_8b.toml \
    --training.tensor_parallel_degree 4 \
    --training.data_parallel_shard_degree 2

# Multi-node training (2 nodes, 8 GPUs each)
torchrun \
    --nproc_per_node=8 \
    --nnodes=2 \
    --rdzv_id=101 \
    --rdzv_backend=c10d \
    --rdzv_endpoint=$MASTER_ADDR:29400 \
    train.py \
    --job.config_file train_configs/llama3_70b.toml

Memory/FLOPs Estimation Tool

Before starting training, estimate memory and compute requirements upfront.

# Estimate memory and FLOPs
python estimation.py \
    --job.config_file train_configs/llama3_8b.toml

# Sample output:
# Estimated model size: 15.01 GB
# Estimated optimizer state size: 30.02 GB
# Total estimated GPU memory: 65.23 GB
# Estimated FLOP per step: 1.61e+14

7. Flash Attention Integration

What is Flash Attention?

Flash Attention is an attention algorithm published in 2022 by Tri Dao and colleagues at Stanford. It reduces the memory complexity of standard attention from O(N²) to O(N) and delivers 2-4x real-world speedups.

The problem with standard attention:

Standard attention:
S = Q * K^T    # (seq_len, seq_len) matrix -- memory explosion!
P = softmax(S)
O = P * V

seq_len=4096:  S matrix = 4096 x 4096 x 4 bytes = 64MB (FP32)
seq_len=32768: S matrix = 32768 x 32768 x 4 bytes = 4GB (single layer!)

Flash Attention's key ideas:

  1. Process Q, K, V in tiles (tiling)
  2. Compute in SRAM (shared memory) instead of HBM
  3. Never materialize the full attention matrix in HBM
  4. Numerically exact (not an approximation)

Flash Attention 2 and 3

Flash Attention 2 (2023):

  • Improved attention computation parallelism
  • More efficient masking
  • 2-4x faster attention vs A100

Flash Attention 3 (2024):

  • Tailored for H100 Hopper architecture
  • Leverages WGMMA (Warpgroup Matrix Multiply Accumulate)
  • 1.5-2x additional improvement over FA2 on H100
  • FP8 support

Using Flash Attention in torchtitan

# Attention implementation in torchtitan/models/llama/model.py
import torch.nn.functional as F

def forward(
    self,
    x: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> torch.Tensor:
    bs, seqlen, _ = x.shape
    xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

    # Apply RoPE embeddings
    xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

    # Flash Attention via PyTorch SDPA
    # torch.nn.functional.scaled_dot_product_attention
    # automatically uses Flash Attention
    output = F.scaled_dot_product_attention(
        xq, xk, xv,
        attn_mask=None,
        dropout_p=0.0,
        is_causal=True,  # Causal mask for language models
    )

    return self.wo(output.view(bs, seqlen, -1))

Since PyTorch 2.0, torch.nn.functional.scaled_dot_product_attention (SDPA) uses Flash Attention automatically. No separate package installation needed.

To use the original Flash Attention package directly:

pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func

output = flash_attn_func(
    q, k, v,
    dropout_p=0.0,
    softmax_scale=None,  # Auto: 1/sqrt(head_dim)
    causal=True,
)

8. Async Tensor Parallelism

Why Async TP?

In standard tensor parallelism, All-Reduce or All-Gather/Reduce-Scatter block the next layer's computation. GPUs sit idle while waiting for communication.

Synchronous TP:
[GPU compute] -> [wait for comm] -> [GPU compute] -> [wait for comm] ...

Async TP overlaps communication and computation:

Async TP:
[GPU compute A] ────────────────────>
[comm (A result)] ─────>
                    [GPU compute B] ->
                          [comm (B)]

Async TP in torchtitan

# Enable async TP in the config file:
# [training]
# enable_async_tensor_parallel = true

# Code reference
from torchtitan.parallelisms import parallelize_llama

# async_tp internally uses torch.distributed.tensor.parallel's
# async_all_gather capability
model = parallelize_llama(
    model,
    world_mesh,
    parallel_dims,
    job_config
)

Async TP is most effective at high TP degrees (8+) through compute-communication overlap. Reports show 5-15% additional throughput gains in H100 + NVLink environments.


9. Checkpointing and Resuming

Distributed Checkpoint (dcp)

In large-scale distributed training, checkpointing is far more than saving a model. You need to save and load simultaneously across thousands of GPUs.

PyTorch's torch.distributed.checkpoint (dcp) natively supports distributed checkpointing.

# torchtitan checkpointing
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import (
    get_model_state_dict,
    get_optimizer_state_dict,
    set_model_state_dict,
    set_optimizer_state_dict,
    StateDictOptions,
)

# Save
def save_checkpoint(model, optimizer, step, output_dir):
    state_dict = {
        "model": get_model_state_dict(model),
        "optimizer": get_optimizer_state_dict(model, optimizer),
        "extra": {"step": step},
    }

    dcp.save(
        state_dict,
        checkpoint_id=f"{output_dir}/step-{step}",
    )

# Load
def load_checkpoint(model, optimizer, checkpoint_path):
    state_dict = {
        "model": get_model_state_dict(model),
        "optimizer": get_optimizer_state_dict(model, optimizer),
        "extra": {},
    }

    dcp.load(
        state_dict,
        checkpoint_id=checkpoint_path,
    )

    set_model_state_dict(
        model,
        model_state_dict=state_dict["model"],
        options=StateDictOptions(strict=True),
    )
    set_optimizer_state_dict(
        model,
        optimizer,
        optim_state_dict=state_dict["optimizer"],
    )

    return state_dict["extra"]["step"]

Async Checkpointing

Checkpointing large models can take minutes, during which GPUs stall. Async checkpointing saves in the background while training continues.

# Async checkpointing configuration
[checkpoint]
async_mode = "async"  # "disabled", "async", "async_with_pinned_mem"
from torchtitan.checkpoint import CheckpointManager

checkpoint_manager = CheckpointManager(
    dataloader=train_dataloader,
    model_parts=model_parts,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    states={"train_state": train_state},
    job_config=job_config,
)

# In the training loop
for step in range(num_steps):
    # ... training code ...

    # Save checkpoint asynchronously (non-blocking)
    checkpoint_manager.save(curr_step=step, force=False)

Checkpoint Format Conversion

Distributed training checkpoints can be converted to a single file or HuggingFace format.

# Convert distributed checkpoint to single HuggingFace model
python scripts/convert_checkpoint.py \
    --checkpoint_path outputs/checkpoint/step-1000 \
    --output_path outputs/hf_model \
    --model_type llama3 \
    --model_flavor 8B

10. Performance Profiling

Using PyTorch Profiler

To find performance bottlenecks, you first need to know where time is being spent. PyTorch Profiler is a powerful tool for this.

import torch
from torch.profiler import profile, record_function, ProfilerActivity

# Basic profiling
with profile(
    activities=[
        ProfilerActivity.CPU,
        ProfilerActivity.CUDA,
    ],
    record_shapes=True,    # Record tensor shapes
    profile_memory=True,   # Record memory usage
    with_stack=True,       # Record call stacks
) as prof:
    with record_function("model_inference"):
        output = model(input)

# Print results
print(prof.key_averages().table(
    sort_by="cuda_time_total",
    row_limit=20
))

# Save Chrome Trace (visualize at chrome://tracing)
prof.export_chrome_trace("trace.json")

torchtitan Built-in Profiling

torchtitan controls profiling via configuration files.

[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100             # Profile every 100 steps
enable_memory_snapshot = true  # Save memory snapshots
# Reference: torchtitan/profiling.py internals
from contextlib import contextmanager
import torch.profiler as profiler

@contextmanager
def maybe_enable_profiling(config, global_step=0):
    if not config.profiling.enable_profiling:
        yield
        return

    with profiler.profile(
        activities=[
            profiler.ProfilerActivity.CPU,
            profiler.ProfilerActivity.CUDA,
        ],
        schedule=profiler.schedule(
            skip_first=10,
            wait=5,
            warmup=5,
            active=1,
            repeat=1,
        ),
        on_trace_ready=profiler.tensorboard_trace_handler(
            config.profiling.save_traces_folder
        ),
        record_shapes=True,
        profile_memory=True,
        with_stack=True,
    ) as p:
        yield p

TensorBoard Integration

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(log_dir="runs/experiment_1")

# In the training loop
for step, (input, target) in enumerate(dataloader):
    loss = train_step(model, optimizer, input, target)

    writer.add_scalar("Loss/train", loss.item(), step)
    writer.add_scalar("LR", optimizer.param_groups[0]["lr"], step)

    # GPU memory monitoring
    writer.add_scalar(
        "GPU/memory_allocated_GB",
        torch.cuda.memory_allocated() / 1e9,
        step
    )
    writer.add_scalar(
        "GPU/memory_reserved_GB",
        torch.cuda.memory_reserved() / 1e9,
        step
    )

# Launch TensorBoard:
# tensorboard --logdir=runs/

Memory Usage Analysis

# GPU memory snapshot
torch.cuda.memory._record_memory_history(max_entries=100000)

# Run part of training
for step in range(100):
    loss = train_step(model, optimizer, batch)

# Save memory snapshot
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")

# Visualize with analysis tool:
# Upload .pickle at https://pytorch.org/memory_viz
# Memory profiler for per-layer analysis
from torch.cuda._memory_viz import profile_plot

with open("memory_profile.html", "w") as f:
    f.write(profile_plot(prof))

Training Efficiency Metric: MFU

MFU (Model FLOP Utilization) is the ratio of actual GPU performance to the theoretical maximum.

def compute_mfu(
    model_num_params: int,
    batch_size: int,
    seq_len: int,
    elapsed_time: float,   # Time per step (seconds)
    num_gpus: int,
    gpu_peak_flops: float, # GPU theoretical peak FLOPS
) -> float:
    """
    Compute MFU for LLM training.
    Reference: PaLM paper (Chowdhery et al., 2022)
    """
    # Forward FLOPs = 2 * num_params * batch_size * seq_len
    # Backward is ~2x forward -> total 6 * num_params * batch_size * seq_len
    flops_per_step = 6 * model_num_params * batch_size * seq_len
    achieved_flops = flops_per_step / elapsed_time
    peak_flops_total = gpu_peak_flops * num_gpus

    return achieved_flops / peak_flops_total

# Example usage
mfu = compute_mfu(
    model_num_params=8e9,     # 8B params
    batch_size=8,
    seq_len=2048,
    elapsed_time=0.5,         # 0.5 seconds per step
    num_gpus=8,
    gpu_peak_flops=989e12,    # H100 BF16: ~989 TFLOPS
)
print(f"MFU: {mfu:.1%}")  # e.g., MFU: 45.2%

# Typically achievable MFU:
# - Good implementation: 40-60%
# - Optimized (torchtitan, Megatron): 50-65%
# - Theoretical max: ~70% (communication/memory overhead unavoidable)

11. Practical Training Configuration Examples

Llama 3 8B: Single Node with 8x H100

# train_configs/llama3_8b_h100x8.toml

[job]
dump_folder = "./outputs/llama3_8b"

[model]
name = "llama3"
flavor = "8B"

[training]
batch_size = 4
seq_len = 8192
data_parallel_replicate_degree = 1
data_parallel_shard_degree = 8  # FSDP2: 8 GPUs
tensor_parallel_degree = 1

[activation_checkpoint]
mode = "selective"

[float8]
enable_float8_linear = true  # Enable FP8 training

[optimizer]
name = "AdamW"
lr = 1e-4
torchrun --nproc_per_node=8 train.py \
    --job.config_file train_configs/llama3_8b_h100x8.toml

Llama 3 70B: 4 Nodes x 8 H100 (32 GPUs Total)

# train_configs/llama3_70b_32gpu.toml

[training]
batch_size = 2
seq_len = 4096
data_parallel_replicate_degree = 2  # DDP: 2 replicas
data_parallel_shard_degree = 4      # FSDP: 4-GPU sharding
tensor_parallel_degree = 4          # TP: 4 GPUs

# Total GPUs = 2 x 4 x 4 = 32

[experimental]
pipeline_parallel_degree = 1  # PP disabled

Llama 3 405B: Large-Scale Cluster

# train_configs/llama3_405b_large.toml

[training]
batch_size = 1
seq_len = 2048
data_parallel_replicate_degree = 4   # DDP
data_parallel_shard_degree = 8       # FSDP
tensor_parallel_degree = 8           # TP

[experimental]
pipeline_parallel_degree = 8         # PP

# Total GPUs = 4 x 8 x 8 x 8 = 2048

Conclusion: What torchtitan Teaches Us

torchtitan goes beyond being a simple training tool — it is an educational resource that clearly demonstrates the best practices of modern LLM distributed training.

Key takeaways from this guide:

  1. 4D Parallelism: Combining DP, TP, PP, and SP to efficiently utilize thousands of GPUs
  2. FSDP2: ZeRO-3-level memory efficiency with PyTorch native APIs
  3. Flash Attention: O(N²) memory becomes O(N); 2-4x speedup
  4. Async Checkpointing: Save checkpoints without stopping training
  5. MFU optimization: Targeting 40-65% of theoretical GPU performance

Distributed training is a complex domain where hardware, software, and algorithms intersect. torchtitan exposes this complexity in the most transparent way possible, making learning and experimentation accessible. Run the code yourself, experiment with different parallelization combinations, and develop your intuition for distributed training.


References