Skip to content

필사 모드: Large-Scale Model Training Complete Guide: Strategies for Pre-training 100B+ Parameter LLMs

English
0%
정확도 0%
💡 왼쪽 원문을 읽으면서 오른쪽에 따라 써보세요. Tab 키로 힌트를 받을 수 있습니다.
원문 렌더가 준비되기 전까지 텍스트 가이드로 표시합니다.

Introduction

LLaMA-3 405B, GPT-4, Falcon 180B — how were these models actually trained? Saying "just use a lot of GPUs" grossly oversimplifies the real complexity involved. Training a hundred-billion parameter LLM requires hyperparameter design grounded in scaling laws, sophisticated distributed training strategies, training stability assurance, and efficient management of enormous computing resources.

This guide covers every critical aspect of large-scale LLM pre-training from a practical perspective.

1. Scaling Laws

1.1 Kaplan et al. (OpenAI) Scaling Laws

In 2020, Kaplan et al.'s paper "Scaling Laws for Neural Language Models" demonstrated that language model performance improves as a power law with three key factors:

- **N**: Number of model parameters

- **D**: Number of training data tokens

- **C**: Total compute budget (FLOPs)

Key findings:

1. **Model size priority (compute-efficient view)**: For a fixed compute budget, increasing model size is more efficient than increasing data

2. **Data efficiency**: Training the same model size longer yields diminishing returns

3. **Power law**: Loss L follows approximately L(N) ≈ (Nc/N)^αN

This law suggested investing most of the compute budget in larger models while training on less data than optimal — which is why GPT-3 (175B) was trained on relatively few tokens (300B).

1.2 Chinchilla Optimal Scaling (Hoffmann et al. 2022)

In 2022, DeepMind's Hoffmann et al. published "Training Compute-Optimal Large Language Models" with a critical finding that overturned Kaplan's conclusions.

**Core claim**: Existing large models were **severely undertrained**.

Chinchilla experiment:

- Previous approach: Gopher (280B parameters, 300B tokens)

- Chinchilla: 70B parameters, **1.4T tokens**

- Result: Chinchilla, at a smaller size, decisively outperformed Gopher

**Chinchilla optimal scaling law:**

Given compute budget C, optimal model size N and data D:

N_optimal ≈ 0.1174 × C^0.4999 (parameter count)

D_optimal ≈ 1.6972 × C^0.5001 (token count)

N and D should scale **nearly equally** with compute. The empirical rule of thumb:

D_optimal ≈ 20 × N

7B parameter model → minimum 140B tokens required

70B parameter model → minimum 1.4T tokens required

1.3 Practical Computation Formulas

**FLOPs estimation:**

def estimate_training_flops(

num_params, # Number of parameters

num_tokens, # Training token count

include_backward=True

):

"""

C ≈ 6 × N × D (forward + backward)

Backward pass costs roughly 2x forward pass

"""

flops_per_token = 6 * num_params

total_flops = flops_per_token * num_tokens

return total_flops

Example: LLaMA-2 7B, 2T tokens

flops = estimate_training_flops(7e9, 2e12)

print(f"Total FLOPs: {flops:.2e}")

Total FLOPs: 8.40e+22

GPU training time estimation (A100 312 TFLOPs, 50% MFU assumed)

a100_flops = 312e12 # 312 TFLOPs

mfu = 0.5 # Model FLOP Utilization

training_seconds = flops / (a100_flops * mfu)

training_hours = training_seconds / 3600

With 1000 A100s

num_gpus = 1000

gpu_hours_per_gpu = training_hours / num_gpus

print(f"Training time: {gpu_hours_per_gpu:.1f} GPU-hours per GPU")

print(f"Total GPU-hours: {training_hours:.0f} GPU-hours")

1.4 Post-Chinchilla Trend: "Overtrain" Strategy

In practice, there is a trend toward training beyond Chinchilla-optimal values.

**Reason**: Trade-off between training cost vs. inference cost

- Training happens once, but inference runs millions of times

- Smaller model trained longer → lower inference cost

Example: LLaMA-3 8B was trained on 15T tokens (~100x Chinchilla optimal)

2. Pre-training Data Pipeline

2.1 Data Source Composition

High-quality pre-training data is central to LLM performance.

**Major data sources:**

| Source | Characteristics | Quality |

| ---------------- | ------------------------ | ---------------------- |

| Common Crawl | Web crawl, PB-scale | Noisy, must filter |

| Books3/Gutenberg | Books, high quality | Limited diversity |

| Wikipedia | Encyclopedia, verified | Limited volume |

| GitHub | Code, improves reasoning | License considerations |

| ArXiv/PubMed | Academic papers | High expertise |

| StackExchange | Q&A, practical knowledge | Good quality |

**Typical mixing ratios (estimated, Llama-2 basis):**

data_mixture = {

"Common Crawl (filtered)": 0.67, # 67%

"Books": 0.14, # 14%

"GitHub": 0.045, # 4.5%

"Wikipedia": 0.045, # 4.5%

"Gutenberg": 0.025, # 2.5%

"ArXiv": 0.025, # 2.5%

"StackExchange": 0.02, # 2%

}

2.2 Data Cleaning Pipeline

from typing import List, Optional

from dataclasses import dataclass

@dataclass

class DocumentFilter:

min_tokens: int = 50

max_tokens: int = 100000

min_avg_word_length: float = 3.0

max_symbol_ratio: float = 0.1

min_alpha_ratio: float = 0.7

def filter_document(text: str, config: DocumentFilter) -> Optional[str]:

"""Basic quality filter"""

tokens = text.split()

token_count = len(tokens)

Length filter

if not (config.min_tokens <= token_count <= config.max_tokens):

return None

Average word length

avg_word_len = sum(len(t) for t in tokens) / token_count

if avg_word_len < config.min_avg_word_length:

return None

Alphabetic character ratio

alpha_chars = sum(1 for c in text if c.isalpha())

if alpha_chars / len(text) < config.min_alpha_ratio:

return None

return text

def deduplicate_documents(texts: List[str], n_gram_size: int = 13) -> List[str]:

"""

MinHash LSH-based deduplication

(real implementation uses datasketch library)

"""

from datasketch import MinHash, MinHashLSH

lsh = MinHashLSH(threshold=0.8, num_perm=128)

unique_texts = []

for i, text in enumerate(texts):

minhash = MinHash(num_perm=128)

words = text.lower().split()

for j in range(len(words) - n_gram_size + 1):

ngram = " ".join(words[j:j+n_gram_size])

minhash.update(ngram.encode("utf-8"))

if not lsh.query(minhash):

lsh.insert(str(i), minhash)

unique_texts.append(text)

return unique_texts

2.3 Tokenizer Training

from tokenizers import Tokenizer

from tokenizers.models import BPE

from tokenizers.trainers import BpeTrainer

from tokenizers.pre_tokenizers import ByteLevel

def train_tokenizer(

corpus_files: List[str],

vocab_size: int = 32000,

output_path: str = "tokenizer.json"

):

"""Train BPE tokenizer (SentencePiece-style)"""

tokenizer = Tokenizer(BPE(unk_token="[UNK]"))

tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)

trainer = BpeTrainer(

vocab_size=vocab_size,

min_frequency=2,

special_tokens=["[UNK]", "[BOS]", "[EOS]", "[PAD]"],

show_progress=True

)

tokenizer.train(files=corpus_files, trainer=trainer)

tokenizer.save(output_path)

return tokenizer

def count_tokens(file_path: str, tokenizer) -> int:

total = 0

with open(file_path, "r") as f:

for line in f:

tokens = tokenizer.encode(line.strip())

total += len(tokens.ids)

return total

3. Megatron-LM

3.1 Introduction to NVIDIA Megatron

Megatron-LM is a large-scale language model training framework developed by NVIDIA. It provides specialized parallelism techniques for training large models like GPT, BERT, and T5.

Core features:

- **Tensor Parallelism**: Splits matrix operations across GPUs

- **Pipeline Parallelism**: Distributes layers sequentially across GPUs

- **Sequence Parallelism**: Also distributes along the sequence dimension (Megatron v2)

- **Flash Attention integration**: Memory-efficient attention computation

3.2 Tensor Parallelism Implementation Principles

The core operation in transformers is matrix multiplication. How to partition it is the essence of tensor parallelism.

**Tensor splitting in Multi-Head Attention:**

Q, K, V projection: [d_model, d_head x n_heads]

→ per GPU: [d_model, d_head x (n_heads/tp_size)]

**Tensor splitting in FFN layers:**

First linear: [d_model, d_ffn] → column-wise split

Second linear: [d_ffn, d_model] → row-wise split

→ per GPU: [d_model, d_ffn/tp_size] + [d_ffn/tp_size, d_model]

Conceptual Megatron ColumnParallelLinear (simplified)

class ColumnParallelLinear(nn.Module):

"""Split weight matrix along columns"""

def __init__(self, in_features, out_features, tp_size):

super().__init__()

self.tp_size = tp_size

assert out_features % tp_size == 0

local_out = out_features // tp_size

Each GPU holds 1/tp_size of the full weight

self.weight = nn.Parameter(torch.empty(local_out, in_features))

def forward(self, x):

Input is full size, output is partitioned

return torch.nn.functional.linear(x, self.weight)

Subsequent All-reduce or All-gather needed

class RowParallelLinear(nn.Module):

"""Split weight matrix along rows"""

def __init__(self, in_features, out_features, tp_size):

super().__init__()

self.tp_size = tp_size

assert in_features % tp_size == 0

local_in = in_features // tp_size

Each GPU handles 1/tp_size of the input dimension

self.weight = nn.Parameter(torch.empty(out_features, local_in))

def forward(self, x):

x: [batch, seq, in_features/tp_size]

local_output = torch.nn.functional.linear(x, self.weight)

All-reduce to sum partial results from each GPU

dist.all_reduce(local_output)

return local_output

3.3 Sequence Parallelism

Sequence parallelism distributes LayerNorm and Dropout along the sequence dimension.

Attention input: [batch, seq/tp_size, d_model] (per GPU)

→ All-gather: [batch, seq, d_model]

→ Self-Attention (column split)

→ Reduce-scatter: [batch, seq/tp_size, d_model]

→ FFN (row split)

This reduces LayerNorm memory by tp_size as well.

3.4 Megatron Configuration Example

#!/bin/bash

Training GPT-3 175B with Megatron-LM (example)

GPUS_PER_NODE=8

NNODES=64

TP_SIZE=8 # Tensor Parallelism

PP_SIZE=16 # Pipeline Parallelism

DP_SIZE=$((GPUS_PER_NODE * NNODES / TP_SIZE / PP_SIZE))

DP_SIZE = 8 * 64 / 8 / 16 = 4

torchrun \

--nproc_per_node $GPUS_PER_NODE \

--nnodes $NNODES \

pretrain_gpt.py \

--num-layers 96 \

--hidden-size 12288 \

--num-attention-heads 96 \

--seq-length 2048 \

--max-position-embeddings 2048 \

--global-batch-size 1536 \

--train-iters 300000 \

--lr 6e-5 \

--lr-decay-style cosine \

--min-lr 6e-6 \

--lr-warmup-fraction 0.001 \

--weight-decay 0.1 \

--tensor-model-parallel-size $TP_SIZE \

--pipeline-model-parallel-size $PP_SIZE \

--micro-batch-size 1 \

--bf16 \

--use-flash-attn \

--sequence-parallel \

--data-path /data/pile_text_document \

--vocab-file /data/gpt2-vocab.json \

--merge-file /data/gpt2-merges.txt \

--save /checkpoints/gpt3-175b \

--load /checkpoints/gpt3-175b

4. 3D Parallelism (DP + TP + PP)

4.1 Combining Three Parallelism Types

Large-scale LLM training uses all three parallelism types simultaneously.

| Parallelism Type | What It Distributes | Communication Pattern | Overhead |

| ------------------------- | -------------------- | --------------------- | -------------------------- |

| Data Parallelism (DP) | Batches | All-reduce gradients | Low |

| Tensor Parallelism (TP) | Intra-layer matrices | All-reduce per layer | Medium (latency-sensitive) |

| Pipeline Parallelism (PP) | Layer groups | Point-to-point | Pipeline bubble overhead |

4.2 Finding Optimal Parallelism Configuration

def find_optimal_parallelism(

world_size: int,

num_layers: int,

hidden_size: int,

gpu_memory_gb: float = 80.0

) -> dict:

"""

Search for optimal 3D parallelism configuration.

General rules:

- TP: Within single node (leverage NVLink), typically 4 or 8

- PP: Across nodes, determined by layer count

- DP: Remaining GPUs

"""

configs = []

for tp in [1, 2, 4, 8]:

for pp in range(1, world_size + 1):

dp = world_size // (tp * pp)

if dp < 1:

continue

if world_size != dp * tp * pp:

continue

if num_layers % pp != 0:

continue

Calculate pipeline bubble fraction

With m microbatches and p pipeline stages: bubble = (p-1)/(m+p-1)

micro_batch = 4 # assumed

global_batch = 2048 # assumed

m = global_batch // (dp * micro_batch)

if m <= 0:

continue

bubble_rate = (pp - 1) / (m + pp - 1)

configs.append({

"tp": tp, "pp": pp, "dp": dp,

"bubble_rate": bubble_rate,

"layers_per_stage": num_layers // pp

})

Prioritize low bubble rate and reasonable TP size

configs.sort(key=lambda x: (x["bubble_rate"], x["tp"]))

return configs[:5]

Example: 512 GPUs, GPT-3 scale (96 layers)

best_configs = find_optimal_parallelism(512, 96, 12288)

for c in best_configs:

print(f"TP={c['tp']}, PP={c['pp']}, DP={c['dp']}, "

f"Bubble={c['bubble_rate']:.2%}, Layers/Stage={c['layers_per_stage']}")

4.3 DeepSpeed + Megatron Combination

Megatron-DeepSpeed integration configuration

deepspeed_config = {

"train_batch_size": 2048,

"train_micro_batch_size_per_gpu": 1,

"gradient_accumulation_steps": 256, # 2048 / (8 DP * 1 micro)

"bf16": {"enabled": True},

"zero_optimization": {

"stage": 1, # Use ZeRO-1 alongside TP+PP

},

"gradient_clipping": 1.0,

"wall_clock_breakdown": True,

"steps_per_print": 10

}

4.4 Communication Topology Optimization

Consider NVLink vs InfiniBand bandwidth

NVLink 3.0: 600 GB/s (bidirectional)

InfiniBand HDR: 200 Gb/s = 25 GB/s per port

Communication group design principles:

TP group: GPUs within same node (leverage NVLink)

PP group: Across nodes (leverage InfiniBand)

DP group: Across nodes (minimize communication with ZeRO)

def setup_process_groups(tp_size, pp_size, dp_size):

"""Configure process groups"""

world_size = tp_size * pp_size * dp_size

rank = torch.distributed.get_rank()

Tensor parallel group (recommended within single node)

for dp_rank in range(dp_size):

for pp_rank in range(pp_size):

ranks = [

dp_rank * tp_size * pp_size + pp_rank * tp_size + tp_rank

for tp_rank in range(tp_size)

]

group = torch.distributed.new_group(ranks)

if rank in ranks:

tp_group = group

return tp_group

5. Training Stability

5.1 Handling Loss Spikes

Sudden loss spikes are common in large-scale training.

**Root cause analysis:**

1. **Excessive gradient norm**: Unusually large gradients in specific batches

2. **Learning rate too high**: Abrupt changes in loss landscape topology

3. **Bad data batches**: Outliers or incorrectly processed data

4. **Numerical instability**: Overflow/underflow in bf16/fp16

**Monitoring code:**

from collections import deque

class TrainingMonitor:

def __init__(self, spike_threshold: float = 3.0, window_size: int = 100):

self.loss_history = deque(maxlen=window_size)

self.grad_norm_history = deque(maxlen=window_size)

self.spike_threshold = spike_threshold

self.spike_count = 0

def check_loss_spike(self, current_loss: float) -> bool:

if len(self.loss_history) < 10:

self.loss_history.append(current_loss)

return False

mean_loss = sum(self.loss_history) / len(self.loss_history)

if current_loss > mean_loss * self.spike_threshold:

self.spike_count += 1

print(f"SPIKE detected: current={current_loss:.4f}, mean={mean_loss:.4f}")

return True

self.loss_history.append(current_loss)

return False

def compute_grad_norm(self, model) -> float:

total_norm = 0.0

for p in model.parameters():

if p.grad is not None:

param_norm = p.grad.data.norm(2)

total_norm += param_norm.item() ** 2

return total_norm ** 0.5

def log_step(self, step: int, loss: float, model, lr: float):

grad_norm = self.compute_grad_norm(model)

is_spike = self.check_loss_spike(loss)

wandb.log({

"train/loss": loss,

"train/grad_norm": grad_norm,

"train/lr": lr,

"train/is_spike": int(is_spike),

"train/spike_count": self.spike_count,

}, step=step)

return grad_norm, is_spike

5.2 Gradient Clipping and Noise Scale

def adaptive_gradient_clipping(

model,

optimizer,

max_norm: float = 1.0,

clip_coef: float = 0.01

):

"""

Adaptive gradient clipping (AGC)

Clip per-layer based on parameter norm ratio

"""

for module in model.modules():

for name, param in module.named_parameters(recurse=False):

if param.grad is None:

continue

param_norm = param.data.norm(2)

grad_norm = param.grad.data.norm(2)

Clipping threshold: clip_coef times the parameter norm

max_grad_norm = clip_coef * param_norm

if grad_norm > max_grad_norm and grad_norm > 0:

param.grad.data.mul_(max_grad_norm / grad_norm)

def compute_gradient_noise_scale(model, world_size: int) -> float:

"""

Compute Gradient Noise Scale (GNS)

GNS = tr(S) / ||g||^2

High GNS → large batch sizes are tolerable

"""

grads = [p.grad.data for p in model.parameters() if p.grad is not None]

g_norm_sq = sum(g.norm(2).item() ** 2 for g in grads)

variance = sum((g - g.mean()).norm(2).item() ** 2 for g in grads)

gns = variance / g_norm_sq if g_norm_sq > 0 else 0

return gns

5.3 Learning Rate Scheduling Strategies

def cosine_schedule_with_warmup(

current_step: int,

warmup_steps: int,

total_steps: int,

max_lr: float,

min_lr: float

) -> float:

"""

Cosine decay with linear warmup.

Standard for most LLM training runs.

"""

if current_step < warmup_steps:

Linear warmup

return max_lr * current_step / warmup_steps

else:

Cosine decay

progress = (current_step - warmup_steps) / (total_steps - warmup_steps)

cosine_val = 0.5 * (1 + math.cos(math.pi * progress))

return min_lr + (max_lr - min_lr) * cosine_val

WSD (Warmup-Stable-Decay) schedule — popular in recent work

def wsd_schedule(

current_step: int,

warmup_steps: int,

stable_steps: int,

decay_steps: int,

max_lr: float,

min_lr: float

) -> float:

"""

Three-phase schedule: Warmup -> Stable -> Decay

Used in MiniCPM, LLaMA-3, etc.

Advantageous for continual learning

"""

if current_step < warmup_steps:

return max_lr * current_step / warmup_steps

elif current_step < warmup_steps + stable_steps:

return max_lr

else:

decay_progress = (current_step - warmup_steps - stable_steps) / decay_steps

return max_lr - (max_lr - min_lr) * min(decay_progress, 1.0)

5.4 Batch Size Scheduling

Batch size ramp strategy (as proposed in the OpenAI GPT-3 paper)

def get_batch_size_schedule(

current_tokens: int,

final_batch_tokens: int = 4_000_000, # Final batch size (tokens)

initial_batch_tokens: int = 32_000, # Initial batch size

ramp_tokens: int = 1_200_000_000 # Ramp-up interval (tokens)

) -> int:

"""

Linearly increase batch size with token throughput.

Early: small batches for fast convergence.

Late: large batches for stable training.

"""

if current_tokens < ramp_tokens:

progress = current_tokens / ramp_tokens

batch_tokens = initial_batch_tokens + (

final_batch_tokens - initial_batch_tokens

) * progress

return int(batch_tokens)

return final_batch_tokens

6. Checkpointing Strategies

6.1 Distributed Checkpointing

from pathlib import Path

class DistributedCheckpointer:

def __init__(self, save_dir: str, max_checkpoints: int = 5):

self.save_dir = Path(save_dir)

self.max_checkpoints = max_checkpoints

self.save_dir.mkdir(parents=True, exist_ok=True)

def save(

self,

model,

optimizer,

scheduler,

step: int,

rank: int,

world_size: int

):

"""Save distributed checkpoint"""

checkpoint_dir = self.save_dir / f"step_{step:08d}"

checkpoint_dir.mkdir(parents=True, exist_ok=True)

Each rank saves its own shard

rank_path = checkpoint_dir / f"rank_{rank:04d}_of_{world_size:04d}.pt"

model_state = {

"model": model.state_dict(),

"optimizer": optimizer.state_dict(),

"scheduler": scheduler.state_dict() if scheduler else None,

"step": step,

"rank": rank,

"world_size": world_size,

}

torch.save(model_state, rank_path)

Rank 0 saves metadata

if rank == 0:

meta = {

"step": step,

"world_size": world_size,

"timestamp": __import__("time").time()

}

torch.save(meta, checkpoint_dir / "meta.pt")

Clean up old checkpoints

if rank == 0:

self._cleanup_old_checkpoints()

def _cleanup_old_checkpoints(self):

checkpoints = sorted(self.save_dir.glob("step_*"))

while len(checkpoints) > self.max_checkpoints:

oldest = checkpoints.pop(0)

shutil.rmtree(oldest)

6.2 Asynchronous Checkpoint Saving

from queue import Queue

class AsyncCheckpointer:

"""Save checkpoints in a separate thread (no training interruption)"""

def __init__(self, save_dir: str):

self.save_dir = save_dir

self.queue = Queue(maxsize=2)

self.worker = threading.Thread(target=self._worker, daemon=True)

self.worker.start()

def _worker(self):

while True:

item = self.queue.get()

if item is None:

break

state_dict, path = item

torch.save(state_dict, path)

self.queue.task_done()

def save_async(self, model, step: int, rank: int):

"""Add to save queue without blocking main thread"""

Copy to CPU (free GPU memory)

state_dict = {

k: v.cpu().clone() for k, v in model.state_dict().items()

}

path = os.path.join(self.save_dir, f"step_{step}_rank_{rank}.pt")

if not self.queue.full():

self.queue.put((state_dict, path))

else:

Queue full: save synchronously

torch.save(state_dict, path)

7. Training Monitoring

7.1 Core Metrics Tracking

class LLMTrainingTracker:

def __init__(self, project_name: str, run_name: str):

wandb.init(project=project_name, name=run_name)

self.step = 0

self.token_count = 0

def log_training_step(

self,

loss: float,

learning_rate: float,

grad_norm: float,

batch_tokens: int,

elapsed_seconds: float

):

tokens_per_second = batch_tokens / elapsed_seconds

self.token_count += batch_tokens

self.step += 1

wandb.log({

Training metrics

"train/loss": loss,

"train/perplexity": np.exp(min(loss, 20)),

"train/grad_norm": grad_norm,

"train/learning_rate": learning_rate,

Efficiency metrics

"throughput/tokens_per_second": tokens_per_second,

"throughput/samples_per_second": tokens_per_second / 2048,

Progress

"progress/total_tokens": self.token_count,

"progress/step": self.step,

}, step=self.step)

def log_evaluation(self, eval_loss: float, perplexities: dict):

wandb.log({

"eval/loss": eval_loss,

"eval/perplexity": np.exp(eval_loss),

**{f"eval/{k}_ppl": v for k, v in perplexities.items()}

}, step=self.step)

7.2 Interpreting Loss Curves

Characteristics of a healthy loss curve:

- **Warmup phase**: Rapid decrease

- **Stable phase**: Slow but steady decrease

- **Late phase**: Very gradual decrease

Warning signals:

- **Complete plateau**: Learning rate too low, data exhausted

- **Sharp spikes**: Bad batches, gradient explosion

- **NaN/Inf**: Numerical instability, fp16 overflow

def analyze_loss_curve(losses: list, window: int = 100) -> dict:

"""Automated loss curve analysis"""

if len(losses) < window * 2:

return {}

recent = losses[-window:]

previous = losses[-2*window:-window]

recent_mean = np.mean(recent)

previous_mean = np.mean(previous)

improvement = (previous_mean - recent_mean) / previous_mean

global_mean = np.mean(losses)

global_std = np.std(losses)

spikes = [l for l in losses if l > global_mean + 3 * global_std]

return {

"recent_loss": recent_mean,

"improvement_rate": improvement,

"spike_count": len(spikes),

"is_stagnant": improvement < 0.001,

"recommendation": "Consider increasing LR" if improvement < 0.001 else "Normal"

}

8. Open-Source LLM Training Codebases

8.1 GPT-NeoX (EleutherAI)

Install and run GPT-NeoX

git clone https://github.com/EleutherAI/gpt-neox

cd gpt-neox

pip install -r requirements/requirements.txt

Configuration file (configs/20B.yml)

Start training

python deepy.py train configs/20B.yml

GPT-NeoX combines Megatron-based pipeline parallelism with DeepSpeed. EleutherAI's Pythia model series was trained with this codebase.

8.2 OLMo (Allen AI)

OLMo training configuration (simplified)

from olmo import TrainConfig, ModelConfig, OptimizerConfig

train_config = TrainConfig(

model=ModelConfig(

d_model=4096,

n_heads=32,

n_layers=32,

mlp_ratio=8/3,

vocab_size=50280,

max_sequence_length=2048,

attention_type="flash",

),

optimizer=OptimizerConfig(

name="adamw",

learning_rate=3e-4,

weight_decay=0.1,

betas=(0.9, 0.95),

),

max_duration="300000ba", # 300,000 batches

global_train_batch_size=2048,

device_train_microbatch_size=2,

precision="bf16",

fsdp_config={

"wrapping_strategy": "by_block",

"precision": "bf16",

"sharding_strategy": "FULL_SHARD",

},

)

OLMo aims for full transparency, releasing training data, code, and intermediate checkpoints.

8.3 torchtitan

torchtitan - PyTorch-native LLM training

A modern pre-training framework developed by Meta

torchtitan configuration (TOML format)

config = """

[model]

name = "llama3"

flavor = "8B"

tokenizer_path = "./original/tokenizer.model"

[optimizer]

name = "AdamW"

lr = 3e-4

[training]

batch_size = 8

seq_len = 8192

max_norm = 1.0

steps = 10000

data_parallel_replicate_degree = 1

data_parallel_shard_degree = -1 # FSDP2

tensor_parallel_degree = 1

compile = true # Enable torch.compile

[checkpoint]

enable_checkpoint = true

folder = "outputs/checkpoint"

interval_type = "steps"

interval = 500

"""

torchtitan implements FSDP2, Tensor Parallel, and Pipeline Parallel natively in PyTorch.

8.4 FSDP2 (Fully Sharded Data Parallel)

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

from transformers.models.llama.modeling_llama import LlamaDecoderLayer

FSDP wrap policy

auto_wrap_policy = transformer_auto_wrap_policy(

transformer_layer_cls={LlamaDecoderLayer}

)

FSDP2 (new API in torch 2.x)

from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy

mp_policy = MixedPrecisionPolicy(

param_dtype=torch.bfloat16,

reduce_dtype=torch.float32,

)

for layer in model.model.layers:

fully_shard(

layer,

mp_policy=mp_policy,

reshard_after_forward=True, # Similar to ZeRO-3

)

fully_shard(model, mp_policy=mp_policy)

9. Cost Estimation and Efficiency Strategies

9.1 GPU-Hour Calculation

def estimate_training_cost(

model_params: float, # Parameter count (e.g., 7e9)

training_tokens: float, # Training tokens (e.g., 2e12)

num_gpus: int,

gpu_type: str = "A100-80GB",

cloud_provider: str = "AWS"

) -> dict:

"""Estimate training cost"""

GPU specs (theoretical peak, bf16)

gpu_specs = {

"A100-80GB": {"tflops": 312, "memory_gb": 80, "nvlink": True},

"H100-80GB": {"tflops": 989, "memory_gb": 80, "nvlink": True},

"A10G-24GB": {"tflops": 125, "memory_gb": 24, "nvlink": False},

"V100-32GB": {"tflops": 130, "memory_gb": 32, "nvlink": True},

}

Cloud pricing (approximate USD per hour)

cloud_prices = {

"AWS": {"A100-80GB": 3.97, "H100-80GB": 8.0, "A10G-24GB": 1.006},

"GCP": {"A100-80GB": 3.67, "H100-80GB": 7.0},

"Azure": {"A100-80GB": 3.40, "H100-80GB": 6.0},

"Lambda": {"A100-80GB": 1.99, "H100-80GB": 3.29},

}

spec = gpu_specs[gpu_type]

price_per_hour = cloud_prices.get(cloud_provider, {}).get(gpu_type, 0)

FLOPs calculation

total_flops = 6 * model_params * training_tokens

Assumed MFU (typically 35-55%)

mfu = 0.45

effective_tflops = spec["tflops"] * 1e12 * mfu

Total training time

total_seconds = total_flops / (effective_tflops * num_gpus)

total_hours = total_seconds / 3600

total_days = total_hours / 24

Cost calculation

total_cost = price_per_hour * num_gpus * total_hours

return {

"total_flops": f"{total_flops:.2e}",

"training_hours": f"{total_hours:.1f}",

"training_days": f"{total_days:.1f}",

"gpu_hours": f"{total_hours * num_gpus:.0f}",

"estimated_cost_usd": f"{total_cost:,.0f}",

"mfu": mfu,

}

Example

result = estimate_training_cost(

model_params=7e9,

training_tokens=2e12,

num_gpus=256,

gpu_type="A100-80GB",

cloud_provider="Lambda"

)

for k, v in result.items():

print(f"{k}: {v}")

9.2 MFU (Model FLOP Utilization) Optimization

def measure_mfu(

model,

batch_tokens: int,

elapsed_seconds: float,

theoretical_peak_tflops: float

) -> float:

"""

Measure actual MFU.

MFU = actual_FLOP_rate / theoretical_peak_FLOP_rate

"""

num_params = sum(p.numel() for p in model.parameters())

actual_flops = 6 * num_params * batch_tokens # forward(2) + backward(4)

actual_tflops = actual_flops / elapsed_seconds / 1e12

mfu = actual_tflops / theoretical_peak_tflops

return mfu

MFU improvement strategies:

1. Use Flash Attention (minimize memory I/O)

2. Enable torch.compile (kernel fusion)

3. Use activation checkpointing carefully (speed trade-off)

4. Choose optimal batch size (maximize GPU occupancy)

5. Overlap communication with computation (FSDP/DeepSpeed settings)

9.3 Efficiency Strategy Summary

| Strategy | Memory Savings | Speed Impact | Implementation Difficulty |

| ------------------------ | -------------- | --------------------- | ------------------------- |

| ZeRO-2 | Medium | -5% | Easy |

| ZeRO-3 | High | -15% | Medium |

| Flash Attention 2 | High | +20% | Easy |

| torch.compile | None | +15-30% | Easy |

| Activation Checkpointing | Medium | -30% | Easy |

| bf16 (vs fp16) | None | Stability improvement | Easy |

| Sequence packing | None | +20% | Medium |

10. Complete Pre-training Launcher Script

#!/usr/bin/env python3

"""

Complete large-scale LLM pre-training example.

GPT-style model, DeepSpeed ZeRO-2 + Flash Attention.

"""

from torch.utils.data import DataLoader, IterableDataset

from transformers import (

AutoConfig,

AutoModelForCausalLM,

AutoTokenizer,

get_cosine_schedule_with_warmup,

)

Configuration

TRAINING_CONFIG = {

"model_name": "meta-llama/Llama-2-7b",

"total_tokens": 1_000_000_000, # 1B tokens (demo)

"max_seq_len": 2048,

"global_batch_size": 1024, # Global batch (tokens: 1024 * 2048 = 2M)

"micro_batch_per_gpu": 4,

"learning_rate": 3e-4,

"min_lr": 3e-5,

"warmup_tokens": 20_000_000, # 2% warmup

"weight_decay": 0.1,

"grad_clip": 1.0,

"save_every_steps": 1000,

"eval_every_steps": 500,

"log_every_steps": 10,

}

class StreamingTokenDataset(IterableDataset):

"""Streaming token dataset (handles large-scale data)"""

def __init__(self, data_path: str, seq_len: int, rank: int, world_size: int):

self.data_path = data_path

self.seq_len = seq_len

self.rank = rank

self.world_size = world_size

def __iter__(self):

Each rank processes a different data segment

with open(self.data_path, "rb") as f:

f.seek(self.rank * self.seq_len * 2) # uint16 basis

while True:

chunk = f.read(self.seq_len * 2 * self.world_size)

if not chunk:

break

tokens = torch.frombuffer(chunk, dtype=torch.uint16).long()

if len(tokens) < self.seq_len + 1:

break

input_ids = tokens[:self.seq_len]

labels = tokens[1:self.seq_len + 1]

yield {"input_ids": input_ids, "labels": labels}

def train():

deepspeed.init_distributed()

rank = torch.distributed.get_rank()

world_size = torch.distributed.get_world_size()

local_rank = int(os.environ.get("LOCAL_RANK", 0))

Load model

config = AutoConfig.from_pretrained(TRAINING_CONFIG["model_name"])

with deepspeed.zero.Init():

model = AutoModelForCausalLM.from_config(config)

Initialize DeepSpeed engine

model_engine, optimizer, _, _ = deepspeed.initialize(

model=model,

config={

"train_micro_batch_size_per_gpu": TRAINING_CONFIG["micro_batch_per_gpu"],

"gradient_accumulation_steps": (

TRAINING_CONFIG["global_batch_size"]

// TRAINING_CONFIG["micro_batch_per_gpu"]

// world_size

),

"bf16": {"enabled": True},

"zero_optimization": {

"stage": 2,

"overlap_comm": True,

"contiguous_gradients": True,

"allgather_bucket_size": 2e8,

"reduce_bucket_size": 2e8,

},

"gradient_clipping": TRAINING_CONFIG["grad_clip"],

"optimizer": {

"type": "AdamW",

"params": {

"lr": TRAINING_CONFIG["learning_rate"],

"betas": [0.9, 0.95],

"eps": 1e-8,

"weight_decay": TRAINING_CONFIG["weight_decay"],

}

},

"steps_per_print": TRAINING_CONFIG["log_every_steps"],

}

)

Dataset

train_dataset = StreamingTokenDataset(

"/data/train_tokens.bin",

TRAINING_CONFIG["max_seq_len"],

rank, world_size

)

train_loader = DataLoader(

train_dataset,

batch_size=TRAINING_CONFIG["micro_batch_per_gpu"],

num_workers=4,

pin_memory=True,

)

Training loop

total_tokens = 0

target_tokens = TRAINING_CONFIG["total_tokens"]

for batch in train_loader:

if total_tokens >= target_tokens:

break

input_ids = batch["input_ids"].to(model_engine.device)

labels = batch["labels"].to(model_engine.device)

Learning rate schedule

warmup_tokens = TRAINING_CONFIG["warmup_tokens"]

lr = cosine_schedule_with_warmup(

total_tokens, warmup_tokens, target_tokens,

TRAINING_CONFIG["learning_rate"], TRAINING_CONFIG["min_lr"]

)

for param_group in optimizer.param_groups:

param_group["lr"] = lr

Forward/backward

outputs = model_engine(input_ids=input_ids, labels=labels)

loss = outputs.loss

model_engine.backward(loss)

model_engine.step()

Update token count

batch_tokens = input_ids.numel() * world_size

total_tokens += batch_tokens

Logging (rank 0 only)

if rank == 0 and model_engine.global_steps % TRAINING_CONFIG["log_every_steps"] == 0:

print(f"Tokens: {total_tokens/1e9:.2f}B, "

f"Loss: {loss.item():.4f}, "

f"LR: {lr:.2e}, "

f"PPL: {math.exp(loss.item()):.1f}")

Save checkpoint

if model_engine.global_steps % TRAINING_CONFIG["save_every_steps"] == 0:

model_engine.save_checkpoint(

"./checkpoints",

tag=f"step_{model_engine.global_steps}"

)

if rank == 0:

print("Training complete!")

model_engine.save_checkpoint("./checkpoints", tag="final")

def cosine_schedule_with_warmup(step, warmup, total, max_lr, min_lr):

if step < warmup:

return max_lr * step / max(warmup, 1)

progress = (step - warmup) / max(total - warmup, 1)

return min_lr + (max_lr - min_lr) * 0.5 * (1 + math.cos(math.pi * progress))

if __name__ == "__main__":

train()

Launch:

deepspeed --num_nodes=4 --num_gpus=8 \

--master_addr=node0 --master_port=29500 \

pretrain.py

Conclusion

Pre-training a 100B+ parameter LLM is not simply running code — it's a careful balancing act of countless engineering decisions and trade-offs.

Key takeaways:

- **Scaling laws**: Balance model size and data volume per Chinchilla optimization, but "overtraining" can be rational when considering inference costs

- **Data is fundamental**: High-quality data curation and balanced mixing determine model capability

- **3D parallelism**: DP + TP + PP combination efficiently utilizes thousands of GPUs

- **Stability first**: Monitoring loss spikes and gradient explosions with fast response is essential

- **Cost awareness**: Continuously monitor MFU and GPU utilization to optimize efficiency

The open-source ecosystem (OLMo, GPT-NeoX, torchtitan) is lowering the barrier to large-scale training. Apply the techniques from this guide to your own projects and train your own LLM.

References

- [Chinchilla paper](https://arxiv.org/abs/2203.15556): Training Compute-Optimal Large Language Models

- [Kaplan Scaling Laws](https://arxiv.org/abs/2001.08361): Scaling Laws for Neural Language Models

- [Megatron-LM GitHub](https://github.com/NVIDIA/Megatron-LM)

- [OLMo GitHub](https://github.com/allenai/OLMo)

- [3D Parallelism paper](https://arxiv.org/abs/2104.04473): Efficient Large-Scale Language Model Training

- [torchtitan](https://github.com/pytorch/torchtitan)

- [GPT-NeoX](https://github.com/EleutherAI/gpt-neox)

현재 단락 (1/842)

LLaMA-3 405B, GPT-4, Falcon 180B — how were these models actually trained? Saying "just use a lot of...

작성 글자: 0원문 글자: 30,049작성 단락: 0/842