Skip to content
Published on

LLM 사전 학습 & 스케일링 법칙: Chinchilla, Flash Attention, MoE까지

Authors

들어가며

GPT-3가 등장한 이후 "모델을 크게 만들면 성능이 좋아진다"는 직관이 지배했습니다. 그런데 2022년 DeepMind의 Chinchilla 연구는 이 직관에 정면으로 도전했습니다. 파라미터 수를 4배 줄이고 데이터를 4배 늘린 모델이 GPT-3를 능가한다는 것입니다. 이 가이드에서는 스케일링 법칙의 수학적 기반부터 최신 사전 학습 레시피까지, LLM 사전 학습의 핵심을 체계적으로 정리합니다.


1. 스케일링 법칙 (Scaling Laws)

1.1 Kaplan et al. (2020) — OpenAI 스케일링 법칙

Kaplan et al.은 언어 모델의 손실이 파라미터 수(N), 데이터 크기(D), 컴퓨팅 예산(C)의 거듭제곱 법칙을 따른다는 것을 실험적으로 보였습니다.

L(N)(NcN)αNL(N) \approx \left(\frac{N_c}{N}\right)^{\alpha_N}

L(D)(DcD)αDL(D) \approx \left(\frac{D_c}{D}\right)^{\alpha_D}

핵심 발견은 컴퓨팅 예산이 고정되어 있을 때, 모델 크기에 훨씬 더 많이 투자하는 것이 유리하다는 것이었습니다. 이 결론이 GPT-3 (175B 파라미터)와 같은 거대 모델 개발의 이론적 근거가 되었습니다.

1.2 Chinchilla (Hoffmann et al. 2022) — 재정립된 스케일링 법칙

DeepMind 팀은 더 엄밀한 실험 설계를 통해 Kaplan의 결론이 데이터 측면을 과소평가했음을 보였습니다.

Chinchilla 법칙의 핵심:

최적 컴퓨팅 예산 C (FLOPs)가 주어졌을 때, 최적 모델 크기 NoptN_{opt}와 최적 데이터 크기 DoptD_{opt}는 다음 관계를 만족합니다.

Nopt0.2C0.5N_{opt} \approx 0.2 \cdot C^{0.5}

Dopt10C0.5D_{opt} \approx 10 \cdot C^{0.5}

즉, 파라미터 1개당 학습 토큰은 약 20개가 최적이라는 결론입니다.

모델파라미터학습 토큰Chinchilla 최적비
GPT-3175B300B1:1.7 (데이터 부족)
Chinchilla70B1.4T1:20 (최적)
Llama 3.18B15T1:1875 (데이터 초과)

Llama 3.1처럼 의도적으로 Chinchilla 최적보다 훨씬 더 많은 데이터를 학습하는 경우는 추론 효율성 때문입니다. 작은 모델이 많이 학습할수록 배포 비용이 줄어듭니다.

1.3 최적 컴퓨팅 예산 배분 계산

import math

def chinchilla_optimal(compute_budget_flops: float):
    """
    Chinchilla 법칙에 따른 최적 N, D 계산
    compute_budget_flops: 총 FLOPs (예: 6 * N * D 근사)
    반환: (optimal_params, optimal_tokens)
    """
    # Chinchilla 계수 (Hoffmann et al. Table A3 기준)
    # C = 6 * N * D 근사에서
    # N_opt = (C / (6 * 20))^0.5, D_opt = 20 * N_opt
    N_opt = math.sqrt(compute_budget_flops / (6 * 20))
    D_opt = 20 * N_opt
    return N_opt, D_opt

# 예시: A100 1000개, 30일 학습
# A100: ~312 TFLOPS (bf16), 가동률 40%
flops_per_sec = 1000 * 312e12 * 0.4
duration_sec = 30 * 24 * 3600
total_flops = flops_per_sec * duration_sec

N_opt, D_opt = chinchilla_optimal(total_flops)
print(f"총 FLOPs: {total_flops:.2e}")
print(f"최적 파라미터 수: {N_opt/1e9:.1f}B")
print(f"최적 토큰 수: {D_opt/1e12:.1f}T")

2. 데이터 준비

2.1 Common Crawl 필터링 파이프라인

Common Crawl은 월 수십 TB의 웹 크롤링 데이터를 제공합니다. 그대로 쓰면 품질이 매우 낮기 때문에 다단계 필터링이 필수입니다.

전형적인 필터링 파이프라인:

  1. 언어 감지 (fastText langdetect): 목표 언어만 유지
  2. 품질 필터: 최소 문장 수, 단어 반복 비율, 특수문자 비율
  3. 도메인 블랙리스트: 스팸, 성인, 광고 도메인 제거
  4. 퍼플렉시티 필터: n-gram 언어 모델 기반 낮은 품질 문서 제거
  5. 중복 제거: MinHash LSH 기반

2.2 MinHash 중복 제거

from datasketch import MinHash, MinHashLSH
import re

def get_shingles(text: str, k: int = 5):
    """문자 k-shingle 집합 생성"""
    text = re.sub(r'\s+', ' ', text.lower())
    return {text[i:i+k] for i in range(len(text) - k + 1)}

def build_minhash(text: str, num_perm: int = 128) -> MinHash:
    """텍스트에서 MinHash 서명 생성"""
    m = MinHash(num_perm=num_perm)
    for shingle in get_shingles(text):
        m.update(shingle.encode('utf-8'))
    return m

def deduplicate_corpus(documents: list, threshold: float = 0.8):
    """
    MinHash LSH로 유사 문서 제거
    threshold: Jaccard 유사도 임계값
    """
    lsh = MinHashLSH(threshold=threshold, num_perm=128)
    unique_docs = []
    seen = set()

    for idx, doc in enumerate(documents):
        mh = build_minhash(doc)
        result = lsh.query(mh)

        if len(result) == 0:
            lsh.insert(f"doc_{idx}", mh)
            unique_docs.append(doc)
        # 유사 문서가 이미 존재하면 스킵

    print(f"원본: {len(documents)}개 → 중복 제거 후: {len(unique_docs)}개")
    return unique_docs

MinHash의 핵심 아이디어: 두 집합의 Jaccard 유사도 J(A,B)=AB/ABJ(A,B) = |A \cap B| / |A \cup B|를 해시 함수 k개의 최솟값으로 추정합니다. 128개의 해시 함수를 사용하면 약 3% 오차 내에서 Jaccard 유사도를 추정할 수 있습니다.

2.3 토크나이저 학습

from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.processors import ByteLevel as ByteLevelProcessor

def train_bpe_tokenizer(
    corpus_files: list,
    vocab_size: int = 32000,
    save_path: str = "tokenizer.json"
):
    """BPE 토크나이저 학습 (GPT-2 스타일 Byte-level BPE)"""
    tokenizer = Tokenizer(BPE(unk_token=None))
    tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)
    tokenizer.post_processor = ByteLevelProcessor(trim_offsets=False)

    trainer = BpeTrainer(
        vocab_size=vocab_size,
        min_frequency=2,
        special_tokens=["<pad>", "<s>", "</s>", "<unk>"],
        show_progress=True,
    )

    tokenizer.train(files=corpus_files, trainer=trainer)
    tokenizer.save(save_path)
    print(f"토크나이저 저장 완료: {save_path} (vocab_size={vocab_size})")
    return tokenizer

# 사용 예시
# train_bpe_tokenizer(["data/train.txt"], vocab_size=32000)

2.4 데이터 혼합 비율

Llama 3.1의 사전 학습 데이터 혼합 비율을 참고하면:

데이터 소스비율
웹 크롤링 (Common Crawl 등)약 80%
코드 (GitHub)약 8%
학술/과학 논문약 5%
책 (Books3 등)약 4%
다국어 데이터약 3%

3. 아키텍처 선택

3.1 주요 오픈소스 LLM 설계 결정 비교

모델위치 임베딩AttentionFFNNormalization
GPT-NeoXALiBiMHASwiGLULayerNorm
LLaMA-3RoPEGQASwiGLURMSNorm
Mistral 7BRoPEGQA + SWASwiGLURMSNorm
Mixtral 8x7BRoPEGQAMoE (SwiGLU)RMSNorm

3.2 RoPE (Rotary Position Embedding)

RoPE는 절대 위치 정보를 직접 더하는 대신, attention score 계산 시 Query/Key 벡터에 회전 변환을 적용합니다.

위치 m에서의 쿼리 벡터 q와 위치 n에서의 키 벡터 k의 내적이 상대 위치 (m-n)의 함수가 됩니다. 이 성질 덕분에 훈련 길이를 넘어서는 긴 컨텍스트에도 어느 정도 외삽(extrapolation)이 가능합니다.

qmTkn=Re[j(q(j)eimθj)(k(j)einθj)]q_m^T k_n = \text{Re}\left[\sum_{j} (q_{(j)} e^{im\theta_j}) \overline{(k_{(j)} e^{in\theta_j})}\right]

3.3 Grouped Query Attention (GQA)

Multi-Head Attention(MHA)에서 KV 캐시는 추론 메모리의 큰 부분을 차지합니다. GQA는 여러 Query 헤드가 소수의 Key/Value 헤드를 공유합니다.

  • MHA: H개 쿼리, H개 키, H개 값 헤드
  • MQA: H개 쿼리, 1개 키, 1개 값 헤드
  • GQA: H개 쿼리, G개 키, G개 값 헤드 (G < H)

Llama 3 8B는 32개 쿼리 헤드에 8개 KV 헤드를 사용합니다. KV 캐시가 1/4로 줄어듭니다.

3.4 Mixture of Experts (MoE)

Mixtral 8x7B와 DeepSeek-V3는 MoE 아키텍처를 사용합니다. 각 토큰은 N개의 전문가(FFN 레이어) 중 Top-K개만 활성화합니다.

import torch
import torch.nn as nn
import torch.nn.functional as F

class MoELayer(nn.Module):
    """간단한 Mixture of Experts 레이어"""
    def __init__(self, d_model: int, d_ff: int, num_experts: int = 8, top_k: int = 2):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.router = nn.Linear(d_model, num_experts, bias=False)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_ff),
                nn.SiLU(),
                nn.Linear(d_ff, d_model),
            )
            for _ in range(num_experts)
        ])

    def forward(self, x: torch.Tensor):
        # x: (batch, seq_len, d_model)
        B, T, D = x.shape
        x_flat = x.view(-1, D)  # (B*T, D)

        # 라우터 계산
        router_logits = self.router(x_flat)  # (B*T, num_experts)
        router_probs = F.softmax(router_logits, dim=-1)

        # Top-K 전문가 선택
        topk_probs, topk_idx = router_probs.topk(self.top_k, dim=-1)
        topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True)  # 재정규화

        # Load Balancing Loss (전문가 균형 유지)
        # 각 전문가가 균등하게 선택되도록
        expert_load = router_probs.mean(0)  # (num_experts,)
        load_balance_loss = self.num_experts * (expert_load * expert_load.mean()).sum()

        # 전문가 출력 계산
        output = torch.zeros_like(x_flat)
        for k in range(self.top_k):
            expert_indices = topk_idx[:, k]
            expert_weights = topk_probs[:, k].unsqueeze(-1)
            for e_idx in range(self.num_experts):
                mask = (expert_indices == e_idx)
                if mask.any():
                    expert_out = self.experts[e_idx](x_flat[mask])
                    output[mask] += expert_weights[mask] * expert_out

        return output.view(B, T, D), load_balance_loss

Load Balancing Loss가 필요한 이유: 라우터가 특정 전문가만 선택하도록 수렴하면 나머지 전문가는 학습되지 않고, 선택된 전문가는 과부하가 걸립니다(expert collapse). Load balancing loss는 모든 전문가에게 균등한 토큰을 보내도록 정규화합니다.


4. 학습 안정성

4.1 학습률 스케줄: Cosine Warmup

import math

def cosine_lr_with_warmup(
    optimizer,
    step: int,
    warmup_steps: int,
    total_steps: int,
    max_lr: float,
    min_lr: float = 0.0,
):
    """Cosine Annealing with Linear Warmup"""
    if step < warmup_steps:
        lr = max_lr * step / warmup_steps
    else:
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

# 전형적인 사전 학습 설정
# warmup: 전체 스텝의 1~2%
# max_lr: 3e-4 (7B 모델 기준)
# min_lr: max_lr * 0.1

4.2 손실 스파이크 감지 및 복구

대규모 사전 학습에서 손실이 갑자기 치솟는 스파이크가 발생할 수 있습니다. 일반적인 대처 전략:

  1. Gradient Norm Clipping: 그래디언트 노름이 임계값을 초과하면 스케일 다운
  2. 스파이크 감지 후 롤백: 이전 체크포인트에서 재시작, 해당 배치 스킵
  3. Loss 이동 평균 모니터링: 급격한 상승 시 알람
import torch

def train_step_with_stability(model, optimizer, batch, grad_clip: float = 1.0):
    """손실 스파이크 방지를 위한 안전한 학습 스텝"""
    optimizer.zero_grad()
    loss = model(batch)
    loss.backward()

    # Gradient Norm 계산
    grad_norm = torch.nn.utils.clip_grad_norm_(
        model.parameters(), max_norm=grad_clip
    )

    # 비정상적으로 큰 그래디언트 감지
    if grad_norm > 100 * grad_clip:
        print(f"경고: 비정상 grad_norm={grad_norm:.2f}, 스텝 스킵")
        optimizer.zero_grad()
        return None, grad_norm

    optimizer.step()
    return loss.item(), grad_norm

5. 효율적 사전 학습

5.1 Flash Attention 2

Flash Attention은 GPU의 SRAM을 최대한 활용해 Attention 연산의 메모리 복잡도를 O(N2)O(N^2)에서 O(N)O(N)으로 줄입니다.

# Flash Attention 2 사용 예시
# pip install flash-attn --no-build-isolation

import torch
from flash_attn import flash_attn_func
from flash_attn.bert_padding import unpad_input, pad_input

def flash_attention_forward(
    q: torch.Tensor,  # (batch, seqlen, nheads, headdim)
    k: torch.Tensor,
    v: torch.Tensor,
    causal: bool = True,
    softmax_scale: float = None,
):
    """
    Flash Attention 2 순전파
    causal=True: 자동 회귀 언어 모델용 인과적 마스크 적용
    """
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)

    out = flash_attn_func(
        q, k, v,
        dropout_p=0.0,
        softmax_scale=softmax_scale,
        causal=causal,
    )
    return out  # (batch, seqlen, nheads, headdim)

Flash Attention 2는 표준 Attention 대비:

  • 메모리: 시퀀스 길이에 선형 비례 (기존 2차 비례 대비)
  • 속도: 2~4배 빠름 (A100 기준)
  • 수치 안정성: 동일 결과 보장

5.2 Sliding Window Attention (SWA)

Mistral 7B에서 사용된 SWA는 각 토큰이 최근 W개의 토큰에만 attend합니다. 긴 시퀀스에서 Attention 복잡도를 O(WN)O(W \cdot N)으로 줄입니다.


6. 평가와 체크포인팅

6.1 Perplexity 곡선 모니터링

import torch
import math

def compute_perplexity(model, dataloader, device: str = "cuda"):
    """검증 세트 퍼플렉시티 계산"""
    model.eval()
    total_loss = 0.0
    total_tokens = 0

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids=input_ids, labels=labels)
            # CrossEntropy loss는 이미 토큰당 평균
            batch_tokens = (labels != -100).sum().item()
            total_loss += outputs.loss.item() * batch_tokens
            total_tokens += batch_tokens

    avg_loss = total_loss / total_tokens
    perplexity = math.exp(avg_loss)
    return perplexity

# 전형적인 사전 학습 진행 기준
# GPT-2 수준: PPL ~20 (WebText)
# 좋은 7B 모델: PPL ~6~8 (검증 세트 기준)

6.2 lm-evaluation-harness 벤치마크

# lm-evaluation-harness 설치 및 실행
pip install lm-eval

# Llama 3 8B 평가 예시
lm_eval --model hf \
    --model_args pretrained=meta-llama/Meta-Llama-3-8B \
    --tasks mmlu,hellaswag,arc_challenge,winogrande \
    --device cuda:0 \
    --batch_size 8 \
    --output_path results/llama3-8b/

# 특정 체크포인트 평가
lm_eval --model hf \
    --model_args pretrained=./checkpoints/step-50000 \
    --tasks hellaswag \
    --num_fewshot 10 \
    --device cuda:0

6.3 체크포인트 전략

대규모 사전 학습에서 체크포인트는 세 가지 용도로 관리합니다.

체크포인트 유형저장 간격보존 기간
최근 N개 유지매 500 스텝최근 5개만
마일스톤매 5,000 스텝영구 보존
최고 성능검증 PPL 기준영구 보존

7. 최신 사전 학습 레시피 비교 (2024-2025)

7.1 Llama 3.1

  • 학습 데이터: 15T 토큰 (Llama 2의 7.5배)
  • 컨텍스트 길이: 128K (RoPE 확장 적용)
  • 어휘 크기: 128K (tiktoken 기반)
  • 특이사항: 사전 학습 후 긴 컨텍스트 어닐링 단계 추가

7.2 Mistral Large / Mistral 7B

  • Mistral 7B: GQA + SWA로 효율성 극대화
  • Mixtral 8x7B: 8개 전문가 중 2개 활성화 (실제 파라미터 13B 사용)
  • 컨텍스트: 32K (SWA 기반)

7.3 DeepSeek-V3

  • 아키텍처: 671B MoE, 토큰당 37B 활성화
  • 학습 비용: H800 2,048개 × 60일 (약 557만 달러 — 경쟁사 대비 10배 저렴)
  • 혁신: Multi-head Latent Attention (MLA), FP8 혼합 정밀도 학습
  • 데이터: 14.8T 토큰 (중국어·코드 비중 높음)

7.4 phi-4 (Microsoft)

  • 크기: 14B 파라미터
  • 전략: 합성 데이터 비중을 대폭 높여 데이터 품질 중심 접근
  • 학습 데이터: 9.8T 토큰 (합성 데이터 약 40%)
  • 성과: 70B+ 모델과 경쟁하는 수학·추론 벤치마크 성능

퀴즈

Q1. Chinchilla 법칙이 GPT-3보다 적은 파라미터로 더 좋은 성능을 낼 수 있음을 보인 핵심 주장은?

정답: GPT-3은 175B 파라미터에 300B 토큰만 학습했으나, Chinchilla는 파라미터 수와 데이터 수를 균형 있게 키워야 한다는 것을 실험적으로 증명했습니다.

설명: Chinchilla (Hoffmann et al. 2022)는 컴퓨팅 예산이 고정될 때 최적 파라미터 수와 최적 토큰 수가 대략 1:20 비율을 만족해야 한다고 주장합니다. GPT-3는 175B 파라미터 대비 약 3.5T 토큰이 필요했지만 실제로는 300B 토큰만 사용해 "데이터 부족" 상태였습니다. 같은 컴퓨팅 예산으로 70B 파라미터에 1.4T 토큰을 학습한 Chinchilla가 GPT-3를 능가했습니다.

Q2. MinHash 알고리즘이 대규모 텍스트 코퍼스에서 중복을 효율적으로 감지하는 방법은?

정답: 문서를 shingle(n-gram 집합)로 표현하고, 여러 해시 함수의 최솟값(MinHash 서명)으로 Jaccard 유사도를 근사 추정합니다. LSH(Locality Sensitive Hashing)로 유사 문서를 빠르게 후보 집합으로 좁힙니다.

설명: 두 문서의 Jaccard 유사도를 직접 계산하면 코퍼스 크기 N에 대해 O(N^2) 비교가 필요합니다. MinHash는 각 문서를 k개 해시 함수의 최솟값으로 이루어진 길이 k짜리 서명 벡터로 압축합니다. 두 서명의 해밍 유사도가 원래 Jaccard 유사도의 불편 추정량이 됩니다. LSH는 같은 버킷에 해시된 후보들만 정밀 비교해 전체 복잡도를 O(N)에 가깝게 줄입니다.

Q3. Grouped Query Attention(GQA)이 Multi-Head Attention보다 추론 메모리를 절감하는 원리는?

정답: GQA는 여러 Query 헤드가 소수의 Key/Value 헤드를 공유하므로, KV 캐시 크기가 Query 헤드 수에 비례하지 않고 KV 헤드 수에만 비례합니다.

설명: 자동 회귀 추론 시 이전 토큰의 Key/Value를 KV 캐시에 저장합니다. MHA에서 H개 헤드 모두가 별도 KV를 가지면 KV 캐시 크기는 배치 크기 × 시퀀스 길이 × H × head_dim × 2입니다. GQA에서 G개 KV 헤드만 쓰면 KV 캐시가 G/H로 줄어듭니다. Llama 3 8B (H=32, G=8)는 KV 캐시가 MHA 대비 1/4 수준입니다.

Q4. RoPE(Rotary Position Embedding)가 절대 위치 임베딩보다 긴 컨텍스트 외삽(extrapolation)에 유리한 이유는?

정답: RoPE는 Attention score가 절대 위치가 아닌 상대 위치 (m-n)의 함수가 되도록 설계되어, 훈련 시 본 적 없는 긴 거리의 상대적 위치 관계에도 일반화가 가능합니다.

설명: 절대 위치 임베딩(사인/코사인 또는 학습된 임베딩)은 특정 위치 인덱스에 대한 임베딩을 직접 추가합니다. 훈련 최대 길이를 넘는 위치는 분포 외(out-of-distribution) 입력이 됩니다. RoPE는 Q·K 내적이 두 위치의 차(상대 위치)에만 의존하도록 회전 행렬을 적용합니다. 이론적으로 훈련 길이를 넘어서도 상대 위치 패턴을 활용할 수 있어 YaRN, LongRoPE 같은 확장 기법의 기반이 됩니다.

Q5. Mixture of Experts(MoE)에서 전문가 라우팅의 load balancing loss가 필요한 이유는?

정답: 라우터가 학습 초기에 특정 전문가만 선택하도록 수렴하면(expert collapse), 나머지 전문가는 gradient를 받지 못해 훈련되지 않고 선택된 전문가만 과부하 됩니다. Load balancing loss는 모든 전문가에 균등한 토큰이 할당되도록 강제합니다.

설명: 라우터의 소프트맥스 출력은 각 전문가의 선택 확률입니다. 초기화가 균일해도 학습 중 양성 피드백으로 특정 전문가가 더 자주 선택 → 더 잘 학습 → 더 자주 선택되는 순환이 발생합니다. Load balancing loss는 전문가별 평균 선택 확률이 균등해지도록 auxiliary loss를 추가합니다. 이를 통해 모든 전문가가 충분히 학습되고 분산 학습 시 GPU 간 부하도 균형을 이룹니다.


마치며

LLM 사전 학습은 스케일링 법칙이라는 이론적 나침반, 데이터 품질이라는 실전적 과제, 그리고 메모리·속도 효율화라는 공학적 도전이 만나는 영역입니다. Chinchilla 법칙은 "크기만이 답이 아니다"를 수학적으로 증명했고, Flash Attention과 GQA는 대형 모델을 현실적인 비용으로 학습·배포할 수 있게 했습니다. DeepSeek-V3의 성공은 MoE 아키텍처와 효율적 구현이 결합될 때 비용 대비 최고 성능이 가능함을 보여줍니다. 이 가이드의 개념들을 실험해보며 직접 작은 스케일에서 사전 학습을 체험해 보시기 바랍니다.