- Authors

- Name
- Youngju Kim
- @fjvbn20031
Introduction
When GPT-3 arrived in 2020, "bigger is better" became the dominant paradigm in LLM development. Then in 2022, DeepMind's Chinchilla paper challenged that intuition head-on: a model with 4x fewer parameters but 4x more training data outperformed GPT-3. This guide covers everything from the mathematics of scaling laws to the cutting-edge pretraining recipes used in Llama 3.1, DeepSeek-V3, and phi-4.
1. Scaling Laws
1.1 Kaplan et al. (2020) — The OpenAI Scaling Laws
Kaplan et al. empirically demonstrated that language model loss follows power laws with respect to the number of parameters (N), dataset size (D), and compute budget (C).
The key takeaway: given a fixed compute budget, investing overwhelmingly in model size is optimal. This became the theoretical justification for building models like GPT-3 (175B parameters).
1.2 Chinchilla (Hoffmann et al. 2022) — The Revised Scaling Laws
The DeepMind team showed through more rigorous experimental design that Kaplan's conclusion substantially underestimated the value of data.
The Chinchilla finding:
Given a compute budget C (in FLOPs), the optimal model size and optimal token count satisfy:
The rule of thumb: roughly 20 training tokens per parameter for optimal compute allocation.
| Model | Parameters | Training Tokens | Chinchilla Ratio |
|---|---|---|---|
| GPT-3 | 175B | 300B | 1:1.7 (data-starved) |
| Chinchilla | 70B | 1.4T | 1:20 (optimal) |
| Llama 3.1 | 8B | 15T | 1:1875 (over-trained) |
Llama 3.1's deliberately over-trained approach is driven by inference efficiency: a smaller model trained longer costs less to deploy, even if it uses more compute during training.
1.3 Computing Optimal Resource Allocation
import math
def chinchilla_optimal(compute_budget_flops: float):
"""
Compute optimal N and D given a compute budget (Chinchilla scaling laws).
compute_budget_flops: total FLOPs (approximated as 6 * N * D)
Returns: (optimal_params, optimal_tokens)
"""
# From Hoffmann et al. Table A3 coefficients
# With C = 6 * N * D approximation:
# N_opt = sqrt(C / (6 * 20)), D_opt = 20 * N_opt
N_opt = math.sqrt(compute_budget_flops / (6 * 20))
D_opt = 20 * N_opt
return N_opt, D_opt
# Example: 1000x A100 GPUs, 30 days of training
# A100: ~312 TFLOPS (bf16), assuming 40% utilization
flops_per_sec = 1000 * 312e12 * 0.4
duration_sec = 30 * 24 * 3600
total_flops = flops_per_sec * duration_sec
N_opt, D_opt = chinchilla_optimal(total_flops)
print(f"Total FLOPs: {total_flops:.2e}")
print(f"Optimal parameters: {N_opt/1e9:.1f}B")
print(f"Optimal tokens: {D_opt/1e12:.1f}T")
2. Data Preparation
2.1 Common Crawl Filtering Pipeline
Common Crawl provides tens of terabytes of web-crawled data monthly. Raw, it is extremely noisy — a multi-stage filtering pipeline is essential.
Typical filtering stages:
- Language detection (fastText): retain only target language(s)
- Quality filters: minimum sentence count, word repetition ratio, special character density
- Domain blocklist: spam, adult content, ad-heavy domains
- Perplexity filter: remove low-quality text using an n-gram language model
- Deduplication: MinHash LSH-based fuzzy deduplication
2.2 MinHash Deduplication
from datasketch import MinHash, MinHashLSH
import re
def get_shingles(text: str, k: int = 5):
"""Generate character k-shingle set from text."""
text = re.sub(r'\s+', ' ', text.lower())
return {text[i:i+k] for i in range(len(text) - k + 1)}
def build_minhash(text: str, num_perm: int = 128) -> MinHash:
"""Create a MinHash signature from text."""
m = MinHash(num_perm=num_perm)
for shingle in get_shingles(text):
m.update(shingle.encode('utf-8'))
return m
def deduplicate_corpus(documents: list, threshold: float = 0.8):
"""
Remove near-duplicate documents using MinHash LSH.
threshold: Jaccard similarity threshold for considering duplicates.
"""
lsh = MinHashLSH(threshold=threshold, num_perm=128)
unique_docs = []
for idx, doc in enumerate(documents):
mh = build_minhash(doc)
result = lsh.query(mh)
if len(result) == 0:
lsh.insert(f"doc_{idx}", mh)
unique_docs.append(doc)
# If similar doc already exists, skip
print(f"Original: {len(documents)} docs -> After dedup: {len(unique_docs)} docs")
return unique_docs
The core idea: the Jaccard similarity between two sets can be estimated by the probability that the minimum hash value under a random permutation is the same for both sets. With 128 hash functions, Jaccard similarity is estimated within roughly 3% error. LSH bins documents into buckets so only candidate pairs are compared — bringing overall complexity close to O(N).
2.3 Tokenizer Training
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.processors import ByteLevel as ByteLevelProcessor
def train_bpe_tokenizer(
corpus_files: list,
vocab_size: int = 32000,
save_path: str = "tokenizer.json"
):
"""Train a GPT-2-style Byte-level BPE tokenizer."""
tokenizer = Tokenizer(BPE(unk_token=None))
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)
tokenizer.post_processor = ByteLevelProcessor(trim_offsets=False)
trainer = BpeTrainer(
vocab_size=vocab_size,
min_frequency=2,
special_tokens=["<pad>", "<s>", "</s>", "<unk>"],
show_progress=True,
)
tokenizer.train(files=corpus_files, trainer=trainer)
tokenizer.save(save_path)
print(f"Tokenizer saved: {save_path} (vocab_size={vocab_size})")
return tokenizer
2.4 Data Mixture Ratios
Llama 3.1's pretraining data mixture as a reference:
| Data Source | Proportion |
|---|---|
| Web crawl (Common Crawl, etc.) | ~80% |
| Code (GitHub) | ~8% |
| Academic / scientific papers | ~5% |
| Books (Books3, etc.) | ~4% |
| Multilingual data | ~3% |
3. Architecture Choices
3.1 Key Design Decisions Across Open-Source LLMs
| Model | Positional Encoding | Attention | FFN | Normalization |
|---|---|---|---|---|
| GPT-NeoX | ALiBi | MHA | SwiGLU | LayerNorm |
| LLaMA-3 | RoPE | GQA | SwiGLU | RMSNorm |
| Mistral 7B | RoPE | GQA + SWA | SwiGLU | RMSNorm |
| Mixtral 8x7B | RoPE | GQA | MoE (SwiGLU) | RMSNorm |
3.2 RoPE (Rotary Position Embedding)
Rather than adding absolute positional information to token embeddings, RoPE applies a rotation transformation to Query and Key vectors before the dot product.
The result: the inner product between a query at position m and a key at position n depends only on the relative offset (m-n). This relative nature enables better extrapolation to sequence lengths beyond what was seen during training.
Extensions like YaRN and LongRoPE use this property to extend context to hundreds of thousands of tokens.
3.3 Grouped Query Attention (GQA)
During autoregressive inference, previous tokens' Key and Value tensors must be stored in the KV cache. GQA reduces this cache by having multiple Query heads share a smaller number of Key/Value heads.
- MHA: H query heads, H key heads, H value heads
- MQA: H query heads, 1 key head, 1 value head
- GQA: H query heads, G key heads, G value heads (G < H)
Llama 3 8B uses 32 query heads with 8 KV heads — reducing KV cache to 1/4 of MHA.
3.4 Mixture of Experts (MoE)
Mixtral 8x7B and DeepSeek-V3 use MoE architectures. Each token activates only Top-K out of N expert FFN layers.
import torch
import torch.nn as nn
import torch.nn.functional as F
class MoELayer(nn.Module):
"""Simple Mixture of Experts layer."""
def __init__(self, d_model: int, d_ff: int, num_experts: int = 8, top_k: int = 2):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.router = nn.Linear(d_model, num_experts, bias=False)
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(d_model, d_ff),
nn.SiLU(),
nn.Linear(d_ff, d_model),
)
for _ in range(num_experts)
])
def forward(self, x: torch.Tensor):
# x: (batch, seq_len, d_model)
B, T, D = x.shape
x_flat = x.view(-1, D) # (B*T, D)
# Router
router_logits = self.router(x_flat) # (B*T, num_experts)
router_probs = F.softmax(router_logits, dim=-1)
# Top-K expert selection
topk_probs, topk_idx = router_probs.topk(self.top_k, dim=-1)
topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True) # renormalize
# Load Balancing Loss
expert_load = router_probs.mean(0) # (num_experts,)
load_balance_loss = self.num_experts * (expert_load * expert_load.mean()).sum()
# Compute expert outputs
output = torch.zeros_like(x_flat)
for k in range(self.top_k):
expert_indices = topk_idx[:, k]
expert_weights = topk_probs[:, k].unsqueeze(-1)
for e_idx in range(self.num_experts):
mask = (expert_indices == e_idx)
if mask.any():
expert_out = self.experts[e_idx](x_flat[mask])
output[mask] += expert_weights[mask] * expert_out
return output.view(B, T, D), load_balance_loss
Why load balancing loss is necessary: Without it, the router collapses to always selecting the same expert (expert collapse). The chosen expert improves most, gets selected even more, and the rest never train — a positive feedback loop. Load balancing loss penalizes unequal distribution by adding an auxiliary loss term that encourages all experts to receive roughly equal token counts.
4. Training Stability
4.1 Learning Rate Schedule: Cosine Warmup
import math
def cosine_lr_with_warmup(
optimizer,
step: int,
warmup_steps: int,
total_steps: int,
max_lr: float,
min_lr: float = 0.0,
):
"""Cosine Annealing with Linear Warmup."""
if step < warmup_steps:
lr = max_lr * step / warmup_steps
else:
progress = (step - warmup_steps) / (total_steps - warmup_steps)
lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr
# Typical pretraining settings:
# warmup: 1-2% of total steps
# max_lr: 3e-4 (for ~7B model)
# min_lr: max_lr * 0.1
4.2 Loss Spike Detection and Recovery
Large-scale pretraining runs frequently encounter sudden loss spikes. Standard mitigation strategies:
- Gradient norm clipping: scale down gradients when the norm exceeds a threshold
- Spike detection with rollback: restart from a previous checkpoint, skipping the offending batch
- Exponential moving average monitoring: trigger alerts on sudden loss increases
import torch
def train_step_with_stability(model, optimizer, batch, grad_clip: float = 1.0):
"""Training step with gradient norm clipping and spike detection."""
optimizer.zero_grad()
loss = model(batch)
loss.backward()
# Compute and clip gradient norm
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=grad_clip
)
# Detect abnormally large gradients
if grad_norm > 100 * grad_clip:
print(f"WARNING: abnormal grad_norm={grad_norm:.2f}, skipping step")
optimizer.zero_grad()
return None, grad_norm
optimizer.step()
return loss.item(), grad_norm
5. Efficient Pretraining
5.1 Flash Attention 2
Flash Attention rewrites the attention computation to maximize GPU SRAM utilization, reducing memory complexity from to .
# Flash Attention 2 usage
# pip install flash-attn --no-build-isolation
import torch
from flash_attn import flash_attn_func
def flash_attention_forward(
q: torch.Tensor, # (batch, seqlen, nheads, headdim)
k: torch.Tensor,
v: torch.Tensor,
causal: bool = True,
softmax_scale: float = None,
):
"""
Flash Attention 2 forward pass.
causal=True: applies causal mask for autoregressive LMs.
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out = flash_attn_func(
q, k, v,
dropout_p=0.0,
softmax_scale=softmax_scale,
causal=causal,
)
return out # (batch, seqlen, nheads, headdim)
Flash Attention 2 compared to standard attention:
- Memory: linear in sequence length (vs. quadratic)
- Speed: 2-4x faster on A100
- Numerical accuracy: guaranteed identical results
5.2 Sliding Window Attention (SWA)
Used in Mistral 7B, SWA limits each token's attention to the most recent W tokens. This reduces attention complexity to for long sequences while maintaining local context.
6. Evaluation and Checkpointing
6.1 Monitoring the Perplexity Curve
import torch
import math
def compute_perplexity(model, dataloader, device: str = "cuda"):
"""Compute perplexity on a validation set."""
model.eval()
total_loss = 0.0
total_tokens = 0
with torch.no_grad():
for batch in dataloader:
input_ids = batch["input_ids"].to(device)
labels = batch["labels"].to(device)
outputs = model(input_ids=input_ids, labels=labels)
batch_tokens = (labels != -100).sum().item()
total_loss += outputs.loss.item() * batch_tokens
total_tokens += batch_tokens
avg_loss = total_loss / total_tokens
perplexity = math.exp(avg_loss)
return perplexity
# Typical benchmarks:
# GPT-2 level: PPL ~20 (WebText)
# Good 7B model: PPL ~6-8 (validation set)
6.2 Running lm-evaluation-harness Benchmarks
# Install and run lm-evaluation-harness
pip install lm-eval
# Evaluate Llama 3 8B on standard benchmarks
lm_eval --model hf \
--model_args pretrained=meta-llama/Meta-Llama-3-8B \
--tasks mmlu,hellaswag,arc_challenge,winogrande \
--device cuda:0 \
--batch_size 8 \
--output_path results/llama3-8b/
# Evaluate a specific checkpoint during training
lm_eval --model hf \
--model_args pretrained=./checkpoints/step-50000 \
--tasks hellaswag \
--num_fewshot 10 \
--device cuda:0
6.3 Checkpoint Strategy
| Checkpoint Type | Save Interval | Retention |
|---|---|---|
| Rolling recent | Every 500 steps | Keep last 5 |
| Milestone | Every 5,000 steps | Permanent |
| Best validation | Based on PPL | Permanent |
7. Recent Pretraining Recipes (2024-2025)
7.1 Llama 3.1
- Training data: 15T tokens (7.5x Llama 2)
- Context length: 128K (via RoPE extension)
- Vocabulary size: 128K (tiktoken-based)
- Notable: post-pretraining long-context annealing phase
7.2 Mistral Large / Mistral 7B
- Mistral 7B: GQA + SWA for maximal efficiency
- Mixtral 8x7B: 8 experts, 2 active per token (13B parameters in use per forward pass)
- Context: 32K via SWA
7.3 DeepSeek-V3
- Architecture: 671B MoE, 37B active parameters per token
- Training cost: ~$5.57M on 2,048x H800 for 60 days (10x cheaper than comparable models)
- Innovations: Multi-head Latent Attention (MLA), FP8 mixed-precision training
- Data: 14.8T tokens with heavy Chinese and code representation
7.4 phi-4 (Microsoft)
- Size: 14B parameters
- Strategy: data quality over data quantity — heavy synthetic data augmentation
- Training data: 9.8T tokens (~40% synthetic)
- Achievement: competitive with 70B+ models on math and reasoning benchmarks
Quiz
Q1. What is the core argument of the Chinchilla paper that showed smaller models can outperform GPT-3?
Answer: GPT-3 trained 175B parameters on only 300B tokens. Chinchilla showed that parameters and training tokens should scale together at a 1:20 ratio — GPT-3 was severely data-starved.
Explanation: Hoffmann et al. 2022 demonstrated through rigorous experiments that the Kaplan scaling laws undervalued data. Under the same compute budget, a 70B parameter model trained on 1.4T tokens (Chinchilla) outperformed GPT-3 (175B / 300B tokens) on nearly all benchmarks. The key insight is that the optimal number of training tokens is proportional to the square root of the compute budget, not just the model size.
Q2. How does the MinHash algorithm efficiently detect near-duplicates in a large text corpus?
Answer: Documents are represented as shingle (n-gram) sets, then compressed to fixed-length MinHash signatures that approximate Jaccard similarity. LSH groups similar signatures into the same bucket, so only candidate pairs need precise comparison.
Explanation: Direct Jaccard similarity between all pairs requires O(N^2) comparisons. MinHash represents each document as k hash minimums. Two signatures' agreement rate is an unbiased estimator of Jaccard similarity. LSH further reduces comparisons to near O(N) by only comparing documents that hash to the same bucket, making corpus-scale deduplication tractable.
Q3. How does Grouped Query Attention (GQA) reduce inference memory compared to Multi-Head Attention?
Answer: GQA uses fewer Key/Value heads than Query heads. Since the KV cache scales with the number of KV heads (not Query heads), the cache is reduced by a factor of KV heads / Query heads.
Explanation: During autoregressive inference, all previous tokens' Keys and Values must be cached. With MHA, the cache size is batch_size x seq_len x num_heads x head_dim x 2. GQA with G KV heads (G < H query heads) reduces cache to G/H of MHA. Llama 3 8B (H=32 query heads, G=8 KV heads) uses 1/4 the KV cache of equivalent MHA.
Q4. Why is RoPE (Rotary Position Embedding) better suited for long-context extrapolation than absolute positional embeddings?
Answer: RoPE makes attention scores a function of relative position (m-n) rather than absolute positions. This means the model generalizes to position offsets it has not seen during training, enabling length extrapolation.
Explanation: Absolute positional embeddings (sinusoidal or learned) are added to token embeddings at specific indices. Positions beyond the training length are out-of-distribution. RoPE applies rotation matrices to Q and K such that their dot product depends only on the difference between their positions. Extensions like YaRN and LongRoPE leverage this property to extend context windows from 4K to 128K+ tokens with minimal fine-tuning.
Q5. Why is a load balancing loss necessary in MoE expert routing?
Answer: Without it, the router collapses to always choosing the same expert (expert collapse). That expert improves most, gets selected more, and the cycle repeats — other experts receive no gradients and remain undertrained.
Explanation: The router's softmax produces expert selection probabilities. Even with uniform initialization, a positive feedback loop causes one or a few experts to dominate. Load balancing loss adds an auxiliary penalty when the distribution of tokens to experts is uneven. The loss term equals num_experts times the sum of (expert fraction) times (average routing probability), encouraging a uniform distribution and ensuring all experts are adequately trained.
Conclusion
LLM pretraining is where theoretical scaling laws, practical data engineering, and systems optimization intersect. Chinchilla proved mathematically that "bigger is not always better" — data efficiency matters as much as parameter count. Flash Attention and GQA made large-scale training and deployment economically feasible. DeepSeek-V3 demonstrated that MoE architecture combined with efficient implementation can achieve frontier performance at a fraction of the cost. The concepts in this guide form the foundation for understanding and conducting LLM pretraining experiments at any scale.