Split View: LLM 사전 학습 & 스케일링 법칙: Chinchilla, Flash Attention, MoE까지
LLM 사전 학습 & 스케일링 법칙: Chinchilla, Flash Attention, MoE까지
들어가며
GPT-3가 등장한 이후 "모델을 크게 만들면 성능이 좋아진다"는 직관이 지배했습니다. 그런데 2022년 DeepMind의 Chinchilla 연구는 이 직관에 정면으로 도전했습니다. 파라미터 수를 4배 줄이고 데이터를 4배 늘린 모델이 GPT-3를 능가한다는 것입니다. 이 가이드에서는 스케일링 법칙의 수학적 기반부터 최신 사전 학습 레시피까지, LLM 사전 학습의 핵심을 체계적으로 정리합니다.
1. 스케일링 법칙 (Scaling Laws)
1.1 Kaplan et al. (2020) — OpenAI 스케일링 법칙
Kaplan et al.은 언어 모델의 손실이 파라미터 수(N), 데이터 크기(D), 컴퓨팅 예산(C)의 거듭제곱 법칙을 따른다는 것을 실험적으로 보였습니다.
핵심 발견은 컴퓨팅 예산이 고정되어 있을 때, 모델 크기에 훨씬 더 많이 투자하는 것이 유리하다는 것이었습니다. 이 결론이 GPT-3 (175B 파라미터)와 같은 거대 모델 개발의 이론적 근거가 되었습니다.
1.2 Chinchilla (Hoffmann et al. 2022) — 재정립된 스케일링 법칙
DeepMind 팀은 더 엄밀한 실험 설계를 통해 Kaplan의 결론이 데이터 측면을 과소평가했음을 보였습니다.
Chinchilla 법칙의 핵심:
최적 컴퓨팅 예산 C (FLOPs)가 주어졌을 때, 최적 모델 크기 와 최적 데이터 크기 는 다음 관계를 만족합니다.
즉, 파라미터 1개당 학습 토큰은 약 20개가 최적이라는 결론입니다.
| 모델 | 파라미터 | 학습 토큰 | Chinchilla 최적비 |
|---|---|---|---|
| GPT-3 | 175B | 300B | 1:1.7 (데이터 부족) |
| Chinchilla | 70B | 1.4T | 1:20 (최적) |
| Llama 3.1 | 8B | 15T | 1:1875 (데이터 초과) |
Llama 3.1처럼 의도적으로 Chinchilla 최적보다 훨씬 더 많은 데이터를 학습하는 경우는 추론 효율성 때문입니다. 작은 모델이 많이 학습할수록 배포 비용이 줄어듭니다.
1.3 최적 컴퓨팅 예산 배분 계산
import math
def chinchilla_optimal(compute_budget_flops: float):
"""
Chinchilla 법칙에 따른 최적 N, D 계산
compute_budget_flops: 총 FLOPs (예: 6 * N * D 근사)
반환: (optimal_params, optimal_tokens)
"""
# Chinchilla 계수 (Hoffmann et al. Table A3 기준)
# C = 6 * N * D 근사에서
# N_opt = (C / (6 * 20))^0.5, D_opt = 20 * N_opt
N_opt = math.sqrt(compute_budget_flops / (6 * 20))
D_opt = 20 * N_opt
return N_opt, D_opt
# 예시: A100 1000개, 30일 학습
# A100: ~312 TFLOPS (bf16), 가동률 40%
flops_per_sec = 1000 * 312e12 * 0.4
duration_sec = 30 * 24 * 3600
total_flops = flops_per_sec * duration_sec
N_opt, D_opt = chinchilla_optimal(total_flops)
print(f"총 FLOPs: {total_flops:.2e}")
print(f"최적 파라미터 수: {N_opt/1e9:.1f}B")
print(f"최적 토큰 수: {D_opt/1e12:.1f}T")
2. 데이터 준비
2.1 Common Crawl 필터링 파이프라인
Common Crawl은 월 수십 TB의 웹 크롤링 데이터를 제공합니다. 그대로 쓰면 품질이 매우 낮기 때문에 다단계 필터링이 필수입니다.
전형적인 필터링 파이프라인:
- 언어 감지 (fastText langdetect): 목표 언어만 유지
- 품질 필터: 최소 문장 수, 단어 반복 비율, 특수문자 비율
- 도메인 블랙리스트: 스팸, 성인, 광고 도메인 제거
- 퍼플렉시티 필터: n-gram 언어 모델 기반 낮은 품질 문서 제거
- 중복 제거: MinHash LSH 기반
2.2 MinHash 중복 제거
from datasketch import MinHash, MinHashLSH
import re
def get_shingles(text: str, k: int = 5):
"""문자 k-shingle 집합 생성"""
text = re.sub(r'\s+', ' ', text.lower())
return {text[i:i+k] for i in range(len(text) - k + 1)}
def build_minhash(text: str, num_perm: int = 128) -> MinHash:
"""텍스트에서 MinHash 서명 생성"""
m = MinHash(num_perm=num_perm)
for shingle in get_shingles(text):
m.update(shingle.encode('utf-8'))
return m
def deduplicate_corpus(documents: list, threshold: float = 0.8):
"""
MinHash LSH로 유사 문서 제거
threshold: Jaccard 유사도 임계값
"""
lsh = MinHashLSH(threshold=threshold, num_perm=128)
unique_docs = []
seen = set()
for idx, doc in enumerate(documents):
mh = build_minhash(doc)
result = lsh.query(mh)
if len(result) == 0:
lsh.insert(f"doc_{idx}", mh)
unique_docs.append(doc)
# 유사 문서가 이미 존재하면 스킵
print(f"원본: {len(documents)}개 → 중복 제거 후: {len(unique_docs)}개")
return unique_docs
MinHash의 핵심 아이디어: 두 집합의 Jaccard 유사도 를 해시 함수 k개의 최솟값으로 추정합니다. 128개의 해시 함수를 사용하면 약 3% 오차 내에서 Jaccard 유사도를 추정할 수 있습니다.
2.3 토크나이저 학습
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.processors import ByteLevel as ByteLevelProcessor
def train_bpe_tokenizer(
corpus_files: list,
vocab_size: int = 32000,
save_path: str = "tokenizer.json"
):
"""BPE 토크나이저 학습 (GPT-2 스타일 Byte-level BPE)"""
tokenizer = Tokenizer(BPE(unk_token=None))
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)
tokenizer.post_processor = ByteLevelProcessor(trim_offsets=False)
trainer = BpeTrainer(
vocab_size=vocab_size,
min_frequency=2,
special_tokens=["<pad>", "<s>", "</s>", "<unk>"],
show_progress=True,
)
tokenizer.train(files=corpus_files, trainer=trainer)
tokenizer.save(save_path)
print(f"토크나이저 저장 완료: {save_path} (vocab_size={vocab_size})")
return tokenizer
# 사용 예시
# train_bpe_tokenizer(["data/train.txt"], vocab_size=32000)
2.4 데이터 혼합 비율
Llama 3.1의 사전 학습 데이터 혼합 비율을 참고하면:
| 데이터 소스 | 비율 |
|---|---|
| 웹 크롤링 (Common Crawl 등) | 약 80% |
| 코드 (GitHub) | 약 8% |
| 학술/과학 논문 | 약 5% |
| 책 (Books3 등) | 약 4% |
| 다국어 데이터 | 약 3% |
3. 아키텍처 선택
3.1 주요 오픈소스 LLM 설계 결정 비교
| 모델 | 위치 임베딩 | Attention | FFN | Normalization |
|---|---|---|---|---|
| GPT-NeoX | ALiBi | MHA | SwiGLU | LayerNorm |
| LLaMA-3 | RoPE | GQA | SwiGLU | RMSNorm |
| Mistral 7B | RoPE | GQA + SWA | SwiGLU | RMSNorm |
| Mixtral 8x7B | RoPE | GQA | MoE (SwiGLU) | RMSNorm |
3.2 RoPE (Rotary Position Embedding)
RoPE는 절대 위치 정보를 직접 더하는 대신, attention score 계산 시 Query/Key 벡터에 회전 변환을 적용합니다.
위치 m에서의 쿼리 벡터 q와 위치 n에서의 키 벡터 k의 내적이 상대 위치 (m-n)의 함수가 됩니다. 이 성질 덕분에 훈련 길이를 넘어서는 긴 컨텍스트에도 어느 정도 외삽(extrapolation)이 가능합니다.
3.3 Grouped Query Attention (GQA)
Multi-Head Attention(MHA)에서 KV 캐시는 추론 메모리의 큰 부분을 차지합니다. GQA는 여러 Query 헤드가 소수의 Key/Value 헤드를 공유합니다.
- MHA: H개 쿼리, H개 키, H개 값 헤드
- MQA: H개 쿼리, 1개 키, 1개 값 헤드
- GQA: H개 쿼리, G개 키, G개 값 헤드 (G < H)
Llama 3 8B는 32개 쿼리 헤드에 8개 KV 헤드를 사용합니다. KV 캐시가 1/4로 줄어듭니다.
3.4 Mixture of Experts (MoE)
Mixtral 8x7B와 DeepSeek-V3는 MoE 아키텍처를 사용합니다. 각 토큰은 N개의 전문가(FFN 레이어) 중 Top-K개만 활성화합니다.
import torch
import torch.nn as nn
import torch.nn.functional as F
class MoELayer(nn.Module):
"""간단한 Mixture of Experts 레이어"""
def __init__(self, d_model: int, d_ff: int, num_experts: int = 8, top_k: int = 2):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.router = nn.Linear(d_model, num_experts, bias=False)
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(d_model, d_ff),
nn.SiLU(),
nn.Linear(d_ff, d_model),
)
for _ in range(num_experts)
])
def forward(self, x: torch.Tensor):
# x: (batch, seq_len, d_model)
B, T, D = x.shape
x_flat = x.view(-1, D) # (B*T, D)
# 라우터 계산
router_logits = self.router(x_flat) # (B*T, num_experts)
router_probs = F.softmax(router_logits, dim=-1)
# Top-K 전문가 선택
topk_probs, topk_idx = router_probs.topk(self.top_k, dim=-1)
topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True) # 재정규화
# Load Balancing Loss (전문가 균형 유지)
# 각 전문가가 균등하게 선택되도록
expert_load = router_probs.mean(0) # (num_experts,)
load_balance_loss = self.num_experts * (expert_load * expert_load.mean()).sum()
# 전문가 출력 계산
output = torch.zeros_like(x_flat)
for k in range(self.top_k):
expert_indices = topk_idx[:, k]
expert_weights = topk_probs[:, k].unsqueeze(-1)
for e_idx in range(self.num_experts):
mask = (expert_indices == e_idx)
if mask.any():
expert_out = self.experts[e_idx](x_flat[mask])
output[mask] += expert_weights[mask] * expert_out
return output.view(B, T, D), load_balance_loss
Load Balancing Loss가 필요한 이유: 라우터가 특정 전문가만 선택하도록 수렴하면 나머지 전문가는 학습되지 않고, 선택된 전문가는 과부하가 걸립니다(expert collapse). Load balancing loss는 모든 전문가에게 균등한 토큰을 보내도록 정규화합니다.
4. 학습 안정성
4.1 학습률 스케줄: Cosine Warmup
import math
def cosine_lr_with_warmup(
optimizer,
step: int,
warmup_steps: int,
total_steps: int,
max_lr: float,
min_lr: float = 0.0,
):
"""Cosine Annealing with Linear Warmup"""
if step < warmup_steps:
lr = max_lr * step / warmup_steps
else:
progress = (step - warmup_steps) / (total_steps - warmup_steps)
lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr
# 전형적인 사전 학습 설정
# warmup: 전체 스텝의 1~2%
# max_lr: 3e-4 (7B 모델 기준)
# min_lr: max_lr * 0.1
4.2 손실 스파이크 감지 및 복구
대규모 사전 학습에서 손실이 갑자기 치솟는 스파이크가 발생할 수 있습니다. 일반적인 대처 전략:
- Gradient Norm Clipping: 그래디언트 노름이 임계값을 초과하면 스케일 다운
- 스파이크 감지 후 롤백: 이전 체크포인트에서 재시작, 해당 배치 스킵
- Loss 이동 평균 모니터링: 급격한 상승 시 알람
import torch
def train_step_with_stability(model, optimizer, batch, grad_clip: float = 1.0):
"""손실 스파이크 방지를 위한 안전한 학습 스텝"""
optimizer.zero_grad()
loss = model(batch)
loss.backward()
# Gradient Norm 계산
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=grad_clip
)
# 비정상적으로 큰 그래디언트 감지
if grad_norm > 100 * grad_clip:
print(f"경고: 비정상 grad_norm={grad_norm:.2f}, 스텝 스킵")
optimizer.zero_grad()
return None, grad_norm
optimizer.step()
return loss.item(), grad_norm
5. 효율적 사전 학습
5.1 Flash Attention 2
Flash Attention은 GPU의 SRAM을 최대한 활용해 Attention 연산의 메모리 복잡도를 에서 으로 줄입니다.
# Flash Attention 2 사용 예시
# pip install flash-attn --no-build-isolation
import torch
from flash_attn import flash_attn_func
from flash_attn.bert_padding import unpad_input, pad_input
def flash_attention_forward(
q: torch.Tensor, # (batch, seqlen, nheads, headdim)
k: torch.Tensor,
v: torch.Tensor,
causal: bool = True,
softmax_scale: float = None,
):
"""
Flash Attention 2 순전파
causal=True: 자동 회귀 언어 모델용 인과적 마스크 적용
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out = flash_attn_func(
q, k, v,
dropout_p=0.0,
softmax_scale=softmax_scale,
causal=causal,
)
return out # (batch, seqlen, nheads, headdim)
Flash Attention 2는 표준 Attention 대비:
- 메모리: 시퀀스 길이에 선형 비례 (기존 2차 비례 대비)
- 속도: 2~4배 빠름 (A100 기준)
- 수치 안정성: 동일 결과 보장
5.2 Sliding Window Attention (SWA)
Mistral 7B에서 사용된 SWA는 각 토큰이 최근 W개의 토큰에만 attend합니다. 긴 시퀀스에서 Attention 복잡도를 으로 줄입니다.
6. 평가와 체크포인팅
6.1 Perplexity 곡선 모니터링
import torch
import math
def compute_perplexity(model, dataloader, device: str = "cuda"):
"""검증 세트 퍼플렉시티 계산"""
model.eval()
total_loss = 0.0
total_tokens = 0
with torch.no_grad():
for batch in dataloader:
input_ids = batch["input_ids"].to(device)
labels = batch["labels"].to(device)
outputs = model(input_ids=input_ids, labels=labels)
# CrossEntropy loss는 이미 토큰당 평균
batch_tokens = (labels != -100).sum().item()
total_loss += outputs.loss.item() * batch_tokens
total_tokens += batch_tokens
avg_loss = total_loss / total_tokens
perplexity = math.exp(avg_loss)
return perplexity
# 전형적인 사전 학습 진행 기준
# GPT-2 수준: PPL ~20 (WebText)
# 좋은 7B 모델: PPL ~6~8 (검증 세트 기준)
6.2 lm-evaluation-harness 벤치마크
# lm-evaluation-harness 설치 및 실행
pip install lm-eval
# Llama 3 8B 평가 예시
lm_eval --model hf \
--model_args pretrained=meta-llama/Meta-Llama-3-8B \
--tasks mmlu,hellaswag,arc_challenge,winogrande \
--device cuda:0 \
--batch_size 8 \
--output_path results/llama3-8b/
# 특정 체크포인트 평가
lm_eval --model hf \
--model_args pretrained=./checkpoints/step-50000 \
--tasks hellaswag \
--num_fewshot 10 \
--device cuda:0
6.3 체크포인트 전략
대규모 사전 학습에서 체크포인트는 세 가지 용도로 관리합니다.
| 체크포인트 유형 | 저장 간격 | 보존 기간 |
|---|---|---|
| 최근 N개 유지 | 매 500 스텝 | 최근 5개만 |
| 마일스톤 | 매 5,000 스텝 | 영구 보존 |
| 최고 성능 | 검증 PPL 기준 | 영구 보존 |
7. 최신 사전 학습 레시피 비교 (2024-2025)
7.1 Llama 3.1
- 학습 데이터: 15T 토큰 (Llama 2의 7.5배)
- 컨텍스트 길이: 128K (RoPE 확장 적용)
- 어휘 크기: 128K (tiktoken 기반)
- 특이사항: 사전 학습 후 긴 컨텍스트 어닐링 단계 추가
7.2 Mistral Large / Mistral 7B
- Mistral 7B: GQA + SWA로 효율성 극대화
- Mixtral 8x7B: 8개 전문가 중 2개 활성화 (실제 파라미터 13B 사용)
- 컨텍스트: 32K (SWA 기반)
7.3 DeepSeek-V3
- 아키텍처: 671B MoE, 토큰당 37B 활성화
- 학습 비용: H800 2,048개 × 60일 (약 557만 달러 — 경쟁사 대비 10배 저렴)
- 혁신: Multi-head Latent Attention (MLA), FP8 혼합 정밀도 학습
- 데이터: 14.8T 토큰 (중국어·코드 비중 높음)
7.4 phi-4 (Microsoft)
- 크기: 14B 파라미터
- 전략: 합성 데이터 비중을 대폭 높여 데이터 품질 중심 접근
- 학습 데이터: 9.8T 토큰 (합성 데이터 약 40%)
- 성과: 70B+ 모델과 경쟁하는 수학·추론 벤치마크 성능
퀴즈
Q1. Chinchilla 법칙이 GPT-3보다 적은 파라미터로 더 좋은 성능을 낼 수 있음을 보인 핵심 주장은?
정답: GPT-3은 175B 파라미터에 300B 토큰만 학습했으나, Chinchilla는 파라미터 수와 데이터 수를 균형 있게 키워야 한다는 것을 실험적으로 증명했습니다.
설명: Chinchilla (Hoffmann et al. 2022)는 컴퓨팅 예산이 고정될 때 최적 파라미터 수와 최적 토큰 수가 대략 1:20 비율을 만족해야 한다고 주장합니다. GPT-3는 175B 파라미터 대비 약 3.5T 토큰이 필요했지만 실제로는 300B 토큰만 사용해 "데이터 부족" 상태였습니다. 같은 컴퓨팅 예산으로 70B 파라미터에 1.4T 토큰을 학습한 Chinchilla가 GPT-3를 능가했습니다.
Q2. MinHash 알고리즘이 대규모 텍스트 코퍼스에서 중복을 효율적으로 감지하는 방법은?
정답: 문서를 shingle(n-gram 집합)로 표현하고, 여러 해시 함수의 최솟값(MinHash 서명)으로 Jaccard 유사도를 근사 추정합니다. LSH(Locality Sensitive Hashing)로 유사 문서를 빠르게 후보 집합으로 좁힙니다.
설명: 두 문서의 Jaccard 유사도를 직접 계산하면 코퍼스 크기 N에 대해 O(N^2) 비교가 필요합니다. MinHash는 각 문서를 k개 해시 함수의 최솟값으로 이루어진 길이 k짜리 서명 벡터로 압축합니다. 두 서명의 해밍 유사도가 원래 Jaccard 유사도의 불편 추정량이 됩니다. LSH는 같은 버킷에 해시된 후보들만 정밀 비교해 전체 복잡도를 O(N)에 가깝게 줄입니다.
Q3. Grouped Query Attention(GQA)이 Multi-Head Attention보다 추론 메모리를 절감하는 원리는?
정답: GQA는 여러 Query 헤드가 소수의 Key/Value 헤드를 공유하므로, KV 캐시 크기가 Query 헤드 수에 비례하지 않고 KV 헤드 수에만 비례합니다.
설명: 자동 회귀 추론 시 이전 토큰의 Key/Value를 KV 캐시에 저장합니다. MHA에서 H개 헤드 모두가 별도 KV를 가지면 KV 캐시 크기는 배치 크기 × 시퀀스 길이 × H × head_dim × 2입니다. GQA에서 G개 KV 헤드만 쓰면 KV 캐시가 G/H로 줄어듭니다. Llama 3 8B (H=32, G=8)는 KV 캐시가 MHA 대비 1/4 수준입니다.
Q4. RoPE(Rotary Position Embedding)가 절대 위치 임베딩보다 긴 컨텍스트 외삽(extrapolation)에 유리한 이유는?
정답: RoPE는 Attention score가 절대 위치가 아닌 상대 위치 (m-n)의 함수가 되도록 설계되어, 훈련 시 본 적 없는 긴 거리의 상대적 위치 관계에도 일반화가 가능합니다.
설명: 절대 위치 임베딩(사인/코사인 또는 학습된 임베딩)은 특정 위치 인덱스에 대한 임베딩을 직접 추가합니다. 훈련 최대 길이를 넘는 위치는 분포 외(out-of-distribution) 입력이 됩니다. RoPE는 Q·K 내적이 두 위치의 차(상대 위치)에만 의존하도록 회전 행렬을 적용합니다. 이론적으로 훈련 길이를 넘어서도 상대 위치 패턴을 활용할 수 있어 YaRN, LongRoPE 같은 확장 기법의 기반이 됩니다.
Q5. Mixture of Experts(MoE)에서 전문가 라우팅의 load balancing loss가 필요한 이유는?
정답: 라우터가 학습 초기에 특정 전문가만 선택하도록 수렴하면(expert collapse), 나머지 전문가는 gradient를 받지 못해 훈련되지 않고 선택된 전문가만 과부하 됩니다. Load balancing loss는 모든 전문가에 균등한 토큰이 할당되도록 강제합니다.
설명: 라우터의 소프트맥스 출력은 각 전문가의 선택 확률입니다. 초기화가 균일해도 학습 중 양성 피드백으로 특정 전문가가 더 자주 선택 → 더 잘 학습 → 더 자주 선택되는 순환이 발생합니다. Load balancing loss는 전문가별 평균 선택 확률이 균등해지도록 auxiliary loss를 추가합니다. 이를 통해 모든 전문가가 충분히 학습되고 분산 학습 시 GPU 간 부하도 균형을 이룹니다.
마치며
LLM 사전 학습은 스케일링 법칙이라는 이론적 나침반, 데이터 품질이라는 실전적 과제, 그리고 메모리·속도 효율화라는 공학적 도전이 만나는 영역입니다. Chinchilla 법칙은 "크기만이 답이 아니다"를 수학적으로 증명했고, Flash Attention과 GQA는 대형 모델을 현실적인 비용으로 학습·배포할 수 있게 했습니다. DeepSeek-V3의 성공은 MoE 아키텍처와 효율적 구현이 결합될 때 비용 대비 최고 성능이 가능함을 보여줍니다. 이 가이드의 개념들을 실험해보며 직접 작은 스케일에서 사전 학습을 체험해 보시기 바랍니다.
LLM Pretraining & Scaling Laws: From Chinchilla to Flash Attention and MoE
Introduction
When GPT-3 arrived in 2020, "bigger is better" became the dominant paradigm in LLM development. Then in 2022, DeepMind's Chinchilla paper challenged that intuition head-on: a model with 4x fewer parameters but 4x more training data outperformed GPT-3. This guide covers everything from the mathematics of scaling laws to the cutting-edge pretraining recipes used in Llama 3.1, DeepSeek-V3, and phi-4.
1. Scaling Laws
1.1 Kaplan et al. (2020) — The OpenAI Scaling Laws
Kaplan et al. empirically demonstrated that language model loss follows power laws with respect to the number of parameters (N), dataset size (D), and compute budget (C).
The key takeaway: given a fixed compute budget, investing overwhelmingly in model size is optimal. This became the theoretical justification for building models like GPT-3 (175B parameters).
1.2 Chinchilla (Hoffmann et al. 2022) — The Revised Scaling Laws
The DeepMind team showed through more rigorous experimental design that Kaplan's conclusion substantially underestimated the value of data.
The Chinchilla finding:
Given a compute budget C (in FLOPs), the optimal model size and optimal token count satisfy:
The rule of thumb: roughly 20 training tokens per parameter for optimal compute allocation.
| Model | Parameters | Training Tokens | Chinchilla Ratio |
|---|---|---|---|
| GPT-3 | 175B | 300B | 1:1.7 (data-starved) |
| Chinchilla | 70B | 1.4T | 1:20 (optimal) |
| Llama 3.1 | 8B | 15T | 1:1875 (over-trained) |
Llama 3.1's deliberately over-trained approach is driven by inference efficiency: a smaller model trained longer costs less to deploy, even if it uses more compute during training.
1.3 Computing Optimal Resource Allocation
import math
def chinchilla_optimal(compute_budget_flops: float):
"""
Compute optimal N and D given a compute budget (Chinchilla scaling laws).
compute_budget_flops: total FLOPs (approximated as 6 * N * D)
Returns: (optimal_params, optimal_tokens)
"""
# From Hoffmann et al. Table A3 coefficients
# With C = 6 * N * D approximation:
# N_opt = sqrt(C / (6 * 20)), D_opt = 20 * N_opt
N_opt = math.sqrt(compute_budget_flops / (6 * 20))
D_opt = 20 * N_opt
return N_opt, D_opt
# Example: 1000x A100 GPUs, 30 days of training
# A100: ~312 TFLOPS (bf16), assuming 40% utilization
flops_per_sec = 1000 * 312e12 * 0.4
duration_sec = 30 * 24 * 3600
total_flops = flops_per_sec * duration_sec
N_opt, D_opt = chinchilla_optimal(total_flops)
print(f"Total FLOPs: {total_flops:.2e}")
print(f"Optimal parameters: {N_opt/1e9:.1f}B")
print(f"Optimal tokens: {D_opt/1e12:.1f}T")
2. Data Preparation
2.1 Common Crawl Filtering Pipeline
Common Crawl provides tens of terabytes of web-crawled data monthly. Raw, it is extremely noisy — a multi-stage filtering pipeline is essential.
Typical filtering stages:
- Language detection (fastText): retain only target language(s)
- Quality filters: minimum sentence count, word repetition ratio, special character density
- Domain blocklist: spam, adult content, ad-heavy domains
- Perplexity filter: remove low-quality text using an n-gram language model
- Deduplication: MinHash LSH-based fuzzy deduplication
2.2 MinHash Deduplication
from datasketch import MinHash, MinHashLSH
import re
def get_shingles(text: str, k: int = 5):
"""Generate character k-shingle set from text."""
text = re.sub(r'\s+', ' ', text.lower())
return {text[i:i+k] for i in range(len(text) - k + 1)}
def build_minhash(text: str, num_perm: int = 128) -> MinHash:
"""Create a MinHash signature from text."""
m = MinHash(num_perm=num_perm)
for shingle in get_shingles(text):
m.update(shingle.encode('utf-8'))
return m
def deduplicate_corpus(documents: list, threshold: float = 0.8):
"""
Remove near-duplicate documents using MinHash LSH.
threshold: Jaccard similarity threshold for considering duplicates.
"""
lsh = MinHashLSH(threshold=threshold, num_perm=128)
unique_docs = []
for idx, doc in enumerate(documents):
mh = build_minhash(doc)
result = lsh.query(mh)
if len(result) == 0:
lsh.insert(f"doc_{idx}", mh)
unique_docs.append(doc)
# If similar doc already exists, skip
print(f"Original: {len(documents)} docs -> After dedup: {len(unique_docs)} docs")
return unique_docs
The core idea: the Jaccard similarity between two sets can be estimated by the probability that the minimum hash value under a random permutation is the same for both sets. With 128 hash functions, Jaccard similarity is estimated within roughly 3% error. LSH bins documents into buckets so only candidate pairs are compared — bringing overall complexity close to O(N).
2.3 Tokenizer Training
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.processors import ByteLevel as ByteLevelProcessor
def train_bpe_tokenizer(
corpus_files: list,
vocab_size: int = 32000,
save_path: str = "tokenizer.json"
):
"""Train a GPT-2-style Byte-level BPE tokenizer."""
tokenizer = Tokenizer(BPE(unk_token=None))
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)
tokenizer.post_processor = ByteLevelProcessor(trim_offsets=False)
trainer = BpeTrainer(
vocab_size=vocab_size,
min_frequency=2,
special_tokens=["<pad>", "<s>", "</s>", "<unk>"],
show_progress=True,
)
tokenizer.train(files=corpus_files, trainer=trainer)
tokenizer.save(save_path)
print(f"Tokenizer saved: {save_path} (vocab_size={vocab_size})")
return tokenizer
2.4 Data Mixture Ratios
Llama 3.1's pretraining data mixture as a reference:
| Data Source | Proportion |
|---|---|
| Web crawl (Common Crawl, etc.) | ~80% |
| Code (GitHub) | ~8% |
| Academic / scientific papers | ~5% |
| Books (Books3, etc.) | ~4% |
| Multilingual data | ~3% |
3. Architecture Choices
3.1 Key Design Decisions Across Open-Source LLMs
| Model | Positional Encoding | Attention | FFN | Normalization |
|---|---|---|---|---|
| GPT-NeoX | ALiBi | MHA | SwiGLU | LayerNorm |
| LLaMA-3 | RoPE | GQA | SwiGLU | RMSNorm |
| Mistral 7B | RoPE | GQA + SWA | SwiGLU | RMSNorm |
| Mixtral 8x7B | RoPE | GQA | MoE (SwiGLU) | RMSNorm |
3.2 RoPE (Rotary Position Embedding)
Rather than adding absolute positional information to token embeddings, RoPE applies a rotation transformation to Query and Key vectors before the dot product.
The result: the inner product between a query at position m and a key at position n depends only on the relative offset (m-n). This relative nature enables better extrapolation to sequence lengths beyond what was seen during training.
Extensions like YaRN and LongRoPE use this property to extend context to hundreds of thousands of tokens.
3.3 Grouped Query Attention (GQA)
During autoregressive inference, previous tokens' Key and Value tensors must be stored in the KV cache. GQA reduces this cache by having multiple Query heads share a smaller number of Key/Value heads.
- MHA: H query heads, H key heads, H value heads
- MQA: H query heads, 1 key head, 1 value head
- GQA: H query heads, G key heads, G value heads (G < H)
Llama 3 8B uses 32 query heads with 8 KV heads — reducing KV cache to 1/4 of MHA.
3.4 Mixture of Experts (MoE)
Mixtral 8x7B and DeepSeek-V3 use MoE architectures. Each token activates only Top-K out of N expert FFN layers.
import torch
import torch.nn as nn
import torch.nn.functional as F
class MoELayer(nn.Module):
"""Simple Mixture of Experts layer."""
def __init__(self, d_model: int, d_ff: int, num_experts: int = 8, top_k: int = 2):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.router = nn.Linear(d_model, num_experts, bias=False)
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(d_model, d_ff),
nn.SiLU(),
nn.Linear(d_ff, d_model),
)
for _ in range(num_experts)
])
def forward(self, x: torch.Tensor):
# x: (batch, seq_len, d_model)
B, T, D = x.shape
x_flat = x.view(-1, D) # (B*T, D)
# Router
router_logits = self.router(x_flat) # (B*T, num_experts)
router_probs = F.softmax(router_logits, dim=-1)
# Top-K expert selection
topk_probs, topk_idx = router_probs.topk(self.top_k, dim=-1)
topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True) # renormalize
# Load Balancing Loss
expert_load = router_probs.mean(0) # (num_experts,)
load_balance_loss = self.num_experts * (expert_load * expert_load.mean()).sum()
# Compute expert outputs
output = torch.zeros_like(x_flat)
for k in range(self.top_k):
expert_indices = topk_idx[:, k]
expert_weights = topk_probs[:, k].unsqueeze(-1)
for e_idx in range(self.num_experts):
mask = (expert_indices == e_idx)
if mask.any():
expert_out = self.experts[e_idx](x_flat[mask])
output[mask] += expert_weights[mask] * expert_out
return output.view(B, T, D), load_balance_loss
Why load balancing loss is necessary: Without it, the router collapses to always selecting the same expert (expert collapse). The chosen expert improves most, gets selected even more, and the rest never train — a positive feedback loop. Load balancing loss penalizes unequal distribution by adding an auxiliary loss term that encourages all experts to receive roughly equal token counts.
4. Training Stability
4.1 Learning Rate Schedule: Cosine Warmup
import math
def cosine_lr_with_warmup(
optimizer,
step: int,
warmup_steps: int,
total_steps: int,
max_lr: float,
min_lr: float = 0.0,
):
"""Cosine Annealing with Linear Warmup."""
if step < warmup_steps:
lr = max_lr * step / warmup_steps
else:
progress = (step - warmup_steps) / (total_steps - warmup_steps)
lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr
# Typical pretraining settings:
# warmup: 1-2% of total steps
# max_lr: 3e-4 (for ~7B model)
# min_lr: max_lr * 0.1
4.2 Loss Spike Detection and Recovery
Large-scale pretraining runs frequently encounter sudden loss spikes. Standard mitigation strategies:
- Gradient norm clipping: scale down gradients when the norm exceeds a threshold
- Spike detection with rollback: restart from a previous checkpoint, skipping the offending batch
- Exponential moving average monitoring: trigger alerts on sudden loss increases
import torch
def train_step_with_stability(model, optimizer, batch, grad_clip: float = 1.0):
"""Training step with gradient norm clipping and spike detection."""
optimizer.zero_grad()
loss = model(batch)
loss.backward()
# Compute and clip gradient norm
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=grad_clip
)
# Detect abnormally large gradients
if grad_norm > 100 * grad_clip:
print(f"WARNING: abnormal grad_norm={grad_norm:.2f}, skipping step")
optimizer.zero_grad()
return None, grad_norm
optimizer.step()
return loss.item(), grad_norm
5. Efficient Pretraining
5.1 Flash Attention 2
Flash Attention rewrites the attention computation to maximize GPU SRAM utilization, reducing memory complexity from to .
# Flash Attention 2 usage
# pip install flash-attn --no-build-isolation
import torch
from flash_attn import flash_attn_func
def flash_attention_forward(
q: torch.Tensor, # (batch, seqlen, nheads, headdim)
k: torch.Tensor,
v: torch.Tensor,
causal: bool = True,
softmax_scale: float = None,
):
"""
Flash Attention 2 forward pass.
causal=True: applies causal mask for autoregressive LMs.
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out = flash_attn_func(
q, k, v,
dropout_p=0.0,
softmax_scale=softmax_scale,
causal=causal,
)
return out # (batch, seqlen, nheads, headdim)
Flash Attention 2 compared to standard attention:
- Memory: linear in sequence length (vs. quadratic)
- Speed: 2-4x faster on A100
- Numerical accuracy: guaranteed identical results
5.2 Sliding Window Attention (SWA)
Used in Mistral 7B, SWA limits each token's attention to the most recent W tokens. This reduces attention complexity to for long sequences while maintaining local context.
6. Evaluation and Checkpointing
6.1 Monitoring the Perplexity Curve
import torch
import math
def compute_perplexity(model, dataloader, device: str = "cuda"):
"""Compute perplexity on a validation set."""
model.eval()
total_loss = 0.0
total_tokens = 0
with torch.no_grad():
for batch in dataloader:
input_ids = batch["input_ids"].to(device)
labels = batch["labels"].to(device)
outputs = model(input_ids=input_ids, labels=labels)
batch_tokens = (labels != -100).sum().item()
total_loss += outputs.loss.item() * batch_tokens
total_tokens += batch_tokens
avg_loss = total_loss / total_tokens
perplexity = math.exp(avg_loss)
return perplexity
# Typical benchmarks:
# GPT-2 level: PPL ~20 (WebText)
# Good 7B model: PPL ~6-8 (validation set)
6.2 Running lm-evaluation-harness Benchmarks
# Install and run lm-evaluation-harness
pip install lm-eval
# Evaluate Llama 3 8B on standard benchmarks
lm_eval --model hf \
--model_args pretrained=meta-llama/Meta-Llama-3-8B \
--tasks mmlu,hellaswag,arc_challenge,winogrande \
--device cuda:0 \
--batch_size 8 \
--output_path results/llama3-8b/
# Evaluate a specific checkpoint during training
lm_eval --model hf \
--model_args pretrained=./checkpoints/step-50000 \
--tasks hellaswag \
--num_fewshot 10 \
--device cuda:0
6.3 Checkpoint Strategy
| Checkpoint Type | Save Interval | Retention |
|---|---|---|
| Rolling recent | Every 500 steps | Keep last 5 |
| Milestone | Every 5,000 steps | Permanent |
| Best validation | Based on PPL | Permanent |
7. Recent Pretraining Recipes (2024-2025)
7.1 Llama 3.1
- Training data: 15T tokens (7.5x Llama 2)
- Context length: 128K (via RoPE extension)
- Vocabulary size: 128K (tiktoken-based)
- Notable: post-pretraining long-context annealing phase
7.2 Mistral Large / Mistral 7B
- Mistral 7B: GQA + SWA for maximal efficiency
- Mixtral 8x7B: 8 experts, 2 active per token (13B parameters in use per forward pass)
- Context: 32K via SWA
7.3 DeepSeek-V3
- Architecture: 671B MoE, 37B active parameters per token
- Training cost: ~$5.57M on 2,048x H800 for 60 days (10x cheaper than comparable models)
- Innovations: Multi-head Latent Attention (MLA), FP8 mixed-precision training
- Data: 14.8T tokens with heavy Chinese and code representation
7.4 phi-4 (Microsoft)
- Size: 14B parameters
- Strategy: data quality over data quantity — heavy synthetic data augmentation
- Training data: 9.8T tokens (~40% synthetic)
- Achievement: competitive with 70B+ models on math and reasoning benchmarks
Quiz
Q1. What is the core argument of the Chinchilla paper that showed smaller models can outperform GPT-3?
Answer: GPT-3 trained 175B parameters on only 300B tokens. Chinchilla showed that parameters and training tokens should scale together at a 1:20 ratio — GPT-3 was severely data-starved.
Explanation: Hoffmann et al. 2022 demonstrated through rigorous experiments that the Kaplan scaling laws undervalued data. Under the same compute budget, a 70B parameter model trained on 1.4T tokens (Chinchilla) outperformed GPT-3 (175B / 300B tokens) on nearly all benchmarks. The key insight is that the optimal number of training tokens is proportional to the square root of the compute budget, not just the model size.
Q2. How does the MinHash algorithm efficiently detect near-duplicates in a large text corpus?
Answer: Documents are represented as shingle (n-gram) sets, then compressed to fixed-length MinHash signatures that approximate Jaccard similarity. LSH groups similar signatures into the same bucket, so only candidate pairs need precise comparison.
Explanation: Direct Jaccard similarity between all pairs requires O(N^2) comparisons. MinHash represents each document as k hash minimums. Two signatures' agreement rate is an unbiased estimator of Jaccard similarity. LSH further reduces comparisons to near O(N) by only comparing documents that hash to the same bucket, making corpus-scale deduplication tractable.
Q3. How does Grouped Query Attention (GQA) reduce inference memory compared to Multi-Head Attention?
Answer: GQA uses fewer Key/Value heads than Query heads. Since the KV cache scales with the number of KV heads (not Query heads), the cache is reduced by a factor of KV heads / Query heads.
Explanation: During autoregressive inference, all previous tokens' Keys and Values must be cached. With MHA, the cache size is batch_size x seq_len x num_heads x head_dim x 2. GQA with G KV heads (G < H query heads) reduces cache to G/H of MHA. Llama 3 8B (H=32 query heads, G=8 KV heads) uses 1/4 the KV cache of equivalent MHA.
Q4. Why is RoPE (Rotary Position Embedding) better suited for long-context extrapolation than absolute positional embeddings?
Answer: RoPE makes attention scores a function of relative position (m-n) rather than absolute positions. This means the model generalizes to position offsets it has not seen during training, enabling length extrapolation.
Explanation: Absolute positional embeddings (sinusoidal or learned) are added to token embeddings at specific indices. Positions beyond the training length are out-of-distribution. RoPE applies rotation matrices to Q and K such that their dot product depends only on the difference between their positions. Extensions like YaRN and LongRoPE leverage this property to extend context windows from 4K to 128K+ tokens with minimal fine-tuning.
Q5. Why is a load balancing loss necessary in MoE expert routing?
Answer: Without it, the router collapses to always choosing the same expert (expert collapse). That expert improves most, gets selected more, and the cycle repeats — other experts receive no gradients and remain undertrained.
Explanation: The router's softmax produces expert selection probabilities. Even with uniform initialization, a positive feedback loop causes one or a few experts to dominate. Load balancing loss adds an auxiliary penalty when the distribution of tokens to experts is uneven. The loss term equals num_experts times the sum of (expert fraction) times (average routing probability), encouraging a uniform distribution and ensuring all experts are adequately trained.
Conclusion
LLM pretraining is where theoretical scaling laws, practical data engineering, and systems optimization intersect. Chinchilla proved mathematically that "bigger is not always better" — data efficiency matters as much as parameter count. Flash Attention and GQA made large-scale training and deployment economically feasible. DeepSeek-V3 demonstrated that MoE architecture combined with efficient implementation can achieve frontier performance at a fraction of the cost. The concepts in this guide form the foundation for understanding and conducting LLM pretraining experiments at any scale.