Skip to content

Split View: 대규모 모델 학습 완전 가이드: 100B+ 파라미터 LLM 사전학습 전략

|

대규모 모델 학습 완전 가이드: 100B+ 파라미터 LLM 사전학습 전략

들어가며

LLaMA-3 405B, GPT-4, Falcon 180B — 이 모델들은 어떻게 학습되었을까요? 단순히 "GPU를 많이 쓰면 된다"는 식의 설명은 현실의 복잡성을 크게 단순화한 것입니다. 수백억 파라미터 LLM을 실제로 학습시키려면 스케일링 법칙에 기반한 하이퍼파라미터 설계, 정교한 분산 학습 전략, 학습 안정성 확보, 그리고 막대한 컴퓨팅 자원의 효율적 관리가 필요합니다.

이 가이드는 대규모 LLM 사전학습의 모든 핵심 요소를 실전 중심으로 다룹니다.


1. 스케일링 법칙 (Scaling Laws)

1.1 Kaplan et al. (OpenAI) 스케일링 법칙

2020년 Kaplan et al.의 논문 "Scaling Laws for Neural Language Models"는 언어 모델 성능이 세 가지 요소에 따라 멱함수(power law)적으로 향상됨을 보였습니다.

  • N: 모델 파라미터 수
  • D: 학습 데이터 토큰 수
  • C: 총 컴퓨팅 예산 (FLOPs)

핵심 발견:

  1. 모델 크기 우선 (Compute-efficient 관점): 고정된 컴퓨팅 예산에서 데이터보다 모델 크기를 늘리는 것이 더 효율적
  2. 데이터 효율성: 같은 모델 크기에서 더 오래 학습해도 수익 감소
  3. 멱법칙: 손실 L은 L(N) ≈ (Nc/N)^αN 형태

이 법칙에 따르면 "컴퓨팅 예산의 대부분을 큰 모델에 투자하고, 데이터는 최적보다 덜 학습"하는 것이 효율적입니다. 결과적으로 GPT-3(175B)같은 거대 모델이 비교적 적은 토큰(300B)으로 학습되었습니다.

1.2 Chinchilla 최적 스케일링 (Hoffmann et al. 2022)

2022년 DeepMind의 Hoffmann et al.은 논문 "Training Compute-Optimal Large Language Models"에서 Kaplan의 결론을 뒤집는 중요한 발견을 합니다.

핵심 주장: 기존 대형 모델들은 **심각하게 과소학습(undertrained)**되어 있습니다.

Chinchilla 실험:

  • 기존 방식: Gopher(280B 파라미터, 300B 토큰)
  • Chinchilla: 70B 파라미터, 1.4T 토큰
  • 결과: Chinchilla가 더 작은 크기로 Gopher를 압도

Chinchilla 최적 스케일링 법칙:

컴퓨팅 예산 C가 주어졌을 때 최적 모델 크기 N과 데이터 D:

N_optimal0.1174 × C^0.4999  (파라미터 수)
D_optimal1.6972 × C^0.5001  (토큰 수)

즉, N과 D는 컴퓨팅에 거의 동등하게 스케일링되어야 합니다. 경험적 규칙으로는:

D_optimal20 × N

7B 파라미터 모델 → 최소 140B 토큰 필요 70B 파라미터 모델 → 최소 1.4T 토큰 필요

1.3 실용적 계산 공식

FLOPs 추정:

def estimate_training_flops(
    num_params,      # 파라미터 수
    num_tokens,      # 학습 토큰 수
    include_backward=True
):
    """
    C ≈ 6 × N × D (순전파 + 역전파)
    역전파는 순전파의 약 2배 비용
    """
    flops_per_token = 6 * num_params
    total_flops = flops_per_token * num_tokens
    return total_flops

# 예시: LLaMA-2 7B, 2T 토큰
flops = estimate_training_flops(7e9, 2e12)
print(f"Total FLOPs: {flops:.2e}")
# Total FLOPs: 8.40e+22

# GPU 학습 시간 추정 (A100 312 TFLOPs, 50% MFU 가정)
a100_flops = 312e12  # 312 TFLOPs
mfu = 0.5           # Model FLOP Utilization
training_seconds = flops / (a100_flops * mfu)
training_hours = training_seconds / 3600

# 1000개 A100 사용 시
num_gpus = 1000
gpu_hours_per_gpu = training_hours / num_gpus
print(f"학습 시간: {gpu_hours_per_gpu:.1f} GPU-hours per GPU")
print(f"총 GPU-시간: {training_hours:.0f} GPU-hours")

1.4 Chinchilla 이후 트렌드: "Overtrain" 전략

실용적 관점에서 Chinchilla 최적값보다 더 많이 학습하는 경향이 있습니다.

이유: 학습 비용 vs 추론 비용의 트레이드오프

  • 학습은 한 번 하지만 추론은 수백만 번 실행
  • 더 작은 모델을 더 오래 학습 → 추론 비용 절감

예시: LLaMA-3 8B는 15T 토큰으로 학습 (Chinchilla 기준의 약 100배)


2. 사전학습 데이터 파이프라인

2.1 데이터 소스 구성

고품질 사전학습 데이터는 LLM 성능의 핵심입니다.

주요 데이터 소스:

소스특징품질
Common Crawl웹 크롤링, 수 PB 규모노이즈 많음, 정제 필수
Books3/Gutenberg책, 높은 품질제한된 다양성
Wikipedia백과사전, 검증된 정보제한된 양
GitHub코드, 논리적 추론 향상라이선스 주의
ArXiv/PubMed학술 논문전문성 높음
StackExchangeQ&A, 실용적 지식좋은 품질

실제 혼합 비율 (추정, Llama-2 기준):

data_mixture = {
    "Common Crawl (filtered)": 0.67,   # 67%
    "Books": 0.14,                     # 14%
    "GitHub": 0.045,                   # 4.5%
    "Wikipedia": 0.045,                # 4.5%
    "Gutenberg": 0.025,                # 2.5%
    "ArXiv": 0.025,                    # 2.5%
    "StackExchange": 0.02,             # 2%
}

2.2 데이터 정제 파이프라인

import re
from typing import List, Optional
from dataclasses import dataclass

@dataclass
class DocumentFilter:
    min_tokens: int = 50
    max_tokens: int = 100000
    min_avg_word_length: float = 3.0
    max_symbol_ratio: float = 0.1
    min_alpha_ratio: float = 0.7

def filter_document(text: str, config: DocumentFilter) -> Optional[str]:
    """기본 품질 필터"""
    tokens = text.split()
    token_count = len(tokens)

    # 길이 필터
    if not (config.min_tokens <= token_count <= config.max_tokens):
        return None

    # 평균 단어 길이
    avg_word_len = sum(len(t) for t in tokens) / token_count
    if avg_word_len < config.min_avg_word_length:
        return None

    # 알파벳 비율
    alpha_chars = sum(1 for c in text if c.isalpha())
    if alpha_chars / len(text) < config.min_alpha_ratio:
        return None

    return text

def deduplicate_documents(texts: List[str], n_gram_size: int = 13) -> List[str]:
    """
    MinHash LSH 기반 중복 제거
    (실제 구현은 datasketch 라이브러리 사용)
    """
    from datasketch import MinHash, MinHashLSH

    lsh = MinHashLSH(threshold=0.8, num_perm=128)
    unique_texts = []

    for i, text in enumerate(texts):
        minhash = MinHash(num_perm=128)
        # n-gram 단위로 해싱
        words = text.lower().split()
        for j in range(len(words) - n_gram_size + 1):
            ngram = " ".join(words[j:j+n_gram_size])
            minhash.update(ngram.encode("utf-8"))

        if not lsh.query(minhash):
            lsh.insert(str(i), minhash)
            unique_texts.append(text)

    return unique_texts

2.3 토크나이저 학습

from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import ByteLevel

def train_tokenizer(
    corpus_files: List[str],
    vocab_size: int = 32000,
    output_path: str = "tokenizer.json"
):
    """BPE 토크나이저 학습 (SentencePiece 방식)"""
    tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
    tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)

    trainer = BpeTrainer(
        vocab_size=vocab_size,
        min_frequency=2,
        special_tokens=["[UNK]", "[BOS]", "[EOS]", "[PAD]"],
        show_progress=True
    )

    tokenizer.train(files=corpus_files, trainer=trainer)
    tokenizer.save(output_path)
    return tokenizer

# 토큰 수 계산
def count_tokens(file_path: str, tokenizer) -> int:
    total = 0
    with open(file_path, "r") as f:
        for line in f:
            tokens = tokenizer.encode(line.strip())
            total += len(tokens.ids)
    return total

3. Megatron-LM

3.1 NVIDIA Megatron 소개

Megatron-LM은 NVIDIA에서 개발한 대규모 언어 모델 학습 프레임워크입니다. GPT, BERT, T5 등의 대형 모델 학습을 위해 특화된 병렬화 기법들을 제공합니다.

핵심 특징:

  • Tensor Parallelism: 행렬 연산을 GPU 간 분할
  • Pipeline Parallelism: 레이어를 GPU 간 순차 분산
  • Sequence Parallelism: 시퀀스 차원도 분산 (Megatron v2)
  • Flash Attention 통합: 메모리 효율적 어텐션 계산

3.2 Tensor Parallelism 구현 원리

트랜스포머의 핵심 연산은 행렬 곱셈입니다. 이를 어떻게 분할하는지가 텐서 병렬화의 핵심입니다.

Multi-Head Attention의 텐서 분할:

Q, K, V 투영:  [d_model, d_head × n_heads]
→ 각 GPU: [d_model, d_head × (n_heads/tp_size)]

FFN 레이어의 텐서 분할:

First linear:   [d_model, d_ffn]   → 열 방향 분할
Second linear:  [d_ffn, d_model]   → 행 방향 분할
→ 각 GPU: [d_model, d_ffn/tp_size] + [d_ffn/tp_size, d_model]
# Megatron의 ColumnParallelLinear 개념 (간략화)
class ColumnParallelLinear(nn.Module):
    """가중치 행렬을 열 방향으로 분할"""
    def __init__(self, in_features, out_features, tp_size):
        super().__init__()
        self.tp_size = tp_size
        assert out_features % tp_size == 0
        local_out = out_features // tp_size
        # 각 GPU는 전체의 1/tp_size만 보유
        self.weight = nn.Parameter(torch.empty(local_out, in_features))

    def forward(self, x):
        # 입력은 전체 크기, 출력은 분할된 크기
        return torch.nn.functional.linear(x, self.weight)
        # 이후 All-reduce 또는 All-gather 필요

class RowParallelLinear(nn.Module):
    """가중치 행렬을 행 방향으로 분할"""
    def __init__(self, in_features, out_features, tp_size):
        super().__init__()
        self.tp_size = tp_size
        assert in_features % tp_size == 0
        local_in = in_features // tp_size
        # 각 GPU는 입력 차원의 1/tp_size 처리
        self.weight = nn.Parameter(torch.empty(out_features, local_in))

    def forward(self, x):
        # x: [batch, seq, in_features/tp_size]
        local_output = torch.nn.functional.linear(x, self.weight)
        # All-reduce로 각 GPU의 부분 합 합산
        dist.all_reduce(local_output)
        return local_output

3.3 Sequence Parallelism

시퀀스 병렬화는 LayerNorm과 Dropout을 시퀀스 차원으로 분산합니다.

Attention 입력: [batch, seq/tp_size, d_model] (GPU)
All-gather: [batch, seq, d_model]
Self-Attention (열 분할)
Reduce-scatter: [batch, seq/tp_size, d_model]
FFN (행 분할)

이를 통해 LayerNorm 메모리도 tp_size배 줄어듭니다.

3.4 Megatron 설정 예제

#!/bin/bash
# Megatron-LM으로 GPT-3 175B 학습 (예시)

GPUS_PER_NODE=8
NNODES=64
TP_SIZE=8    # Tensor Parallelism
PP_SIZE=16   # Pipeline Parallelism
DP_SIZE=$((GPUS_PER_NODE * NNODES / TP_SIZE / PP_SIZE))
# DP_SIZE = 8 * 64 / 8 / 16 = 4

torchrun \
    --nproc_per_node $GPUS_PER_NODE \
    --nnodes $NNODES \
    pretrain_gpt.py \
    --num-layers 96 \
    --hidden-size 12288 \
    --num-attention-heads 96 \
    --seq-length 2048 \
    --max-position-embeddings 2048 \
    --global-batch-size 1536 \
    --train-iters 300000 \
    --lr 6e-5 \
    --lr-decay-style cosine \
    --min-lr 6e-6 \
    --lr-warmup-fraction 0.001 \
    --weight-decay 0.1 \
    --tensor-model-parallel-size $TP_SIZE \
    --pipeline-model-parallel-size $PP_SIZE \
    --micro-batch-size 1 \
    --bf16 \
    --use-flash-attn \
    --sequence-parallel \
    --data-path /data/pile_text_document \
    --vocab-file /data/gpt2-vocab.json \
    --merge-file /data/gpt2-merges.txt \
    --save /checkpoints/gpt3-175b \
    --load /checkpoints/gpt3-175b

4. 3D 병렬화 (DP + TP + PP)

4.1 세 가지 병렬화의 결합

대규모 LLM 학습에서는 세 가지 병렬화를 동시에 사용합니다.

병렬화 유형분산 대상통신 패턴오버헤드
데이터 병렬화 (DP)배치All-reduce 그래디언트낮음
텐서 병렬화 (TP)레이어 내 행렬All-reduce per layer중간 (레이턴시 민감)
파이프라인 병렬화 (PP)레이어 그룹Point-to-point버블 오버헤드

4.2 최적 병렬 설정 찾기

def find_optimal_parallelism(
    world_size: int,
    num_layers: int,
    hidden_size: int,
    gpu_memory_gb: float = 80.0
) -> dict:
    """
    3D 병렬화 최적 설정 탐색
    일반적 규칙:
    - TP: 단일 노드 내 (NVLink 활용), 보통 4 또는 8
    - PP: 노드 간 분산, 레이어 수에 따라 결정
    - DP: 나머지 GPU
    """
    configs = []
    for tp in [1, 2, 4, 8]:
        for pp in range(1, world_size + 1):
            dp = world_size // (tp * pp)
            if dp < 1:
                continue
            if world_size != dp * tp * pp:
                continue
            if num_layers % pp != 0:
                continue

            # 파이프라인 버블 비율 계산
            # 마이크로배치 수 m, pp 단계 p일 때: bubble = (p-1)/(m+p-1)
            # 여기서 m = global_batch / (dp * micro_batch)
            micro_batch = 4  # 가정
            global_batch = 2048  # 가정
            m = global_batch // (dp * micro_batch)
            if m <= 0:
                continue
            bubble_rate = (pp - 1) / (m + pp - 1)

            configs.append({
                "tp": tp, "pp": pp, "dp": dp,
                "bubble_rate": bubble_rate,
                "layers_per_stage": num_layers // pp
            })

    # 버블 비율이 낮고 TP가 합리적인 것 우선
    configs.sort(key=lambda x: (x["bubble_rate"], x["tp"]))
    return configs[:5]

# 예시: 512 GPU, GPT-3 규모 (96층)
best_configs = find_optimal_parallelism(512, 96, 12288)
for c in best_configs:
    print(f"TP={c['tp']}, PP={c['pp']}, DP={c['dp']}, "
          f"Bubble={c['bubble_rate']:.2%}, Layers/Stage={c['layers_per_stage']}")

4.3 DeepSpeed와 Megatron 결합

# Megatron-DeepSpeed 통합 설정
deepspeed_config = {
    "train_batch_size": 2048,
    "train_micro_batch_size_per_gpu": 1,
    "gradient_accumulation_steps": 256,   # 2048 / (8 DP * 1 micro)
    "bf16": {"enabled": True},
    "zero_optimization": {
        "stage": 1,   # TP+PP와 함께 ZeRO-1 사용
    },
    "gradient_clipping": 1.0,
    "wall_clock_breakdown": True,
    "steps_per_print": 10
}

4.4 통신 토폴로지 최적화

# NVLink vs InfiniBand 대역폭 고려
# NVLink 3.0: 600 GB/s (양방향)
# InfiniBand HDR: 200 Gb/s = 25 GB/s per port

# 통신 그룹 설계 원칙:
# TP 그룹: 동일 노드 내 GPU (NVLink 활용)
# PP 그룹: 노드 간 (InfiniBand 활용)
# DP 그룹: 노드 간 (ZeRO로 통신 최소화)

def setup_process_groups(tp_size, pp_size, dp_size):
    """프로세스 그룹 설정"""
    world_size = tp_size * pp_size * dp_size
    rank = torch.distributed.get_rank()

    # 텐서 병렬 그룹 (동일 노드 권장)
    for dp_rank in range(dp_size):
        for pp_rank in range(pp_size):
            ranks = [
                dp_rank * tp_size * pp_size + pp_rank * tp_size + tp_rank
                for tp_rank in range(tp_size)
            ]
            group = torch.distributed.new_group(ranks)
            if rank in ranks:
                tp_group = group

    return tp_group

5. 학습 안정성

5.1 손실 스파이크 처리

대규모 학습에서 갑작스러운 손실 급등(loss spike)은 흔한 문제입니다.

원인 분석:

  1. 과도한 그래디언트 노름: 특정 배치에서 그래디언트가 비정상적으로 큼
  2. 학습률 너무 높음: 손실 경관의 급격한 지형 변화
  3. 불량 데이터 배치: 이상치나 잘못 처리된 데이터
  4. 수치 불안정: bf16/fp16에서의 오버/언더플로우

모니터링 코드:

import torch
import wandb
from collections import deque

class TrainingMonitor:
    def __init__(self, spike_threshold: float = 3.0, window_size: int = 100):
        self.loss_history = deque(maxlen=window_size)
        self.grad_norm_history = deque(maxlen=window_size)
        self.spike_threshold = spike_threshold
        self.spike_count = 0

    def check_loss_spike(self, current_loss: float) -> bool:
        if len(self.loss_history) < 10:
            self.loss_history.append(current_loss)
            return False

        mean_loss = sum(self.loss_history) / len(self.loss_history)
        if current_loss > mean_loss * self.spike_threshold:
            self.spike_count += 1
            print(f"SPIKE 감지: current={current_loss:.4f}, mean={mean_loss:.4f}")
            return True

        self.loss_history.append(current_loss)
        return False

    def compute_grad_norm(self, model) -> float:
        total_norm = 0.0
        for p in model.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        return total_norm ** 0.5

    def log_step(self, step: int, loss: float, model, lr: float):
        grad_norm = self.compute_grad_norm(model)
        is_spike = self.check_loss_spike(loss)

        wandb.log({
            "train/loss": loss,
            "train/grad_norm": grad_norm,
            "train/lr": lr,
            "train/is_spike": int(is_spike),
            "train/spike_count": self.spike_count,
        }, step=step)

        return grad_norm, is_spike

5.2 그래디언트 클리핑과 노이즈 척도

def adaptive_gradient_clipping(
    model,
    optimizer,
    max_norm: float = 1.0,
    clip_coef: float = 0.01
):
    """
    적응형 그래디언트 클리핑
    AGC (Adaptive Gradient Clipping): 레이어별로 클리핑
    """
    for module in model.modules():
        for name, param in module.named_parameters(recurse=False):
            if param.grad is None:
                continue

            # 파라미터 노름과 그래디언트 노름 비교
            param_norm = param.data.norm(2)
            grad_norm = param.grad.data.norm(2)

            # 클리핑 임계값: 파라미터 노름의 clip_coef 배
            max_grad_norm = clip_coef * param_norm

            if grad_norm > max_grad_norm and grad_norm > 0:
                param.grad.data.mul_(max_grad_norm / grad_norm)

def compute_gradient_noise_scale(model, world_size: int) -> float:
    """
    그래디언트 노이즈 척도 (Gradient Noise Scale) 계산
    GNS = tr(S) / ||g||^2
    높은 GNS → 큰 배치 크기 허용
    """
    grads = [p.grad.data for p in model.parameters() if p.grad is not None]

    # 전체 그래디언트 노름
    g_norm_sq = sum(g.norm(2).item() ** 2 for g in grads)

    # 단순화된 분산 추정
    variance = sum((g - g.mean()).norm(2).item() ** 2 for g in grads)

    gns = variance / g_norm_sq if g_norm_sq > 0 else 0
    return gns

5.3 학습률 스케줄링 전략

import math

def cosine_schedule_with_warmup(
    current_step: int,
    warmup_steps: int,
    total_steps: int,
    max_lr: float,
    min_lr: float
) -> float:
    """
    코사인 감쇠 with 선형 워밍업
    대부분의 LLM 학습에서 표준으로 사용
    """
    if current_step < warmup_steps:
        # 선형 워밍업
        return max_lr * current_step / warmup_steps
    else:
        # 코사인 감쇠
        progress = (current_step - warmup_steps) / (total_steps - warmup_steps)
        cosine_val = 0.5 * (1 + math.cos(math.pi * progress))
        return min_lr + (max_lr - min_lr) * cosine_val

# WSD (Warmup-Stable-Decay) 스케줄 - 최근 인기
def wsd_schedule(
    current_step: int,
    warmup_steps: int,
    stable_steps: int,
    decay_steps: int,
    max_lr: float,
    min_lr: float
) -> float:
    """
    Warmup → Stable → Decay 세 단계 스케줄
    MiniCPM, LLaMA-3 등에서 사용
    연속 학습(continual learning)에 유리
    """
    if current_step < warmup_steps:
        return max_lr * current_step / warmup_steps
    elif current_step < warmup_steps + stable_steps:
        return max_lr
    else:
        decay_progress = (current_step - warmup_steps - stable_steps) / decay_steps
        return max_lr - (max_lr - min_lr) * min(decay_progress, 1.0)

5.4 배치 크기 스케줄링

# OpenAI GPT-3 논문에서 제안된 배치 크기 증가 전략
def get_batch_size_schedule(
    current_tokens: int,
    final_batch_tokens: int = 4_000_000,  # 최종 배치 크기 (토큰 단위)
    initial_batch_tokens: int = 32_000,   # 초기 배치 크기
    ramp_tokens: int = 1_200_000_000      # 증가 구간 (토큰)
) -> int:
    """
    배치 크기를 토큰 처리량에 따라 선형 증가
    초기: 작은 배치로 빠른 수렴
    후기: 큰 배치로 안정적 학습
    """
    if current_tokens < ramp_tokens:
        progress = current_tokens / ramp_tokens
        batch_tokens = initial_batch_tokens + (
            final_batch_tokens - initial_batch_tokens
        ) * progress
        return int(batch_tokens)
    return final_batch_tokens

6. 체크포인팅 전략

6.1 분산 체크포인트

import os
import torch
import torch.distributed as dist
from pathlib import Path

class DistributedCheckpointer:
    def __init__(self, save_dir: str, max_checkpoints: int = 5):
        self.save_dir = Path(save_dir)
        self.max_checkpoints = max_checkpoints
        self.save_dir.mkdir(parents=True, exist_ok=True)

    def save(
        self,
        model,
        optimizer,
        scheduler,
        step: int,
        rank: int,
        world_size: int
    ):
        """분산 체크포인트 저장"""
        checkpoint_dir = self.save_dir / f"step_{step:08d}"
        checkpoint_dir.mkdir(parents=True, exist_ok=True)

        # 각 랭크가 자신의 샤드 저장
        rank_path = checkpoint_dir / f"rank_{rank:04d}_of_{world_size:04d}.pt"

        # 모델 상태 (ZeRO 샤드)
        model_state = {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict() if scheduler else None,
            "step": step,
            "rank": rank,
            "world_size": world_size,
        }

        torch.save(model_state, rank_path)

        # 랭크 0이 메타데이터 저장
        if rank == 0:
            meta = {
                "step": step,
                "world_size": world_size,
                "timestamp": __import__("time").time()
            }
            torch.save(meta, checkpoint_dir / "meta.pt")

        # 오래된 체크포인트 정리
        if rank == 0:
            self._cleanup_old_checkpoints()

    def _cleanup_old_checkpoints(self):
        checkpoints = sorted(self.save_dir.glob("step_*"))
        while len(checkpoints) > self.max_checkpoints:
            oldest = checkpoints.pop(0)
            import shutil
            shutil.rmtree(oldest)

6.2 비동기 체크포인트 저장

import threading
from queue import Queue

class AsyncCheckpointer:
    """별도 스레드에서 체크포인트 저장 (학습 중단 없음)"""
    def __init__(self, save_dir: str):
        self.save_dir = save_dir
        self.queue = Queue(maxsize=2)
        self.worker = threading.Thread(target=self._worker, daemon=True)
        self.worker.start()

    def _worker(self):
        while True:
            item = self.queue.get()
            if item is None:
                break
            state_dict, path = item
            torch.save(state_dict, path)
            self.queue.task_done()

    def save_async(self, model, step: int, rank: int):
        """메인 스레드 블로킹 없이 저장 큐에 추가"""
        # CPU로 복사 (GPU 메모리 해제)
        state_dict = {
            k: v.cpu().clone() for k, v in model.state_dict().items()
        }
        path = os.path.join(self.save_dir, f"step_{step}_rank_{rank}.pt")

        if not self.queue.full():
            self.queue.put((state_dict, path))
        else:
            # 큐 가득 찼으면 동기적 저장
            torch.save(state_dict, path)

7. 학습 모니터링

7.1 핵심 지표 추적

import wandb
import numpy as np

class LLMTrainingTracker:
    def __init__(self, project_name: str, run_name: str):
        wandb.init(project=project_name, name=run_name)
        self.step = 0
        self.token_count = 0

    def log_training_step(
        self,
        loss: float,
        learning_rate: float,
        grad_norm: float,
        batch_tokens: int,
        elapsed_seconds: float
    ):
        tokens_per_second = batch_tokens / elapsed_seconds
        # 토큰 처리량으로 MFU 추정
        # MFU = actual_flops / theoretical_peak_flops

        self.token_count += batch_tokens
        self.step += 1

        wandb.log({
            # 학습 지표
            "train/loss": loss,
            "train/perplexity": np.exp(min(loss, 20)),
            "train/grad_norm": grad_norm,
            "train/learning_rate": learning_rate,
            # 효율성 지표
            "throughput/tokens_per_second": tokens_per_second,
            "throughput/samples_per_second": tokens_per_second / 2048,
            # 진행 상황
            "progress/total_tokens": self.token_count,
            "progress/step": self.step,
        }, step=self.step)

    def log_evaluation(self, eval_loss: float, perplexities: dict):
        wandb.log({
            "eval/loss": eval_loss,
            "eval/perplexity": np.exp(eval_loss),
            **{f"eval/{k}_ppl": v for k, v in perplexities.items()}
        }, step=self.step)

7.2 손실 곡선 해석

정상적인 학습 곡선의 특징:

  • 워밍업 구간: 빠른 감소
  • 안정 구간: 느리지만 꾸준한 감소
  • 후기 구간: 매우 완만한 감소

문제 신호:

  • 완전 정체 (Loss plateau): 학습률 너무 낮음, 데이터 고갈
  • 급격한 스파이크: 불량 배치, 그래디언트 폭발
  • NaN/Inf: 수치 불안정, fp16 오버플로우
def analyze_loss_curve(losses: list, window: int = 100) -> dict:
    """손실 곡선 자동 분석"""
    if len(losses) < window * 2:
        return {}

    recent = losses[-window:]
    previous = losses[-2*window:-window]

    recent_mean = np.mean(recent)
    previous_mean = np.mean(previous)
    improvement = (previous_mean - recent_mean) / previous_mean

    # 스파이크 감지
    global_mean = np.mean(losses)
    global_std = np.std(losses)
    spikes = [l for l in losses if l > global_mean + 3 * global_std]

    return {
        "recent_loss": recent_mean,
        "improvement_rate": improvement,
        "spike_count": len(spikes),
        "is_stagnant": improvement < 0.001,
        "recommendation": "학습률 증가 검토" if improvement < 0.001 else "정상"
    }

8. 오픈소스 LLM 학습 코드베이스

8.1 GPT-NeoX (EleutherAI)

# GPT-NeoX 설치 및 실행
git clone https://github.com/EleutherAI/gpt-neox
cd gpt-neox
pip install -r requirements/requirements.txt

# 설정 파일 (configs/20B.yml)
# 학습 시작
python deepy.py train configs/20B.yml

GPT-NeoX는 Megatron 기반 파이프라인 병렬화와 DeepSpeed를 결합합니다. EleutherAI의 Pythia 시리즈가 이 코드베이스로 학습되었습니다.

8.2 OLMo (Allen AI)

# OLMo 학습 설정 (간략화)
from olmo import TrainConfig, ModelConfig, OptimizerConfig

train_config = TrainConfig(
    model=ModelConfig(
        d_model=4096,
        n_heads=32,
        n_layers=32,
        mlp_ratio=8/3,
        vocab_size=50280,
        max_sequence_length=2048,
        attention_type="flash",
    ),
    optimizer=OptimizerConfig(
        name="adamw",
        learning_rate=3e-4,
        weight_decay=0.1,
        betas=(0.9, 0.95),
    ),
    max_duration="300000ba",    # 300,000 배치
    global_train_batch_size=2048,
    device_train_microbatch_size=2,
    precision="bf16",
    fsdp_config={
        "wrapping_strategy": "by_block",
        "precision": "bf16",
        "sharding_strategy": "FULL_SHARD",
    },
)

OLMo는 완전한 투명성을 목표로 학습 데이터, 코드, 중간 체크포인트를 모두 공개합니다.

8.3 torchtitan

# torchtitan - PyTorch 네이티브 LLM 학습
# Meta에서 개발한 최신 사전학습 프레임워크

# torchtitan 설정 (toml 형식)
config = """
[model]
name = "llama3"
flavor = "8B"
tokenizer_path = "./original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 8
seq_len = 8192
max_norm = 1.0
steps = 10000
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1  # FSDP2
tensor_parallel_degree = 1
compile = true   # torch.compile 활성화

[checkpoint]
enable_checkpoint = true
folder = "outputs/checkpoint"
interval_type = "steps"
interval = 500
"""

torchtitan은 FSDP2, Tensor Parallel, Pipeline Parallel을 PyTorch 네이티브로 구현합니다.

8.4 FSDP2 (Fully Sharded Data Parallel)

import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

from transformers.models.llama.modeling_llama import LlamaDecoderLayer

# FSDP 래핑 정책
auto_wrap_policy = transformer_auto_wrap_policy(
    transformer_layer_cls={LlamaDecoderLayer}
)

# FSDP2 (torch 2.x의 새로운 API)
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy

mp_policy = MixedPrecisionPolicy(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.float32,
)

for layer in model.model.layers:
    fully_shard(
        layer,
        mp_policy=mp_policy,
        reshard_after_forward=True,  # ZeRO-3와 유사
    )

fully_shard(model, mp_policy=mp_policy)

9. 비용 추정과 효율화 전략

9.1 GPU-시간 계산

def estimate_training_cost(
    model_params: float,  # 파라미터 수 (예: 7e9)
    training_tokens: float,  # 학습 토큰 (예: 2e12)
    num_gpus: int,
    gpu_type: str = "A100-80GB",
    cloud_provider: str = "AWS"
) -> dict:
    """학습 비용 추정"""
    # GPU 스펙 (이론적 피크, bf16)
    gpu_specs = {
        "A100-80GB": {"tflops": 312, "memory_gb": 80, "nvlink": True},
        "H100-80GB": {"tflops": 989, "memory_gb": 80, "nvlink": True},
        "A10G-24GB": {"tflops": 125, "memory_gb": 24, "nvlink": False},
        "V100-32GB": {"tflops": 130, "memory_gb": 32, "nvlink": True},
    }

    # 클라우드 가격 (시간당 USD, 대략적)
    cloud_prices = {
        "AWS": {"A100-80GB": 3.97, "H100-80GB": 8.0, "A10G-24GB": 1.006},
        "GCP": {"A100-80GB": 3.67, "H100-80GB": 7.0},
        "Azure": {"A100-80GB": 3.40, "H100-80GB": 6.0},
        "Lambda": {"A100-80GB": 1.99, "H100-80GB": 3.29},
    }

    spec = gpu_specs[gpu_type]
    price_per_hour = cloud_prices.get(cloud_provider, {}).get(gpu_type, 0)

    # FLOPs 계산
    total_flops = 6 * model_params * training_tokens

    # 실제 MFU 가정 (보통 35-55%)
    mfu = 0.45
    effective_tflops = spec["tflops"] * 1e12 * mfu

    # 총 학습 시간
    total_seconds = total_flops / (effective_tflops * num_gpus)
    total_hours = total_seconds / 3600
    total_days = total_hours / 24

    # 비용 계산
    total_cost = price_per_hour * num_gpus * total_hours

    return {
        "total_flops": f"{total_flops:.2e}",
        "training_hours": f"{total_hours:.1f}",
        "training_days": f"{total_days:.1f}",
        "gpu_hours": f"{total_hours * num_gpus:.0f}",
        "estimated_cost_usd": f"{total_cost:,.0f}",
        "mfu": mfu,
    }

# 예시
result = estimate_training_cost(
    model_params=7e9,
    training_tokens=2e12,
    num_gpus=256,
    gpu_type="A100-80GB",
    cloud_provider="Lambda"
)
for k, v in result.items():
    print(f"{k}: {v}")

9.2 MFU (Model FLOP Utilization) 최적화

def measure_mfu(
    model,
    batch_tokens: int,
    elapsed_seconds: float,
    theoretical_peak_tflops: float
) -> float:
    """
    실제 MFU 측정
    MFU = actual_FLOP_rate / theoretical_peak_FLOP_rate
    """
    # 모델의 실제 FLOPs 추정
    num_params = sum(p.numel() for p in model.parameters())
    actual_flops = 6 * num_params * batch_tokens  # 순전파(2) + 역전파(4)

    actual_tflops = actual_flops / elapsed_seconds / 1e12
    mfu = actual_tflops / theoretical_peak_tflops

    return mfu

# MFU 향상 전략:
# 1. Flash Attention 사용 (메모리 I/O 최소화)
# 2. torch.compile 활성화 (커널 퓨전)
# 3. 활성화 체크포인팅 신중히 사용 (속도 저하)
# 4. 최적 배치 크기 선택 (GPU 점유율 최대화)
# 5. 통신과 연산 오버랩 (FSDP/DeepSpeed 설정)

9.3 효율화 전략 요약

전략메모리 절약속도 영향구현 난이도
ZeRO-2중간-5%쉬움
ZeRO-3높음-15%중간
Flash Attention 2높음+20%쉬움
torch.compile없음+15-30%쉬움
Activation Checkpointing중간-30%쉬움
bf16 (vs fp16)없음안정성 향상쉬움
시퀀스 패킹없음+20%중간

10. 완전한 사전학습 런처 스크립트

#!/usr/bin/env python3
"""
대규모 LLM 사전학습 완전 예제
GPT 스타일 모델, DeepSpeed ZeRO-2 + Flash Attention
"""
import os
import time
import math
import torch
import torch.nn as nn
import deepspeed
from torch.utils.data import DataLoader, IterableDataset
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    get_cosine_schedule_with_warmup,
)

# 설정
TRAINING_CONFIG = {
    "model_name": "meta-llama/Llama-2-7b",
    "total_tokens": 1_000_000_000,   # 1B 토큰 (데모용)
    "max_seq_len": 2048,
    "global_batch_size": 1024,       # 글로벌 배치 (토큰 단위: 1024 * 2048 = 2M)
    "micro_batch_per_gpu": 4,
    "learning_rate": 3e-4,
    "min_lr": 3e-5,
    "warmup_tokens": 20_000_000,     # 2% 워밍업
    "weight_decay": 0.1,
    "grad_clip": 1.0,
    "save_every_steps": 1000,
    "eval_every_steps": 500,
    "log_every_steps": 10,
}

class StreamingTokenDataset(IterableDataset):
    """스트리밍 토큰 데이터셋 (대용량 데이터 처리)"""
    def __init__(self, data_path: str, seq_len: int, rank: int, world_size: int):
        self.data_path = data_path
        self.seq_len = seq_len
        self.rank = rank
        self.world_size = world_size

    def __iter__(self):
        # 각 랭크가 다른 데이터 구간 처리
        with open(self.data_path, "rb") as f:
            f.seek(self.rank * self.seq_len * 2)  # uint16 기준
            while True:
                chunk = f.read(self.seq_len * 2 * self.world_size)
                if not chunk:
                    break
                tokens = torch.frombuffer(chunk, dtype=torch.uint16).long()
                if len(tokens) < self.seq_len + 1:
                    break
                input_ids = tokens[:self.seq_len]
                labels = tokens[1:self.seq_len + 1]
                yield {"input_ids": input_ids, "labels": labels}

def train():
    deepspeed.init_distributed()
    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()
    local_rank = int(os.environ.get("LOCAL_RANK", 0))

    # 모델 로드
    config = AutoConfig.from_pretrained(TRAINING_CONFIG["model_name"])
    with deepspeed.zero.Init():
        model = AutoModelForCausalLM.from_config(config)

    # DeepSpeed 엔진 초기화
    model_engine, optimizer, _, _ = deepspeed.initialize(
        model=model,
        config={
            "train_micro_batch_size_per_gpu": TRAINING_CONFIG["micro_batch_per_gpu"],
            "gradient_accumulation_steps": (
                TRAINING_CONFIG["global_batch_size"]
                // TRAINING_CONFIG["micro_batch_per_gpu"]
                // world_size
            ),
            "bf16": {"enabled": True},
            "zero_optimization": {
                "stage": 2,
                "overlap_comm": True,
                "contiguous_gradients": True,
                "allgather_bucket_size": 2e8,
                "reduce_bucket_size": 2e8,
            },
            "gradient_clipping": TRAINING_CONFIG["grad_clip"],
            "optimizer": {
                "type": "AdamW",
                "params": {
                    "lr": TRAINING_CONFIG["learning_rate"],
                    "betas": [0.9, 0.95],
                    "eps": 1e-8,
                    "weight_decay": TRAINING_CONFIG["weight_decay"],
                }
            },
            "steps_per_print": TRAINING_CONFIG["log_every_steps"],
        }
    )

    # 데이터셋
    train_dataset = StreamingTokenDataset(
        "/data/train_tokens.bin",
        TRAINING_CONFIG["max_seq_len"],
        rank, world_size
    )
    train_loader = DataLoader(
        train_dataset,
        batch_size=TRAINING_CONFIG["micro_batch_per_gpu"],
        num_workers=4,
        pin_memory=True,
    )

    # 학습 루프
    total_tokens = 0
    target_tokens = TRAINING_CONFIG["total_tokens"]

    for batch in train_loader:
        if total_tokens >= target_tokens:
            break

        input_ids = batch["input_ids"].to(model_engine.device)
        labels = batch["labels"].to(model_engine.device)

        # 학습률 스케줄
        warmup_tokens = TRAINING_CONFIG["warmup_tokens"]
        lr = cosine_schedule_with_warmup(
            total_tokens, warmup_tokens, target_tokens,
            TRAINING_CONFIG["learning_rate"], TRAINING_CONFIG["min_lr"]
        )
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr

        # 순전파/역전파
        outputs = model_engine(input_ids=input_ids, labels=labels)
        loss = outputs.loss
        model_engine.backward(loss)
        model_engine.step()

        # 토큰 카운트 업데이트
        batch_tokens = input_ids.numel() * world_size
        total_tokens += batch_tokens

        # 로깅 (랭크 0만)
        if rank == 0 and model_engine.global_steps % TRAINING_CONFIG["log_every_steps"] == 0:
            print(f"Tokens: {total_tokens/1e9:.2f}B, "
                  f"Loss: {loss.item():.4f}, "
                  f"LR: {lr:.2e}, "
                  f"PPL: {math.exp(loss.item()):.1f}")

        # 체크포인트 저장
        if model_engine.global_steps % TRAINING_CONFIG["save_every_steps"] == 0:
            model_engine.save_checkpoint(
                "./checkpoints",
                tag=f"step_{model_engine.global_steps}"
            )

    if rank == 0:
        print("학습 완료!")
        model_engine.save_checkpoint("./checkpoints", tag="final")

def cosine_schedule_with_warmup(step, warmup, total, max_lr, min_lr):
    if step < warmup:
        return max_lr * step / max(warmup, 1)
    progress = (step - warmup) / max(total - warmup, 1)
    return min_lr + (max_lr - min_lr) * 0.5 * (1 + math.cos(math.pi * progress))

if __name__ == "__main__":
    train()

실행:

deepspeed --num_nodes=4 --num_gpus=8 \
    --master_addr=node0 --master_port=29500 \
    pretrain.py

마치며

100B+ 파라미터 LLM 사전학습은 단순히 코드를 실행하는 것이 아니라, 수많은 공학적 결정과 트레이드오프를 균형 있게 조율하는 작업입니다.

핵심 요약:

  • 스케일링 법칙: Chinchilla 최적화에 따라 모델 크기와 데이터 양을 균형 있게 설정하되, 추론 비용을 고려하면 "overtrain"도 합리적
  • 데이터가 핵심: 고품질 데이터 정제와 균형 잡힌 혼합이 모델 성능을 결정
  • 3D 병렬화: DP + TP + PP 조합으로 수천 GPU를 효율적으로 활용
  • 안정성 우선: Loss spike, 그래디언트 폭발 모니터링과 빠른 대응이 필수
  • 비용 의식: MFU와 GPU 활용률을 지속적으로 모니터링하여 효율 최적화

오픈소스 생태계(OLMo, GPT-NeoX, torchtitan)가 대규모 학습의 진입 장벽을 낮추고 있습니다. 이 가이드의 기법들을 실제 프로젝트에 적용하여 자신만의 LLM을 학습시켜 보세요.

참고 자료

Large-Scale Model Training Complete Guide: Strategies for Pre-training 100B+ Parameter LLMs

Introduction

LLaMA-3 405B, GPT-4, Falcon 180B — how were these models actually trained? Saying "just use a lot of GPUs" grossly oversimplifies the real complexity involved. Training a hundred-billion parameter LLM requires hyperparameter design grounded in scaling laws, sophisticated distributed training strategies, training stability assurance, and efficient management of enormous computing resources.

This guide covers every critical aspect of large-scale LLM pre-training from a practical perspective.


1. Scaling Laws

1.1 Kaplan et al. (OpenAI) Scaling Laws

In 2020, Kaplan et al.'s paper "Scaling Laws for Neural Language Models" demonstrated that language model performance improves as a power law with three key factors:

  • N: Number of model parameters
  • D: Number of training data tokens
  • C: Total compute budget (FLOPs)

Key findings:

  1. Model size priority (compute-efficient view): For a fixed compute budget, increasing model size is more efficient than increasing data
  2. Data efficiency: Training the same model size longer yields diminishing returns
  3. Power law: Loss L follows approximately L(N) ≈ (Nc/N)^αN

This law suggested investing most of the compute budget in larger models while training on less data than optimal — which is why GPT-3 (175B) was trained on relatively few tokens (300B).

1.2 Chinchilla Optimal Scaling (Hoffmann et al. 2022)

In 2022, DeepMind's Hoffmann et al. published "Training Compute-Optimal Large Language Models" with a critical finding that overturned Kaplan's conclusions.

Core claim: Existing large models were severely undertrained.

Chinchilla experiment:

  • Previous approach: Gopher (280B parameters, 300B tokens)
  • Chinchilla: 70B parameters, 1.4T tokens
  • Result: Chinchilla, at a smaller size, decisively outperformed Gopher

Chinchilla optimal scaling law:

Given compute budget C, optimal model size N and data D:

N_optimal0.1174 × C^0.4999  (parameter count)
D_optimal1.6972 × C^0.5001  (token count)

N and D should scale nearly equally with compute. The empirical rule of thumb:

D_optimal20 × N

7B parameter model → minimum 140B tokens required 70B parameter model → minimum 1.4T tokens required

1.3 Practical Computation Formulas

FLOPs estimation:

def estimate_training_flops(
    num_params,       # Number of parameters
    num_tokens,       # Training token count
    include_backward=True
):
    """
    C ≈ 6 × N × D (forward + backward)
    Backward pass costs roughly 2x forward pass
    """
    flops_per_token = 6 * num_params
    total_flops = flops_per_token * num_tokens
    return total_flops

# Example: LLaMA-2 7B, 2T tokens
flops = estimate_training_flops(7e9, 2e12)
print(f"Total FLOPs: {flops:.2e}")
# Total FLOPs: 8.40e+22

# GPU training time estimation (A100 312 TFLOPs, 50% MFU assumed)
a100_flops = 312e12   # 312 TFLOPs
mfu = 0.5             # Model FLOP Utilization
training_seconds = flops / (a100_flops * mfu)
training_hours = training_seconds / 3600

# With 1000 A100s
num_gpus = 1000
gpu_hours_per_gpu = training_hours / num_gpus
print(f"Training time: {gpu_hours_per_gpu:.1f} GPU-hours per GPU")
print(f"Total GPU-hours: {training_hours:.0f} GPU-hours")

1.4 Post-Chinchilla Trend: "Overtrain" Strategy

In practice, there is a trend toward training beyond Chinchilla-optimal values.

Reason: Trade-off between training cost vs. inference cost

  • Training happens once, but inference runs millions of times
  • Smaller model trained longer → lower inference cost

Example: LLaMA-3 8B was trained on 15T tokens (~100x Chinchilla optimal)


2. Pre-training Data Pipeline

2.1 Data Source Composition

High-quality pre-training data is central to LLM performance.

Major data sources:

SourceCharacteristicsQuality
Common CrawlWeb crawl, PB-scaleNoisy, must filter
Books3/GutenbergBooks, high qualityLimited diversity
WikipediaEncyclopedia, verifiedLimited volume
GitHubCode, improves reasoningLicense considerations
ArXiv/PubMedAcademic papersHigh expertise
StackExchangeQ&A, practical knowledgeGood quality

Typical mixing ratios (estimated, Llama-2 basis):

data_mixture = {
    "Common Crawl (filtered)": 0.67,   # 67%
    "Books": 0.14,                     # 14%
    "GitHub": 0.045,                   # 4.5%
    "Wikipedia": 0.045,                # 4.5%
    "Gutenberg": 0.025,                # 2.5%
    "ArXiv": 0.025,                    # 2.5%
    "StackExchange": 0.02,             # 2%
}

2.2 Data Cleaning Pipeline

import re
from typing import List, Optional
from dataclasses import dataclass

@dataclass
class DocumentFilter:
    min_tokens: int = 50
    max_tokens: int = 100000
    min_avg_word_length: float = 3.0
    max_symbol_ratio: float = 0.1
    min_alpha_ratio: float = 0.7

def filter_document(text: str, config: DocumentFilter) -> Optional[str]:
    """Basic quality filter"""
    tokens = text.split()
    token_count = len(tokens)

    # Length filter
    if not (config.min_tokens <= token_count <= config.max_tokens):
        return None

    # Average word length
    avg_word_len = sum(len(t) for t in tokens) / token_count
    if avg_word_len < config.min_avg_word_length:
        return None

    # Alphabetic character ratio
    alpha_chars = sum(1 for c in text if c.isalpha())
    if alpha_chars / len(text) < config.min_alpha_ratio:
        return None

    return text

def deduplicate_documents(texts: List[str], n_gram_size: int = 13) -> List[str]:
    """
    MinHash LSH-based deduplication
    (real implementation uses datasketch library)
    """
    from datasketch import MinHash, MinHashLSH

    lsh = MinHashLSH(threshold=0.8, num_perm=128)
    unique_texts = []

    for i, text in enumerate(texts):
        minhash = MinHash(num_perm=128)
        words = text.lower().split()
        for j in range(len(words) - n_gram_size + 1):
            ngram = " ".join(words[j:j+n_gram_size])
            minhash.update(ngram.encode("utf-8"))

        if not lsh.query(minhash):
            lsh.insert(str(i), minhash)
            unique_texts.append(text)

    return unique_texts

2.3 Tokenizer Training

from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import ByteLevel

def train_tokenizer(
    corpus_files: List[str],
    vocab_size: int = 32000,
    output_path: str = "tokenizer.json"
):
    """Train BPE tokenizer (SentencePiece-style)"""
    tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
    tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)

    trainer = BpeTrainer(
        vocab_size=vocab_size,
        min_frequency=2,
        special_tokens=["[UNK]", "[BOS]", "[EOS]", "[PAD]"],
        show_progress=True
    )

    tokenizer.train(files=corpus_files, trainer=trainer)
    tokenizer.save(output_path)
    return tokenizer

def count_tokens(file_path: str, tokenizer) -> int:
    total = 0
    with open(file_path, "r") as f:
        for line in f:
            tokens = tokenizer.encode(line.strip())
            total += len(tokens.ids)
    return total

3. Megatron-LM

3.1 Introduction to NVIDIA Megatron

Megatron-LM is a large-scale language model training framework developed by NVIDIA. It provides specialized parallelism techniques for training large models like GPT, BERT, and T5.

Core features:

  • Tensor Parallelism: Splits matrix operations across GPUs
  • Pipeline Parallelism: Distributes layers sequentially across GPUs
  • Sequence Parallelism: Also distributes along the sequence dimension (Megatron v2)
  • Flash Attention integration: Memory-efficient attention computation

3.2 Tensor Parallelism Implementation Principles

The core operation in transformers is matrix multiplication. How to partition it is the essence of tensor parallelism.

Tensor splitting in Multi-Head Attention:

Q, K, V projection:  [d_model, d_head x n_heads]
→ per GPU: [d_model, d_head x (n_heads/tp_size)]

Tensor splitting in FFN layers:

First linear:   [d_model, d_ffn]  → column-wise split
Second linear:  [d_ffn, d_model]  → row-wise split
→ per GPU: [d_model, d_ffn/tp_size] + [d_ffn/tp_size, d_model]
# Conceptual Megatron ColumnParallelLinear (simplified)
class ColumnParallelLinear(nn.Module):
    """Split weight matrix along columns"""
    def __init__(self, in_features, out_features, tp_size):
        super().__init__()
        self.tp_size = tp_size
        assert out_features % tp_size == 0
        local_out = out_features // tp_size
        # Each GPU holds 1/tp_size of the full weight
        self.weight = nn.Parameter(torch.empty(local_out, in_features))

    def forward(self, x):
        # Input is full size, output is partitioned
        return torch.nn.functional.linear(x, self.weight)
        # Subsequent All-reduce or All-gather needed

class RowParallelLinear(nn.Module):
    """Split weight matrix along rows"""
    def __init__(self, in_features, out_features, tp_size):
        super().__init__()
        self.tp_size = tp_size
        assert in_features % tp_size == 0
        local_in = in_features // tp_size
        # Each GPU handles 1/tp_size of the input dimension
        self.weight = nn.Parameter(torch.empty(out_features, local_in))

    def forward(self, x):
        # x: [batch, seq, in_features/tp_size]
        local_output = torch.nn.functional.linear(x, self.weight)
        # All-reduce to sum partial results from each GPU
        dist.all_reduce(local_output)
        return local_output

3.3 Sequence Parallelism

Sequence parallelism distributes LayerNorm and Dropout along the sequence dimension.

Attention input: [batch, seq/tp_size, d_model] (per GPU)
All-gather: [batch, seq, d_model]
Self-Attention (column split)
Reduce-scatter: [batch, seq/tp_size, d_model]
FFN (row split)

This reduces LayerNorm memory by tp_size as well.

3.4 Megatron Configuration Example

#!/bin/bash
# Training GPT-3 175B with Megatron-LM (example)

GPUS_PER_NODE=8
NNODES=64
TP_SIZE=8    # Tensor Parallelism
PP_SIZE=16   # Pipeline Parallelism
DP_SIZE=$((GPUS_PER_NODE * NNODES / TP_SIZE / PP_SIZE))
# DP_SIZE = 8 * 64 / 8 / 16 = 4

torchrun \
    --nproc_per_node $GPUS_PER_NODE \
    --nnodes $NNODES \
    pretrain_gpt.py \
    --num-layers 96 \
    --hidden-size 12288 \
    --num-attention-heads 96 \
    --seq-length 2048 \
    --max-position-embeddings 2048 \
    --global-batch-size 1536 \
    --train-iters 300000 \
    --lr 6e-5 \
    --lr-decay-style cosine \
    --min-lr 6e-6 \
    --lr-warmup-fraction 0.001 \
    --weight-decay 0.1 \
    --tensor-model-parallel-size $TP_SIZE \
    --pipeline-model-parallel-size $PP_SIZE \
    --micro-batch-size 1 \
    --bf16 \
    --use-flash-attn \
    --sequence-parallel \
    --data-path /data/pile_text_document \
    --vocab-file /data/gpt2-vocab.json \
    --merge-file /data/gpt2-merges.txt \
    --save /checkpoints/gpt3-175b \
    --load /checkpoints/gpt3-175b

4. 3D Parallelism (DP + TP + PP)

4.1 Combining Three Parallelism Types

Large-scale LLM training uses all three parallelism types simultaneously.

Parallelism TypeWhat It DistributesCommunication PatternOverhead
Data Parallelism (DP)BatchesAll-reduce gradientsLow
Tensor Parallelism (TP)Intra-layer matricesAll-reduce per layerMedium (latency-sensitive)
Pipeline Parallelism (PP)Layer groupsPoint-to-pointPipeline bubble overhead

4.2 Finding Optimal Parallelism Configuration

def find_optimal_parallelism(
    world_size: int,
    num_layers: int,
    hidden_size: int,
    gpu_memory_gb: float = 80.0
) -> dict:
    """
    Search for optimal 3D parallelism configuration.
    General rules:
    - TP: Within single node (leverage NVLink), typically 4 or 8
    - PP: Across nodes, determined by layer count
    - DP: Remaining GPUs
    """
    configs = []
    for tp in [1, 2, 4, 8]:
        for pp in range(1, world_size + 1):
            dp = world_size // (tp * pp)
            if dp < 1:
                continue
            if world_size != dp * tp * pp:
                continue
            if num_layers % pp != 0:
                continue

            # Calculate pipeline bubble fraction
            # With m microbatches and p pipeline stages: bubble = (p-1)/(m+p-1)
            micro_batch = 4   # assumed
            global_batch = 2048   # assumed
            m = global_batch // (dp * micro_batch)
            if m <= 0:
                continue
            bubble_rate = (pp - 1) / (m + pp - 1)

            configs.append({
                "tp": tp, "pp": pp, "dp": dp,
                "bubble_rate": bubble_rate,
                "layers_per_stage": num_layers // pp
            })

    # Prioritize low bubble rate and reasonable TP size
    configs.sort(key=lambda x: (x["bubble_rate"], x["tp"]))
    return configs[:5]

# Example: 512 GPUs, GPT-3 scale (96 layers)
best_configs = find_optimal_parallelism(512, 96, 12288)
for c in best_configs:
    print(f"TP={c['tp']}, PP={c['pp']}, DP={c['dp']}, "
          f"Bubble={c['bubble_rate']:.2%}, Layers/Stage={c['layers_per_stage']}")

4.3 DeepSpeed + Megatron Combination

# Megatron-DeepSpeed integration configuration
deepspeed_config = {
    "train_batch_size": 2048,
    "train_micro_batch_size_per_gpu": 1,
    "gradient_accumulation_steps": 256,   # 2048 / (8 DP * 1 micro)
    "bf16": {"enabled": True},
    "zero_optimization": {
        "stage": 1,   # Use ZeRO-1 alongside TP+PP
    },
    "gradient_clipping": 1.0,
    "wall_clock_breakdown": True,
    "steps_per_print": 10
}

4.4 Communication Topology Optimization

# Consider NVLink vs InfiniBand bandwidth
# NVLink 3.0: 600 GB/s (bidirectional)
# InfiniBand HDR: 200 Gb/s = 25 GB/s per port

# Communication group design principles:
# TP group: GPUs within same node (leverage NVLink)
# PP group: Across nodes (leverage InfiniBand)
# DP group: Across nodes (minimize communication with ZeRO)

def setup_process_groups(tp_size, pp_size, dp_size):
    """Configure process groups"""
    world_size = tp_size * pp_size * dp_size
    rank = torch.distributed.get_rank()

    # Tensor parallel group (recommended within single node)
    for dp_rank in range(dp_size):
        for pp_rank in range(pp_size):
            ranks = [
                dp_rank * tp_size * pp_size + pp_rank * tp_size + tp_rank
                for tp_rank in range(tp_size)
            ]
            group = torch.distributed.new_group(ranks)
            if rank in ranks:
                tp_group = group

    return tp_group

5. Training Stability

5.1 Handling Loss Spikes

Sudden loss spikes are common in large-scale training.

Root cause analysis:

  1. Excessive gradient norm: Unusually large gradients in specific batches
  2. Learning rate too high: Abrupt changes in loss landscape topology
  3. Bad data batches: Outliers or incorrectly processed data
  4. Numerical instability: Overflow/underflow in bf16/fp16

Monitoring code:

import torch
import wandb
from collections import deque

class TrainingMonitor:
    def __init__(self, spike_threshold: float = 3.0, window_size: int = 100):
        self.loss_history = deque(maxlen=window_size)
        self.grad_norm_history = deque(maxlen=window_size)
        self.spike_threshold = spike_threshold
        self.spike_count = 0

    def check_loss_spike(self, current_loss: float) -> bool:
        if len(self.loss_history) < 10:
            self.loss_history.append(current_loss)
            return False

        mean_loss = sum(self.loss_history) / len(self.loss_history)
        if current_loss > mean_loss * self.spike_threshold:
            self.spike_count += 1
            print(f"SPIKE detected: current={current_loss:.4f}, mean={mean_loss:.4f}")
            return True

        self.loss_history.append(current_loss)
        return False

    def compute_grad_norm(self, model) -> float:
        total_norm = 0.0
        for p in model.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        return total_norm ** 0.5

    def log_step(self, step: int, loss: float, model, lr: float):
        grad_norm = self.compute_grad_norm(model)
        is_spike = self.check_loss_spike(loss)

        wandb.log({
            "train/loss": loss,
            "train/grad_norm": grad_norm,
            "train/lr": lr,
            "train/is_spike": int(is_spike),
            "train/spike_count": self.spike_count,
        }, step=step)

        return grad_norm, is_spike

5.2 Gradient Clipping and Noise Scale

def adaptive_gradient_clipping(
    model,
    optimizer,
    max_norm: float = 1.0,
    clip_coef: float = 0.01
):
    """
    Adaptive gradient clipping (AGC)
    Clip per-layer based on parameter norm ratio
    """
    for module in model.modules():
        for name, param in module.named_parameters(recurse=False):
            if param.grad is None:
                continue

            param_norm = param.data.norm(2)
            grad_norm = param.grad.data.norm(2)

            # Clipping threshold: clip_coef times the parameter norm
            max_grad_norm = clip_coef * param_norm

            if grad_norm > max_grad_norm and grad_norm > 0:
                param.grad.data.mul_(max_grad_norm / grad_norm)

def compute_gradient_noise_scale(model, world_size: int) -> float:
    """
    Compute Gradient Noise Scale (GNS)
    GNS = tr(S) / ||g||^2
    High GNS → large batch sizes are tolerable
    """
    grads = [p.grad.data for p in model.parameters() if p.grad is not None]

    g_norm_sq = sum(g.norm(2).item() ** 2 for g in grads)
    variance = sum((g - g.mean()).norm(2).item() ** 2 for g in grads)

    gns = variance / g_norm_sq if g_norm_sq > 0 else 0
    return gns

5.3 Learning Rate Scheduling Strategies

import math

def cosine_schedule_with_warmup(
    current_step: int,
    warmup_steps: int,
    total_steps: int,
    max_lr: float,
    min_lr: float
) -> float:
    """
    Cosine decay with linear warmup.
    Standard for most LLM training runs.
    """
    if current_step < warmup_steps:
        # Linear warmup
        return max_lr * current_step / warmup_steps
    else:
        # Cosine decay
        progress = (current_step - warmup_steps) / (total_steps - warmup_steps)
        cosine_val = 0.5 * (1 + math.cos(math.pi * progress))
        return min_lr + (max_lr - min_lr) * cosine_val

# WSD (Warmup-Stable-Decay) schedule — popular in recent work
def wsd_schedule(
    current_step: int,
    warmup_steps: int,
    stable_steps: int,
    decay_steps: int,
    max_lr: float,
    min_lr: float
) -> float:
    """
    Three-phase schedule: Warmup -> Stable -> Decay
    Used in MiniCPM, LLaMA-3, etc.
    Advantageous for continual learning
    """
    if current_step < warmup_steps:
        return max_lr * current_step / warmup_steps
    elif current_step < warmup_steps + stable_steps:
        return max_lr
    else:
        decay_progress = (current_step - warmup_steps - stable_steps) / decay_steps
        return max_lr - (max_lr - min_lr) * min(decay_progress, 1.0)

5.4 Batch Size Scheduling

# Batch size ramp strategy (as proposed in the OpenAI GPT-3 paper)
def get_batch_size_schedule(
    current_tokens: int,
    final_batch_tokens: int = 4_000_000,   # Final batch size (tokens)
    initial_batch_tokens: int = 32_000,    # Initial batch size
    ramp_tokens: int = 1_200_000_000       # Ramp-up interval (tokens)
) -> int:
    """
    Linearly increase batch size with token throughput.
    Early: small batches for fast convergence.
    Late: large batches for stable training.
    """
    if current_tokens < ramp_tokens:
        progress = current_tokens / ramp_tokens
        batch_tokens = initial_batch_tokens + (
            final_batch_tokens - initial_batch_tokens
        ) * progress
        return int(batch_tokens)
    return final_batch_tokens

6. Checkpointing Strategies

6.1 Distributed Checkpointing

import os
import torch
import torch.distributed as dist
from pathlib import Path

class DistributedCheckpointer:
    def __init__(self, save_dir: str, max_checkpoints: int = 5):
        self.save_dir = Path(save_dir)
        self.max_checkpoints = max_checkpoints
        self.save_dir.mkdir(parents=True, exist_ok=True)

    def save(
        self,
        model,
        optimizer,
        scheduler,
        step: int,
        rank: int,
        world_size: int
    ):
        """Save distributed checkpoint"""
        checkpoint_dir = self.save_dir / f"step_{step:08d}"
        checkpoint_dir.mkdir(parents=True, exist_ok=True)

        # Each rank saves its own shard
        rank_path = checkpoint_dir / f"rank_{rank:04d}_of_{world_size:04d}.pt"

        model_state = {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict() if scheduler else None,
            "step": step,
            "rank": rank,
            "world_size": world_size,
        }

        torch.save(model_state, rank_path)

        # Rank 0 saves metadata
        if rank == 0:
            meta = {
                "step": step,
                "world_size": world_size,
                "timestamp": __import__("time").time()
            }
            torch.save(meta, checkpoint_dir / "meta.pt")

        # Clean up old checkpoints
        if rank == 0:
            self._cleanup_old_checkpoints()

    def _cleanup_old_checkpoints(self):
        checkpoints = sorted(self.save_dir.glob("step_*"))
        while len(checkpoints) > self.max_checkpoints:
            oldest = checkpoints.pop(0)
            import shutil
            shutil.rmtree(oldest)

6.2 Asynchronous Checkpoint Saving

import threading
from queue import Queue

class AsyncCheckpointer:
    """Save checkpoints in a separate thread (no training interruption)"""
    def __init__(self, save_dir: str):
        self.save_dir = save_dir
        self.queue = Queue(maxsize=2)
        self.worker = threading.Thread(target=self._worker, daemon=True)
        self.worker.start()

    def _worker(self):
        while True:
            item = self.queue.get()
            if item is None:
                break
            state_dict, path = item
            torch.save(state_dict, path)
            self.queue.task_done()

    def save_async(self, model, step: int, rank: int):
        """Add to save queue without blocking main thread"""
        # Copy to CPU (free GPU memory)
        state_dict = {
            k: v.cpu().clone() for k, v in model.state_dict().items()
        }
        path = os.path.join(self.save_dir, f"step_{step}_rank_{rank}.pt")

        if not self.queue.full():
            self.queue.put((state_dict, path))
        else:
            # Queue full: save synchronously
            torch.save(state_dict, path)

7. Training Monitoring

7.1 Core Metrics Tracking

import wandb
import numpy as np

class LLMTrainingTracker:
    def __init__(self, project_name: str, run_name: str):
        wandb.init(project=project_name, name=run_name)
        self.step = 0
        self.token_count = 0

    def log_training_step(
        self,
        loss: float,
        learning_rate: float,
        grad_norm: float,
        batch_tokens: int,
        elapsed_seconds: float
    ):
        tokens_per_second = batch_tokens / elapsed_seconds

        self.token_count += batch_tokens
        self.step += 1

        wandb.log({
            # Training metrics
            "train/loss": loss,
            "train/perplexity": np.exp(min(loss, 20)),
            "train/grad_norm": grad_norm,
            "train/learning_rate": learning_rate,
            # Efficiency metrics
            "throughput/tokens_per_second": tokens_per_second,
            "throughput/samples_per_second": tokens_per_second / 2048,
            # Progress
            "progress/total_tokens": self.token_count,
            "progress/step": self.step,
        }, step=self.step)

    def log_evaluation(self, eval_loss: float, perplexities: dict):
        wandb.log({
            "eval/loss": eval_loss,
            "eval/perplexity": np.exp(eval_loss),
            **{f"eval/{k}_ppl": v for k, v in perplexities.items()}
        }, step=self.step)

7.2 Interpreting Loss Curves

Characteristics of a healthy loss curve:

  • Warmup phase: Rapid decrease
  • Stable phase: Slow but steady decrease
  • Late phase: Very gradual decrease

Warning signals:

  • Complete plateau: Learning rate too low, data exhausted
  • Sharp spikes: Bad batches, gradient explosion
  • NaN/Inf: Numerical instability, fp16 overflow
def analyze_loss_curve(losses: list, window: int = 100) -> dict:
    """Automated loss curve analysis"""
    if len(losses) < window * 2:
        return {}

    recent = losses[-window:]
    previous = losses[-2*window:-window]

    recent_mean = np.mean(recent)
    previous_mean = np.mean(previous)
    improvement = (previous_mean - recent_mean) / previous_mean

    global_mean = np.mean(losses)
    global_std = np.std(losses)
    spikes = [l for l in losses if l > global_mean + 3 * global_std]

    return {
        "recent_loss": recent_mean,
        "improvement_rate": improvement,
        "spike_count": len(spikes),
        "is_stagnant": improvement < 0.001,
        "recommendation": "Consider increasing LR" if improvement < 0.001 else "Normal"
    }

8. Open-Source LLM Training Codebases

8.1 GPT-NeoX (EleutherAI)

# Install and run GPT-NeoX
git clone https://github.com/EleutherAI/gpt-neox
cd gpt-neox
pip install -r requirements/requirements.txt

# Configuration file (configs/20B.yml)
# Start training
python deepy.py train configs/20B.yml

GPT-NeoX combines Megatron-based pipeline parallelism with DeepSpeed. EleutherAI's Pythia model series was trained with this codebase.

8.2 OLMo (Allen AI)

# OLMo training configuration (simplified)
from olmo import TrainConfig, ModelConfig, OptimizerConfig

train_config = TrainConfig(
    model=ModelConfig(
        d_model=4096,
        n_heads=32,
        n_layers=32,
        mlp_ratio=8/3,
        vocab_size=50280,
        max_sequence_length=2048,
        attention_type="flash",
    ),
    optimizer=OptimizerConfig(
        name="adamw",
        learning_rate=3e-4,
        weight_decay=0.1,
        betas=(0.9, 0.95),
    ),
    max_duration="300000ba",    # 300,000 batches
    global_train_batch_size=2048,
    device_train_microbatch_size=2,
    precision="bf16",
    fsdp_config={
        "wrapping_strategy": "by_block",
        "precision": "bf16",
        "sharding_strategy": "FULL_SHARD",
    },
)

OLMo aims for full transparency, releasing training data, code, and intermediate checkpoints.

8.3 torchtitan

# torchtitan - PyTorch-native LLM training
# A modern pre-training framework developed by Meta

# torchtitan configuration (TOML format)
config = """
[model]
name = "llama3"
flavor = "8B"
tokenizer_path = "./original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 8
seq_len = 8192
max_norm = 1.0
steps = 10000
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1  # FSDP2
tensor_parallel_degree = 1
compile = true   # Enable torch.compile

[checkpoint]
enable_checkpoint = true
folder = "outputs/checkpoint"
interval_type = "steps"
interval = 500
"""

torchtitan implements FSDP2, Tensor Parallel, and Pipeline Parallel natively in PyTorch.

8.4 FSDP2 (Fully Sharded Data Parallel)

import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

from transformers.models.llama.modeling_llama import LlamaDecoderLayer

# FSDP wrap policy
auto_wrap_policy = transformer_auto_wrap_policy(
    transformer_layer_cls={LlamaDecoderLayer}
)

# FSDP2 (new API in torch 2.x)
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy

mp_policy = MixedPrecisionPolicy(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.float32,
)

for layer in model.model.layers:
    fully_shard(
        layer,
        mp_policy=mp_policy,
        reshard_after_forward=True,   # Similar to ZeRO-3
    )

fully_shard(model, mp_policy=mp_policy)

9. Cost Estimation and Efficiency Strategies

9.1 GPU-Hour Calculation

def estimate_training_cost(
    model_params: float,   # Parameter count (e.g., 7e9)
    training_tokens: float,  # Training tokens (e.g., 2e12)
    num_gpus: int,
    gpu_type: str = "A100-80GB",
    cloud_provider: str = "AWS"
) -> dict:
    """Estimate training cost"""
    # GPU specs (theoretical peak, bf16)
    gpu_specs = {
        "A100-80GB": {"tflops": 312, "memory_gb": 80, "nvlink": True},
        "H100-80GB": {"tflops": 989, "memory_gb": 80, "nvlink": True},
        "A10G-24GB": {"tflops": 125, "memory_gb": 24, "nvlink": False},
        "V100-32GB": {"tflops": 130, "memory_gb": 32, "nvlink": True},
    }

    # Cloud pricing (approximate USD per hour)
    cloud_prices = {
        "AWS": {"A100-80GB": 3.97, "H100-80GB": 8.0, "A10G-24GB": 1.006},
        "GCP": {"A100-80GB": 3.67, "H100-80GB": 7.0},
        "Azure": {"A100-80GB": 3.40, "H100-80GB": 6.0},
        "Lambda": {"A100-80GB": 1.99, "H100-80GB": 3.29},
    }

    spec = gpu_specs[gpu_type]
    price_per_hour = cloud_prices.get(cloud_provider, {}).get(gpu_type, 0)

    # FLOPs calculation
    total_flops = 6 * model_params * training_tokens

    # Assumed MFU (typically 35-55%)
    mfu = 0.45
    effective_tflops = spec["tflops"] * 1e12 * mfu

    # Total training time
    total_seconds = total_flops / (effective_tflops * num_gpus)
    total_hours = total_seconds / 3600
    total_days = total_hours / 24

    # Cost calculation
    total_cost = price_per_hour * num_gpus * total_hours

    return {
        "total_flops": f"{total_flops:.2e}",
        "training_hours": f"{total_hours:.1f}",
        "training_days": f"{total_days:.1f}",
        "gpu_hours": f"{total_hours * num_gpus:.0f}",
        "estimated_cost_usd": f"{total_cost:,.0f}",
        "mfu": mfu,
    }

# Example
result = estimate_training_cost(
    model_params=7e9,
    training_tokens=2e12,
    num_gpus=256,
    gpu_type="A100-80GB",
    cloud_provider="Lambda"
)
for k, v in result.items():
    print(f"{k}: {v}")

9.2 MFU (Model FLOP Utilization) Optimization

def measure_mfu(
    model,
    batch_tokens: int,
    elapsed_seconds: float,
    theoretical_peak_tflops: float
) -> float:
    """
    Measure actual MFU.
    MFU = actual_FLOP_rate / theoretical_peak_FLOP_rate
    """
    num_params = sum(p.numel() for p in model.parameters())
    actual_flops = 6 * num_params * batch_tokens   # forward(2) + backward(4)

    actual_tflops = actual_flops / elapsed_seconds / 1e12
    mfu = actual_tflops / theoretical_peak_tflops

    return mfu

# MFU improvement strategies:
# 1. Use Flash Attention (minimize memory I/O)
# 2. Enable torch.compile (kernel fusion)
# 3. Use activation checkpointing carefully (speed trade-off)
# 4. Choose optimal batch size (maximize GPU occupancy)
# 5. Overlap communication with computation (FSDP/DeepSpeed settings)

9.3 Efficiency Strategy Summary

StrategyMemory SavingsSpeed ImpactImplementation Difficulty
ZeRO-2Medium-5%Easy
ZeRO-3High-15%Medium
Flash Attention 2High+20%Easy
torch.compileNone+15-30%Easy
Activation CheckpointingMedium-30%Easy
bf16 (vs fp16)NoneStability improvementEasy
Sequence packingNone+20%Medium

10. Complete Pre-training Launcher Script

#!/usr/bin/env python3
"""
Complete large-scale LLM pre-training example.
GPT-style model, DeepSpeed ZeRO-2 + Flash Attention.
"""
import os
import time
import math
import torch
import torch.nn as nn
import deepspeed
from torch.utils.data import DataLoader, IterableDataset
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    get_cosine_schedule_with_warmup,
)

# Configuration
TRAINING_CONFIG = {
    "model_name": "meta-llama/Llama-2-7b",
    "total_tokens": 1_000_000_000,    # 1B tokens (demo)
    "max_seq_len": 2048,
    "global_batch_size": 1024,        # Global batch (tokens: 1024 * 2048 = 2M)
    "micro_batch_per_gpu": 4,
    "learning_rate": 3e-4,
    "min_lr": 3e-5,
    "warmup_tokens": 20_000_000,      # 2% warmup
    "weight_decay": 0.1,
    "grad_clip": 1.0,
    "save_every_steps": 1000,
    "eval_every_steps": 500,
    "log_every_steps": 10,
}

class StreamingTokenDataset(IterableDataset):
    """Streaming token dataset (handles large-scale data)"""
    def __init__(self, data_path: str, seq_len: int, rank: int, world_size: int):
        self.data_path = data_path
        self.seq_len = seq_len
        self.rank = rank
        self.world_size = world_size

    def __iter__(self):
        # Each rank processes a different data segment
        with open(self.data_path, "rb") as f:
            f.seek(self.rank * self.seq_len * 2)   # uint16 basis
            while True:
                chunk = f.read(self.seq_len * 2 * self.world_size)
                if not chunk:
                    break
                tokens = torch.frombuffer(chunk, dtype=torch.uint16).long()
                if len(tokens) < self.seq_len + 1:
                    break
                input_ids = tokens[:self.seq_len]
                labels = tokens[1:self.seq_len + 1]
                yield {"input_ids": input_ids, "labels": labels}

def train():
    deepspeed.init_distributed()
    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()
    local_rank = int(os.environ.get("LOCAL_RANK", 0))

    # Load model
    config = AutoConfig.from_pretrained(TRAINING_CONFIG["model_name"])
    with deepspeed.zero.Init():
        model = AutoModelForCausalLM.from_config(config)

    # Initialize DeepSpeed engine
    model_engine, optimizer, _, _ = deepspeed.initialize(
        model=model,
        config={
            "train_micro_batch_size_per_gpu": TRAINING_CONFIG["micro_batch_per_gpu"],
            "gradient_accumulation_steps": (
                TRAINING_CONFIG["global_batch_size"]
                // TRAINING_CONFIG["micro_batch_per_gpu"]
                // world_size
            ),
            "bf16": {"enabled": True},
            "zero_optimization": {
                "stage": 2,
                "overlap_comm": True,
                "contiguous_gradients": True,
                "allgather_bucket_size": 2e8,
                "reduce_bucket_size": 2e8,
            },
            "gradient_clipping": TRAINING_CONFIG["grad_clip"],
            "optimizer": {
                "type": "AdamW",
                "params": {
                    "lr": TRAINING_CONFIG["learning_rate"],
                    "betas": [0.9, 0.95],
                    "eps": 1e-8,
                    "weight_decay": TRAINING_CONFIG["weight_decay"],
                }
            },
            "steps_per_print": TRAINING_CONFIG["log_every_steps"],
        }
    )

    # Dataset
    train_dataset = StreamingTokenDataset(
        "/data/train_tokens.bin",
        TRAINING_CONFIG["max_seq_len"],
        rank, world_size
    )
    train_loader = DataLoader(
        train_dataset,
        batch_size=TRAINING_CONFIG["micro_batch_per_gpu"],
        num_workers=4,
        pin_memory=True,
    )

    # Training loop
    total_tokens = 0
    target_tokens = TRAINING_CONFIG["total_tokens"]

    for batch in train_loader:
        if total_tokens >= target_tokens:
            break

        input_ids = batch["input_ids"].to(model_engine.device)
        labels = batch["labels"].to(model_engine.device)

        # Learning rate schedule
        warmup_tokens = TRAINING_CONFIG["warmup_tokens"]
        lr = cosine_schedule_with_warmup(
            total_tokens, warmup_tokens, target_tokens,
            TRAINING_CONFIG["learning_rate"], TRAINING_CONFIG["min_lr"]
        )
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr

        # Forward/backward
        outputs = model_engine(input_ids=input_ids, labels=labels)
        loss = outputs.loss
        model_engine.backward(loss)
        model_engine.step()

        # Update token count
        batch_tokens = input_ids.numel() * world_size
        total_tokens += batch_tokens

        # Logging (rank 0 only)
        if rank == 0 and model_engine.global_steps % TRAINING_CONFIG["log_every_steps"] == 0:
            print(f"Tokens: {total_tokens/1e9:.2f}B, "
                  f"Loss: {loss.item():.4f}, "
                  f"LR: {lr:.2e}, "
                  f"PPL: {math.exp(loss.item()):.1f}")

        # Save checkpoint
        if model_engine.global_steps % TRAINING_CONFIG["save_every_steps"] == 0:
            model_engine.save_checkpoint(
                "./checkpoints",
                tag=f"step_{model_engine.global_steps}"
            )

    if rank == 0:
        print("Training complete!")
        model_engine.save_checkpoint("./checkpoints", tag="final")

def cosine_schedule_with_warmup(step, warmup, total, max_lr, min_lr):
    if step < warmup:
        return max_lr * step / max(warmup, 1)
    progress = (step - warmup) / max(total - warmup, 1)
    return min_lr + (max_lr - min_lr) * 0.5 * (1 + math.cos(math.pi * progress))

if __name__ == "__main__":
    train()

Launch:

deepspeed --num_nodes=4 --num_gpus=8 \
    --master_addr=node0 --master_port=29500 \
    pretrain.py

Conclusion

Pre-training a 100B+ parameter LLM is not simply running code — it's a careful balancing act of countless engineering decisions and trade-offs.

Key takeaways:

  • Scaling laws: Balance model size and data volume per Chinchilla optimization, but "overtraining" can be rational when considering inference costs
  • Data is fundamental: High-quality data curation and balanced mixing determine model capability
  • 3D parallelism: DP + TP + PP combination efficiently utilizes thousands of GPUs
  • Stability first: Monitoring loss spikes and gradient explosions with fast response is essential
  • Cost awareness: Continuously monitor MFU and GPU utilization to optimize efficiency

The open-source ecosystem (OLMo, GPT-NeoX, torchtitan) is lowering the barrier to large-scale training. Apply the techniques from this guide to your own projects and train your own LLM.

References