Skip to content
Published on

LLM Pretraining & Scaling Laws: From Chinchilla to Flash Attention and MoE

Authors

Introduction

When GPT-3 arrived in 2020, "bigger is better" became the dominant paradigm in LLM development. Then in 2022, DeepMind's Chinchilla paper challenged that intuition head-on: a model with 4x fewer parameters but 4x more training data outperformed GPT-3. This guide covers everything from the mathematics of scaling laws to the cutting-edge pretraining recipes used in Llama 3.1, DeepSeek-V3, and phi-4.


1. Scaling Laws

1.1 Kaplan et al. (2020) — The OpenAI Scaling Laws

Kaplan et al. empirically demonstrated that language model loss follows power laws with respect to the number of parameters (N), dataset size (D), and compute budget (C).

L(N)(NcN)αNL(N) \approx \left(\frac{N_c}{N}\right)^{\alpha_N}

L(D)(DcD)αDL(D) \approx \left(\frac{D_c}{D}\right)^{\alpha_D}

The key takeaway: given a fixed compute budget, investing overwhelmingly in model size is optimal. This became the theoretical justification for building models like GPT-3 (175B parameters).

1.2 Chinchilla (Hoffmann et al. 2022) — The Revised Scaling Laws

The DeepMind team showed through more rigorous experimental design that Kaplan's conclusion substantially underestimated the value of data.

The Chinchilla finding:

Given a compute budget C (in FLOPs), the optimal model size NoptN_{opt} and optimal token count DoptD_{opt} satisfy:

Nopt0.2C0.5N_{opt} \approx 0.2 \cdot C^{0.5}

Dopt10C0.5D_{opt} \approx 10 \cdot C^{0.5}

The rule of thumb: roughly 20 training tokens per parameter for optimal compute allocation.

ModelParametersTraining TokensChinchilla Ratio
GPT-3175B300B1:1.7 (data-starved)
Chinchilla70B1.4T1:20 (optimal)
Llama 3.18B15T1:1875 (over-trained)

Llama 3.1's deliberately over-trained approach is driven by inference efficiency: a smaller model trained longer costs less to deploy, even if it uses more compute during training.

1.3 Computing Optimal Resource Allocation

import math

def chinchilla_optimal(compute_budget_flops: float):
    """
    Compute optimal N and D given a compute budget (Chinchilla scaling laws).
    compute_budget_flops: total FLOPs (approximated as 6 * N * D)
    Returns: (optimal_params, optimal_tokens)
    """
    # From Hoffmann et al. Table A3 coefficients
    # With C = 6 * N * D approximation:
    # N_opt = sqrt(C / (6 * 20)), D_opt = 20 * N_opt
    N_opt = math.sqrt(compute_budget_flops / (6 * 20))
    D_opt = 20 * N_opt
    return N_opt, D_opt

# Example: 1000x A100 GPUs, 30 days of training
# A100: ~312 TFLOPS (bf16), assuming 40% utilization
flops_per_sec = 1000 * 312e12 * 0.4
duration_sec = 30 * 24 * 3600
total_flops = flops_per_sec * duration_sec

N_opt, D_opt = chinchilla_optimal(total_flops)
print(f"Total FLOPs: {total_flops:.2e}")
print(f"Optimal parameters: {N_opt/1e9:.1f}B")
print(f"Optimal tokens: {D_opt/1e12:.1f}T")

2. Data Preparation

2.1 Common Crawl Filtering Pipeline

Common Crawl provides tens of terabytes of web-crawled data monthly. Raw, it is extremely noisy — a multi-stage filtering pipeline is essential.

Typical filtering stages:

  1. Language detection (fastText): retain only target language(s)
  2. Quality filters: minimum sentence count, word repetition ratio, special character density
  3. Domain blocklist: spam, adult content, ad-heavy domains
  4. Perplexity filter: remove low-quality text using an n-gram language model
  5. Deduplication: MinHash LSH-based fuzzy deduplication

2.2 MinHash Deduplication

from datasketch import MinHash, MinHashLSH
import re

def get_shingles(text: str, k: int = 5):
    """Generate character k-shingle set from text."""
    text = re.sub(r'\s+', ' ', text.lower())
    return {text[i:i+k] for i in range(len(text) - k + 1)}

def build_minhash(text: str, num_perm: int = 128) -> MinHash:
    """Create a MinHash signature from text."""
    m = MinHash(num_perm=num_perm)
    for shingle in get_shingles(text):
        m.update(shingle.encode('utf-8'))
    return m

def deduplicate_corpus(documents: list, threshold: float = 0.8):
    """
    Remove near-duplicate documents using MinHash LSH.
    threshold: Jaccard similarity threshold for considering duplicates.
    """
    lsh = MinHashLSH(threshold=threshold, num_perm=128)
    unique_docs = []

    for idx, doc in enumerate(documents):
        mh = build_minhash(doc)
        result = lsh.query(mh)

        if len(result) == 0:
            lsh.insert(f"doc_{idx}", mh)
            unique_docs.append(doc)
        # If similar doc already exists, skip

    print(f"Original: {len(documents)} docs -> After dedup: {len(unique_docs)} docs")
    return unique_docs

The core idea: the Jaccard similarity J(A,B)=AB/ABJ(A,B) = |A \cap B| / |A \cup B| between two sets can be estimated by the probability that the minimum hash value under a random permutation is the same for both sets. With 128 hash functions, Jaccard similarity is estimated within roughly 3% error. LSH bins documents into buckets so only candidate pairs are compared — bringing overall complexity close to O(N).

2.3 Tokenizer Training

from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.processors import ByteLevel as ByteLevelProcessor

def train_bpe_tokenizer(
    corpus_files: list,
    vocab_size: int = 32000,
    save_path: str = "tokenizer.json"
):
    """Train a GPT-2-style Byte-level BPE tokenizer."""
    tokenizer = Tokenizer(BPE(unk_token=None))
    tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)
    tokenizer.post_processor = ByteLevelProcessor(trim_offsets=False)

    trainer = BpeTrainer(
        vocab_size=vocab_size,
        min_frequency=2,
        special_tokens=["<pad>", "<s>", "</s>", "<unk>"],
        show_progress=True,
    )

    tokenizer.train(files=corpus_files, trainer=trainer)
    tokenizer.save(save_path)
    print(f"Tokenizer saved: {save_path} (vocab_size={vocab_size})")
    return tokenizer

2.4 Data Mixture Ratios

Llama 3.1's pretraining data mixture as a reference:

Data SourceProportion
Web crawl (Common Crawl, etc.)~80%
Code (GitHub)~8%
Academic / scientific papers~5%
Books (Books3, etc.)~4%
Multilingual data~3%

3. Architecture Choices

3.1 Key Design Decisions Across Open-Source LLMs

ModelPositional EncodingAttentionFFNNormalization
GPT-NeoXALiBiMHASwiGLULayerNorm
LLaMA-3RoPEGQASwiGLURMSNorm
Mistral 7BRoPEGQA + SWASwiGLURMSNorm
Mixtral 8x7BRoPEGQAMoE (SwiGLU)RMSNorm

3.2 RoPE (Rotary Position Embedding)

Rather than adding absolute positional information to token embeddings, RoPE applies a rotation transformation to Query and Key vectors before the dot product.

The result: the inner product between a query at position m and a key at position n depends only on the relative offset (m-n). This relative nature enables better extrapolation to sequence lengths beyond what was seen during training.

qmTkn=Re[j(q(j)eimθj)(k(j)einθj)]q_m^T k_n = \text{Re}\left[\sum_{j} (q_{(j)} e^{im\theta_j}) \overline{(k_{(j)} e^{in\theta_j})}\right]

Extensions like YaRN and LongRoPE use this property to extend context to hundreds of thousands of tokens.

3.3 Grouped Query Attention (GQA)

During autoregressive inference, previous tokens' Key and Value tensors must be stored in the KV cache. GQA reduces this cache by having multiple Query heads share a smaller number of Key/Value heads.

  • MHA: H query heads, H key heads, H value heads
  • MQA: H query heads, 1 key head, 1 value head
  • GQA: H query heads, G key heads, G value heads (G < H)

Llama 3 8B uses 32 query heads with 8 KV heads — reducing KV cache to 1/4 of MHA.

3.4 Mixture of Experts (MoE)

Mixtral 8x7B and DeepSeek-V3 use MoE architectures. Each token activates only Top-K out of N expert FFN layers.

import torch
import torch.nn as nn
import torch.nn.functional as F

class MoELayer(nn.Module):
    """Simple Mixture of Experts layer."""
    def __init__(self, d_model: int, d_ff: int, num_experts: int = 8, top_k: int = 2):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.router = nn.Linear(d_model, num_experts, bias=False)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_ff),
                nn.SiLU(),
                nn.Linear(d_ff, d_model),
            )
            for _ in range(num_experts)
        ])

    def forward(self, x: torch.Tensor):
        # x: (batch, seq_len, d_model)
        B, T, D = x.shape
        x_flat = x.view(-1, D)  # (B*T, D)

        # Router
        router_logits = self.router(x_flat)  # (B*T, num_experts)
        router_probs = F.softmax(router_logits, dim=-1)

        # Top-K expert selection
        topk_probs, topk_idx = router_probs.topk(self.top_k, dim=-1)
        topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True)  # renormalize

        # Load Balancing Loss
        expert_load = router_probs.mean(0)  # (num_experts,)
        load_balance_loss = self.num_experts * (expert_load * expert_load.mean()).sum()

        # Compute expert outputs
        output = torch.zeros_like(x_flat)
        for k in range(self.top_k):
            expert_indices = topk_idx[:, k]
            expert_weights = topk_probs[:, k].unsqueeze(-1)
            for e_idx in range(self.num_experts):
                mask = (expert_indices == e_idx)
                if mask.any():
                    expert_out = self.experts[e_idx](x_flat[mask])
                    output[mask] += expert_weights[mask] * expert_out

        return output.view(B, T, D), load_balance_loss

Why load balancing loss is necessary: Without it, the router collapses to always selecting the same expert (expert collapse). The chosen expert improves most, gets selected even more, and the rest never train — a positive feedback loop. Load balancing loss penalizes unequal distribution by adding an auxiliary loss term that encourages all experts to receive roughly equal token counts.


4. Training Stability

4.1 Learning Rate Schedule: Cosine Warmup

import math

def cosine_lr_with_warmup(
    optimizer,
    step: int,
    warmup_steps: int,
    total_steps: int,
    max_lr: float,
    min_lr: float = 0.0,
):
    """Cosine Annealing with Linear Warmup."""
    if step < warmup_steps:
        lr = max_lr * step / warmup_steps
    else:
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

# Typical pretraining settings:
# warmup: 1-2% of total steps
# max_lr: 3e-4 (for ~7B model)
# min_lr: max_lr * 0.1

4.2 Loss Spike Detection and Recovery

Large-scale pretraining runs frequently encounter sudden loss spikes. Standard mitigation strategies:

  1. Gradient norm clipping: scale down gradients when the norm exceeds a threshold
  2. Spike detection with rollback: restart from a previous checkpoint, skipping the offending batch
  3. Exponential moving average monitoring: trigger alerts on sudden loss increases
import torch

def train_step_with_stability(model, optimizer, batch, grad_clip: float = 1.0):
    """Training step with gradient norm clipping and spike detection."""
    optimizer.zero_grad()
    loss = model(batch)
    loss.backward()

    # Compute and clip gradient norm
    grad_norm = torch.nn.utils.clip_grad_norm_(
        model.parameters(), max_norm=grad_clip
    )

    # Detect abnormally large gradients
    if grad_norm > 100 * grad_clip:
        print(f"WARNING: abnormal grad_norm={grad_norm:.2f}, skipping step")
        optimizer.zero_grad()
        return None, grad_norm

    optimizer.step()
    return loss.item(), grad_norm

5. Efficient Pretraining

5.1 Flash Attention 2

Flash Attention rewrites the attention computation to maximize GPU SRAM utilization, reducing memory complexity from O(N2)O(N^2) to O(N)O(N).

# Flash Attention 2 usage
# pip install flash-attn --no-build-isolation

import torch
from flash_attn import flash_attn_func

def flash_attention_forward(
    q: torch.Tensor,  # (batch, seqlen, nheads, headdim)
    k: torch.Tensor,
    v: torch.Tensor,
    causal: bool = True,
    softmax_scale: float = None,
):
    """
    Flash Attention 2 forward pass.
    causal=True: applies causal mask for autoregressive LMs.
    """
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)

    out = flash_attn_func(
        q, k, v,
        dropout_p=0.0,
        softmax_scale=softmax_scale,
        causal=causal,
    )
    return out  # (batch, seqlen, nheads, headdim)

Flash Attention 2 compared to standard attention:

  • Memory: linear in sequence length (vs. quadratic)
  • Speed: 2-4x faster on A100
  • Numerical accuracy: guaranteed identical results

5.2 Sliding Window Attention (SWA)

Used in Mistral 7B, SWA limits each token's attention to the most recent W tokens. This reduces attention complexity to O(WN)O(W \cdot N) for long sequences while maintaining local context.


6. Evaluation and Checkpointing

6.1 Monitoring the Perplexity Curve

import torch
import math

def compute_perplexity(model, dataloader, device: str = "cuda"):
    """Compute perplexity on a validation set."""
    model.eval()
    total_loss = 0.0
    total_tokens = 0

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids=input_ids, labels=labels)
            batch_tokens = (labels != -100).sum().item()
            total_loss += outputs.loss.item() * batch_tokens
            total_tokens += batch_tokens

    avg_loss = total_loss / total_tokens
    perplexity = math.exp(avg_loss)
    return perplexity

# Typical benchmarks:
# GPT-2 level: PPL ~20 (WebText)
# Good 7B model: PPL ~6-8 (validation set)

6.2 Running lm-evaluation-harness Benchmarks

# Install and run lm-evaluation-harness
pip install lm-eval

# Evaluate Llama 3 8B on standard benchmarks
lm_eval --model hf \
    --model_args pretrained=meta-llama/Meta-Llama-3-8B \
    --tasks mmlu,hellaswag,arc_challenge,winogrande \
    --device cuda:0 \
    --batch_size 8 \
    --output_path results/llama3-8b/

# Evaluate a specific checkpoint during training
lm_eval --model hf \
    --model_args pretrained=./checkpoints/step-50000 \
    --tasks hellaswag \
    --num_fewshot 10 \
    --device cuda:0

6.3 Checkpoint Strategy

Checkpoint TypeSave IntervalRetention
Rolling recentEvery 500 stepsKeep last 5
MilestoneEvery 5,000 stepsPermanent
Best validationBased on PPLPermanent

7. Recent Pretraining Recipes (2024-2025)

7.1 Llama 3.1

  • Training data: 15T tokens (7.5x Llama 2)
  • Context length: 128K (via RoPE extension)
  • Vocabulary size: 128K (tiktoken-based)
  • Notable: post-pretraining long-context annealing phase

7.2 Mistral Large / Mistral 7B

  • Mistral 7B: GQA + SWA for maximal efficiency
  • Mixtral 8x7B: 8 experts, 2 active per token (13B parameters in use per forward pass)
  • Context: 32K via SWA

7.3 DeepSeek-V3

  • Architecture: 671B MoE, 37B active parameters per token
  • Training cost: ~$5.57M on 2,048x H800 for 60 days (10x cheaper than comparable models)
  • Innovations: Multi-head Latent Attention (MLA), FP8 mixed-precision training
  • Data: 14.8T tokens with heavy Chinese and code representation

7.4 phi-4 (Microsoft)

  • Size: 14B parameters
  • Strategy: data quality over data quantity — heavy synthetic data augmentation
  • Training data: 9.8T tokens (~40% synthetic)
  • Achievement: competitive with 70B+ models on math and reasoning benchmarks

Quiz

Q1. What is the core argument of the Chinchilla paper that showed smaller models can outperform GPT-3?

Answer: GPT-3 trained 175B parameters on only 300B tokens. Chinchilla showed that parameters and training tokens should scale together at a 1:20 ratio — GPT-3 was severely data-starved.

Explanation: Hoffmann et al. 2022 demonstrated through rigorous experiments that the Kaplan scaling laws undervalued data. Under the same compute budget, a 70B parameter model trained on 1.4T tokens (Chinchilla) outperformed GPT-3 (175B / 300B tokens) on nearly all benchmarks. The key insight is that the optimal number of training tokens is proportional to the square root of the compute budget, not just the model size.

Q2. How does the MinHash algorithm efficiently detect near-duplicates in a large text corpus?

Answer: Documents are represented as shingle (n-gram) sets, then compressed to fixed-length MinHash signatures that approximate Jaccard similarity. LSH groups similar signatures into the same bucket, so only candidate pairs need precise comparison.

Explanation: Direct Jaccard similarity between all pairs requires O(N^2) comparisons. MinHash represents each document as k hash minimums. Two signatures' agreement rate is an unbiased estimator of Jaccard similarity. LSH further reduces comparisons to near O(N) by only comparing documents that hash to the same bucket, making corpus-scale deduplication tractable.

Q3. How does Grouped Query Attention (GQA) reduce inference memory compared to Multi-Head Attention?

Answer: GQA uses fewer Key/Value heads than Query heads. Since the KV cache scales with the number of KV heads (not Query heads), the cache is reduced by a factor of KV heads / Query heads.

Explanation: During autoregressive inference, all previous tokens' Keys and Values must be cached. With MHA, the cache size is batch_size x seq_len x num_heads x head_dim x 2. GQA with G KV heads (G < H query heads) reduces cache to G/H of MHA. Llama 3 8B (H=32 query heads, G=8 KV heads) uses 1/4 the KV cache of equivalent MHA.

Q4. Why is RoPE (Rotary Position Embedding) better suited for long-context extrapolation than absolute positional embeddings?

Answer: RoPE makes attention scores a function of relative position (m-n) rather than absolute positions. This means the model generalizes to position offsets it has not seen during training, enabling length extrapolation.

Explanation: Absolute positional embeddings (sinusoidal or learned) are added to token embeddings at specific indices. Positions beyond the training length are out-of-distribution. RoPE applies rotation matrices to Q and K such that their dot product depends only on the difference between their positions. Extensions like YaRN and LongRoPE leverage this property to extend context windows from 4K to 128K+ tokens with minimal fine-tuning.

Q5. Why is a load balancing loss necessary in MoE expert routing?

Answer: Without it, the router collapses to always choosing the same expert (expert collapse). That expert improves most, gets selected more, and the cycle repeats — other experts receive no gradients and remain undertrained.

Explanation: The router's softmax produces expert selection probabilities. Even with uniform initialization, a positive feedback loop causes one or a few experts to dominate. Load balancing loss adds an auxiliary penalty when the distribution of tokens to experts is uneven. The loss term equals num_experts times the sum of (expert fraction) times (average routing probability), encouraging a uniform distribution and ensuring all experts are adequately trained.


Conclusion

LLM pretraining is where theoretical scaling laws, practical data engineering, and systems optimization intersect. Chinchilla proved mathematically that "bigger is not always better" — data efficiency matters as much as parameter count. Flash Attention and GQA made large-scale training and deployment economically feasible. DeepSeek-V3 demonstrated that MoE architecture combined with efficient implementation can achieve frontier performance at a fraction of the cost. The concepts in this guide form the foundation for understanding and conducting LLM pretraining experiments at any scale.