Skip to content

Split View: LLM 처음부터 만들기: 코드로 이해하는 GPT 완전 구현 가이드

|

LLM 처음부터 만들기: 코드로 이해하는 GPT 완전 구현 가이드

들어가며

GPT, Llama, Mistral 같은 대규모 언어 모델(LLM)은 현재 AI 혁명의 중심에 있습니다. 하지만 이 모델들이 어떻게 동작하는지 코드 수준에서 이해하는 사람은 많지 않습니다. 이 가이드에서는 "miniGPT"라는 소형 GPT를 처음부터 PyTorch로 직접 구현하며, 현대 LLM의 모든 핵심 개념을 마스터합니다.


1. LLM 전체 아키텍처 개요

1.1 사전학습 언어 모델의 파이프라인

현대 LLM의 개발 파이프라인은 세 단계로 나뉩니다.

  1. 사전학습 (Pretraining): 수십억 개의 텍스트 토큰으로 다음 토큰 예측(next token prediction)을 학습합니다. 이 과정에서 모델은 언어의 통계적 패턴, 지식, 추론 능력을 습득합니다.

  2. 지시 파인튜닝 (Instruction Fine-Tuning / SFT): 사람이 작성한 질문-답변 쌍으로 모델이 지시를 따르도록 학습합니다.

  3. 선호도 학습 (RLHF / DPO): 사람의 선호도를 학습하여 더 유용하고 안전한 응답을 생성하도록 개선합니다.

1.2 miniGPT 사양

이 가이드에서 구현할 miniGPT의 사양:

파라미터
어휘 크기50,257 (GPT-2와 동일)
컨텍스트 길이1024
임베딩 차원768
레이어 수12
Attention 헤드 수12
FFN 확장 비율4x
총 파라미터약 124M

이는 GPT-2 Small과 동일한 사양으로, 단일 GPU에서 학습 가능합니다.


2. 토크나이저 구현

2.1 BPE 알고리즘

BPE(Byte-Pair Encoding)는 대부분의 현대 LLM이 사용하는 토크나이저 알고리즘입니다.

from collections import defaultdict
from typing import List, Tuple, Dict

class SimpleBPETokenizer:
    """BPE 토크나이저 처음부터 구현"""

    def __init__(self):
        self.vocab = {}
        self.merges = {}
        self.special_tokens = {
            "<|endoftext|>": 50256,
            "<|padding|>": 50257,
        }

    def get_stats(self, vocab: Dict[str, int]) -> Dict[Tuple, int]:
        """인접한 토큰 쌍의 빈도를 계산"""
        pairs = defaultdict(int)
        for word, freq in vocab.items():
            symbols = word.split()
            for i in range(len(symbols) - 1):
                pairs[(symbols[i], symbols[i+1])] += freq
        return pairs

    def merge_vocab(self, pair: Tuple, v_in: Dict) -> Dict:
        """가장 빈번한 쌍을 병합"""
        v_out = {}
        bigram = " ".join(pair)
        replacement = "".join(pair)
        for word in v_in:
            w_out = word.replace(bigram, replacement)
            v_out[w_out] = v_in[word]
        return v_out

    def train(self, corpus: List[str], vocab_size: int = 1000):
        """BPE 어휘 학습"""
        # 1. 문자 단위로 초기 어휘 구성
        word_freq = defaultdict(int)
        for text in corpus:
            for word in text.split():
                word_freq[" ".join(list(word)) + " </w>"] += 1

        vocab = dict(word_freq)

        # 2. vocab_size에 도달할 때까지 반복 병합
        for i in range(vocab_size):
            pairs = self.get_stats(vocab)
            if not pairs:
                break
            best = max(pairs, key=pairs.get)
            vocab = self.merge_vocab(best, vocab)
            self.merges[best] = i
            print(f"Step {i}: merged {best} (freq={pairs[best]})")

        # 3. 최종 어휘 구성
        for word in vocab:
            for token in word.split():
                if token not in self.vocab:
                    self.vocab[token] = len(self.vocab)

        return self.vocab, self.merges

2.2 tiktoken 호환 토크나이저 사용

실제로는 OpenAI의 tiktoken 라이브러리를 사용하는 것이 훨씬 효율적입니다.

import tiktoken
import torch
from typing import List

class GPTTokenizer:
    """tiktoken 기반 GPT 토크나이저 래퍼"""

    def __init__(self, encoding_name: str = "gpt2"):
        self.enc = tiktoken.get_encoding(encoding_name)
        self.vocab_size = self.enc.n_vocab
        self.eot_token = self.enc.eot_token  # 50256

    def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
        """텍스트를 토큰 ID 리스트로 변환"""
        ids = self.enc.encode(text, allowed_special={"<|endoftext|>"})
        if add_special_tokens:
            ids = [self.eot_token] + ids
        return ids

    def decode(self, ids: List[int]) -> str:
        """토큰 ID 리스트를 텍스트로 변환"""
        return self.enc.decode(ids)

    def encode_batch(self, texts: List[str]) -> List[List[int]]:
        """배치 인코딩"""
        return [self.encode(t) for t in texts]

    def __len__(self):
        return self.vocab_size

# 사용 예시
tokenizer = GPTTokenizer()

text = "안녕하세요! GPT를 처음부터 만들어봅시다."
ids = tokenizer.encode(text)
print(f"토큰 수: {len(ids)}")
print(f"토큰 ID: {ids}")
decoded = tokenizer.decode(ids)
print(f"디코딩: {decoded}")

3. 데이터 준비

3.1 텍스트 데이터셋 클래스

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os

class TextDataset(Dataset):
    """슬라이딩 윈도우 방식의 언어 모델링 데이터셋"""

    def __init__(
        self,
        data_path: str,
        tokenizer: GPTTokenizer,
        seq_len: int = 1024,
        stride: int = 512
    ):
        self.seq_len = seq_len
        self.stride = stride

        # 텍스트 파일 로딩 및 토크나이징
        print(f"데이터 로딩: {data_path}")
        with open(data_path, "r", encoding="utf-8") as f:
            text = f.read()

        print(f"총 문자 수: {len(text):,}")
        self.tokens = tokenizer.encode(text, add_special_tokens=False)
        print(f"총 토큰 수: {len(self.tokens):,}")

        # numpy 배열로 변환 (메모리 효율)
        self.tokens = np.array(self.tokens, dtype=np.int32)

    def __len__(self):
        return max(0, (len(self.tokens) - self.seq_len) // self.stride)

    def __getitem__(self, idx):
        start = idx * self.stride
        end = start + self.seq_len + 1  # +1은 레이블용

        chunk = self.tokens[start:end]

        # 입력: tokens[0:seq_len], 레이블: tokens[1:seq_len+1]
        x = torch.tensor(chunk[:-1], dtype=torch.long)
        y = torch.tensor(chunk[1:], dtype=torch.long)

        return x, y


def create_dataloader(
    data_path: str,
    tokenizer: GPTTokenizer,
    batch_size: int = 8,
    seq_len: int = 1024,
    stride: int = 512,
    num_workers: int = 4,
    shuffle: bool = True
) -> DataLoader:
    dataset = TextDataset(data_path, tokenizer, seq_len, stride)
    print(f"데이터셋 크기: {len(dataset):,} 샘플")

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True
    )

4. 모델 아키텍처 구현

4.1 설정 클래스

from dataclasses import dataclass

@dataclass
class GPTConfig:
    """GPT 모델 설정"""
    vocab_size: int = 50257
    max_seq_len: int = 1024
    n_embd: int = 768
    n_layer: int = 12
    n_head: int = 12
    dropout: float = 0.1
    bias: bool = False  # bias 없이도 좋은 성능

    # FFN 확장 비율
    ffn_mult: int = 4

    # SwiGLU 사용 여부
    use_swiglu: bool = True

    # RoPE 사용 여부
    use_rope: bool = True

4.2 RMSNorm

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization (Llama 스타일)"""

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        return self.weight * self._norm(x.float()).type_as(x)

4.3 Rotary Position Embedding (RoPE)

class RotaryEmbedding(nn.Module):
    """Rotary Position Embedding (RoPE)"""

    def __init__(self, dim: int, max_seq_len: int = 2048, base: int = 10000):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len

        # 주파수 계산: theta_i = base^(-2i/dim)
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

        # 코사인/사인 캐시 미리 계산
        self._build_cache(max_seq_len)

    def _build_cache(self, seq_len: int):
        t = torch.arange(seq_len, device=self.inv_freq.device).float()
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        self.register_buffer("cos_cache", emb.cos()[None, None, :, :])
        self.register_buffer("sin_cache", emb.sin()[None, None, :, :])

    def rotate_half(self, x):
        """텐서의 절반을 회전"""
        x1 = x[..., :x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2:]
        return torch.cat([-x2, x1], dim=-1)

    def forward(self, q: torch.Tensor, k: torch.Tensor, seq_len: int):
        cos = self.cos_cache[:, :, :seq_len, :].to(q.dtype)
        sin = self.sin_cache[:, :, :seq_len, :].to(q.dtype)

        q_rot = (q * cos) + (self.rotate_half(q) * sin)
        k_rot = (k * cos) + (self.rotate_half(k) * sin)
        return q_rot, k_rot

4.4 Multi-Head Causal Attention with KV Cache

class CausalSelfAttention(nn.Module):
    """인과 자기 어텐션 (KV 캐시 지원)"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        assert config.n_embd % config.n_head == 0

        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = config.n_embd // config.n_head

        # Q, K, V 프로젝션 (한 번에 처리)
        self.qkv_proj = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)

        # RoPE
        if config.use_rope:
            self.rope = RotaryEmbedding(self.head_dim, config.max_seq_len)
        else:
            self.rope = None

        # 인과 마스크 (하삼각 행렬)
        self.register_buffer(
            "causal_mask",
            torch.tril(torch.ones(config.max_seq_len, config.max_seq_len))
            .view(1, 1, config.max_seq_len, config.max_seq_len)
        )

        # KV 캐시
        self.use_cache = False
        self.cache_k = None
        self.cache_v = None

    def forward(
        self,
        x: torch.Tensor,
        use_cache: bool = False,
        past_kv=None
    ):
        B, T, C = x.shape  # batch, seq_len, n_embd

        # Q, K, V 분리
        qkv = self.qkv_proj(x)
        q, k, v = qkv.split(self.n_embd, dim=2)

        # 헤드 분할: [B, T, C] -> [B, n_head, T, head_dim]
        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)

        # RoPE 적용
        if self.rope is not None:
            q, k = self.rope(q, k, T)

        # KV 캐시 처리
        if past_kv is not None:
            past_k, past_v = past_kv
            k = torch.cat([past_k, k], dim=2)
            v = torch.cat([past_v, v], dim=2)

        present_kv = (k, v) if use_cache else None
        kv_len = k.shape[2]

        # Scaled Dot-Product Attention
        # PyTorch 2.0+ Flash Attention 자동 활용
        if hasattr(F, "scaled_dot_product_attention"):
            # Flash Attention 사용 (더 빠르고 메모리 효율적)
            y = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask=None,
                dropout_p=self.attn_dropout.p if self.training else 0.0,
                is_causal=True  # 자동으로 인과 마스크 적용
            )
        else:
            # 수동 구현
            scale = 1.0 / math.sqrt(self.head_dim)
            attn = (q @ k.transpose(-2, -1)) * scale
            mask = self.causal_mask[:, :, :T, :kv_len]
            attn = attn.masked_fill(mask == 0, float("-inf"))
            attn = F.softmax(attn, dim=-1)
            attn = self.attn_dropout(attn)
            y = attn @ v

        # 헤드 합치기: [B, n_head, T, head_dim] -> [B, T, C]
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.out_proj(y))

        return y, present_kv

4.5 SwiGLU Feed-Forward Network

class SwiGLUFFN(nn.Module):
    """SwiGLU 활성화 함수를 사용하는 FFN (Llama 스타일)"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        hidden_dim = int(config.n_embd * config.ffn_mult * 2 / 3)
        # 64의 배수로 반올림 (하드웨어 최적화)
        hidden_dim = ((hidden_dim + 63) // 64) * 64

        self.gate_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)
        self.up_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)
        self.down_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # SwiGLU: gate * SiLU(x)
        gate = F.silu(self.gate_proj(x))
        up = self.up_proj(x)
        return self.dropout(self.down_proj(gate * up))


class GPTFFN(nn.Module):
    """원래 GPT 스타일 GELU FFN"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        hidden_dim = config.n_embd * config.ffn_mult
        self.fc1 = nn.Linear(config.n_embd, hidden_dim, bias=config.bias)
        self.fc2 = nn.Linear(hidden_dim, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.gelu(self.fc1(x))
        x = self.dropout(self.fc2(x))
        return x

4.6 Transformer Block

class TransformerBlock(nn.Module):
    """Pre-RMSNorm Transformer 블록"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        self.norm1 = RMSNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.norm2 = RMSNorm(config.n_embd)

        if config.use_swiglu:
            self.ffn = SwiGLUFFN(config)
        else:
            self.ffn = GPTFFN(config)

    def forward(self, x: torch.Tensor, use_cache: bool = False, past_kv=None):
        # Pre-Norm + Attention + Residual
        attn_out, present_kv = self.attn(
            self.norm1(x),
            use_cache=use_cache,
            past_kv=past_kv
        )
        x = x + attn_out

        # Pre-Norm + FFN + Residual
        x = x + self.ffn(self.norm2(x))

        return x, present_kv

4.7 전체 GPT 모델

class MiniGPT(nn.Module):
    """miniGPT 전체 모델"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config

        # 임베딩 레이어
        self.token_emb = nn.Embedding(config.vocab_size, config.n_embd)

        # RoPE를 사용하지 않는 경우 학습 가능한 위치 임베딩
        if not config.use_rope:
            self.pos_emb = nn.Embedding(config.max_seq_len, config.n_embd)

        self.emb_dropout = nn.Dropout(config.dropout)

        # Transformer 레이어
        self.layers = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.n_layer)
        ])

        # 최종 정규화
        self.norm_final = RMSNorm(config.n_embd)

        # 언어 모델 헤드 (어휘 크기로 투영)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # 가중치 공유 (입력 임베딩 == 출력 임베딩)
        self.lm_head.weight = self.token_emb.weight

        # 가중치 초기화
        self.apply(self._init_weights)

        # 특별 초기화: 잔차 스트림 스케일링
        for pn, p in self.named_parameters():
            if pn.endswith("out_proj.weight") or pn.endswith("down_proj.weight"):
                nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))

        print(f"파라미터 수: {self.count_params() / 1e6:.1f}M")

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def count_params(self) -> int:
        return sum(p.numel() for p in self.parameters())

    def forward(
        self,
        input_ids: torch.Tensor,
        targets: torch.Tensor = None,
        use_cache: bool = False,
        past_kvs: list = None
    ):
        B, T = input_ids.shape
        device = input_ids.device

        # 토큰 임베딩
        x = self.token_emb(input_ids)  # [B, T, n_embd]

        # 위치 임베딩 (RoPE 미사용 시)
        if not self.config.use_rope:
            positions = torch.arange(T, device=device)
            x = x + self.pos_emb(positions)

        x = self.emb_dropout(x)

        # Transformer 레이어 통과
        present_kvs = []
        for i, layer in enumerate(self.layers):
            past_kv = past_kvs[i] if past_kvs is not None else None
            x, present_kv = layer(x, use_cache=use_cache, past_kv=past_kv)
            present_kvs.append(present_kv)

        # 최종 정규화
        x = self.norm_final(x)

        if targets is not None:
            # 학습: 전체 시퀀스에 대해 손실 계산
            logits = self.lm_head(x)
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
                ignore_index=-1
            )
            return logits, loss
        else:
            # 추론: 마지막 토큰의 로짓만 필요
            logits = self.lm_head(x[:, [-1], :])
            return logits, present_kvs

    @classmethod
    def from_pretrained(cls, model_path: str):
        """저장된 체크포인트에서 모델 로딩"""
        checkpoint = torch.load(model_path, map_location="cpu")
        config = checkpoint["config"]
        model = cls(config)
        model.load_state_dict(checkpoint["model"])
        return model

5. 사전학습 구현

5.1 학습률 스케줄러

import math

def get_cosine_schedule_with_warmup(
    optimizer,
    warmup_steps: int,
    total_steps: int,
    min_lr_ratio: float = 0.1
):
    """Cosine LR 스케줄 with warmup"""

    def lr_lambda(current_step: int):
        if current_step < warmup_steps:
            # Linear warmup
            return current_step / max(1, warmup_steps)
        else:
            # Cosine decay
            progress = (current_step - warmup_steps) / max(1, total_steps - warmup_steps)
            cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
            return max(min_lr_ratio, cosine_decay)

    from torch.optim.lr_scheduler import LambdaLR
    return LambdaLR(optimizer, lr_lambda)

5.2 학습 루프

import torch
import torch.nn as nn
from torch.optim import AdamW
import os
import time
from contextlib import nullcontext

class Trainer:
    """miniGPT 사전학습 트레이너"""

    def __init__(
        self,
        model: MiniGPT,
        train_dataloader,
        val_dataloader,
        config: dict
    ):
        self.model = model
        self.train_dl = train_dataloader
        self.val_dl = val_dataloader
        self.config = config
        self.device = config.get("device", "cuda" if torch.cuda.is_available() else "cpu")

        # Mixed Precision 설정
        self.dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        self.scaler = torch.cuda.amp.GradScaler(enabled=(self.dtype == torch.float16))
        self.ctx = torch.amp.autocast(device_type="cuda", dtype=self.dtype)

        # 옵티마이저 설정
        self.optimizer = self._create_optimizer()

        # 학습률 스케줄러
        total_steps = len(train_dataloader) * config["num_epochs"]
        warmup_steps = int(total_steps * config.get("warmup_ratio", 0.05))
        self.scheduler = get_cosine_schedule_with_warmup(
            self.optimizer,
            warmup_steps=warmup_steps,
            total_steps=total_steps
        )

        self.step = 0
        self.best_val_loss = float("inf")

    def _create_optimizer(self):
        """Weight decay를 적용할 파라미터와 그렇지 않을 파라미터 분리"""
        decay_params = []
        no_decay_params = []

        for name, param in self.model.named_parameters():
            if not param.requires_grad:
                continue
            # 1D 텐서 (bias, norm weight)는 weight decay 적용 안함
            if param.dim() == 1 or "emb" in name:
                no_decay_params.append(param)
            else:
                decay_params.append(param)

        optim_groups = [
            {"params": decay_params, "weight_decay": self.config.get("weight_decay", 0.1)},
            {"params": no_decay_params, "weight_decay": 0.0}
        ]

        return AdamW(
            optim_groups,
            lr=self.config["learning_rate"],
            betas=(0.9, 0.95),
            fused=True  # CUDA fused 옵티마이저
        )

    def compute_val_loss(self) -> float:
        """검증 손실 계산"""
        self.model.eval()
        total_loss = 0.0
        num_batches = min(len(self.val_dl), 50)  # 최대 50 배치만 평가

        with torch.no_grad():
            for i, (x, y) in enumerate(self.val_dl):
                if i >= num_batches:
                    break
                x, y = x.to(self.device), y.to(self.device)
                with self.ctx:
                    _, loss = self.model(x, y)
                total_loss += loss.item()

        self.model.train()
        return total_loss / num_batches

    def save_checkpoint(self, path: str):
        """체크포인트 저장"""
        os.makedirs(os.path.dirname(path), exist_ok=True)
        torch.save({
            "step": self.step,
            "model": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "scheduler": self.scheduler.state_dict(),
            "config": self.model.config,
            "best_val_loss": self.best_val_loss
        }, path)
        print(f"체크포인트 저장: {path}")

    def train(self):
        """메인 학습 루프"""
        self.model.to(self.device)
        self.model.train()

        num_epochs = self.config["num_epochs"]
        grad_accum = self.config.get("gradient_accumulation_steps", 1)
        max_grad_norm = self.config.get("max_grad_norm", 1.0)
        log_interval = self.config.get("log_interval", 100)
        save_interval = self.config.get("save_interval", 1000)
        eval_interval = self.config.get("eval_interval", 500)

        print(f"학습 시작: {self.device}, dtype={self.dtype}")

        for epoch in range(num_epochs):
            epoch_loss = 0.0
            t0 = time.time()

            for batch_idx, (x, y) in enumerate(self.train_dl):
                x, y = x.to(self.device), y.to(self.device)

                # Mixed Precision Forward
                with self.ctx:
                    _, loss = self.model(x, y)
                    loss = loss / grad_accum

                # Backward
                self.scaler.scale(loss).backward()

                # 그래디언트 누적
                if (batch_idx + 1) % grad_accum == 0:
                    # Gradient Clipping
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), max_grad_norm
                    )

                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    self.scheduler.step()
                    self.optimizer.zero_grad(set_to_none=True)
                    self.step += 1

                epoch_loss += loss.item() * grad_accum

                # 로깅
                if self.step % log_interval == 0:
                    lr = self.optimizer.param_groups[0]["lr"]
                    elapsed = time.time() - t0
                    tokens_per_sec = (
                        log_interval * x.shape[0] * x.shape[1] / elapsed
                    )
                    print(
                        f"Epoch {epoch} | Step {self.step} | "
                        f"Loss: {loss.item() * grad_accum:.4f} | "
                        f"LR: {lr:.6f} | "
                        f"Tokens/s: {tokens_per_sec:.0f}"
                    )
                    t0 = time.time()

                # 평가
                if self.step % eval_interval == 0:
                    val_loss = self.compute_val_loss()
                    print(f"Val Loss: {val_loss:.4f} | Perplexity: {math.exp(val_loss):.2f}")

                    if val_loss < self.best_val_loss:
                        self.best_val_loss = val_loss
                        self.save_checkpoint("./checkpoints/best_model.pt")

                # 주기적 저장
                if self.step % save_interval == 0:
                    self.save_checkpoint(f"./checkpoints/step_{self.step}.pt")

            avg_loss = epoch_loss / len(self.train_dl)
            print(f"\n=== Epoch {epoch} 완료 | 평균 Loss: {avg_loss:.4f} ===\n")

6. 텍스트 생성

6.1 다양한 디코딩 전략

import torch
import torch.nn.functional as F
from typing import Optional, List

class TextGenerator:
    """miniGPT 텍스트 생성기"""

    def __init__(self, model: MiniGPT, tokenizer: GPTTokenizer, device: str = "cuda"):
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.device = device
        self.model.eval()

    @torch.no_grad()
    def generate(
        self,
        prompt: str,
        max_new_tokens: int = 200,
        strategy: str = "top_p",
        temperature: float = 0.8,
        top_k: int = 50,
        top_p: float = 0.9,
        repetition_penalty: float = 1.1,
        num_beams: int = 4
    ) -> str:
        """텍스트 생성 메인 함수"""
        input_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
        input_ids = torch.tensor(input_ids, dtype=torch.long, device=self.device).unsqueeze(0)

        if strategy == "greedy":
            generated = self._greedy_decode(input_ids, max_new_tokens)
        elif strategy == "temperature":
            generated = self._temperature_sampling(input_ids, max_new_tokens, temperature)
        elif strategy == "top_k":
            generated = self._top_k_sampling(input_ids, max_new_tokens, temperature, top_k)
        elif strategy == "top_p":
            generated = self._top_p_sampling(input_ids, max_new_tokens, temperature, top_p)
        elif strategy == "beam":
            generated = self._beam_search(input_ids, max_new_tokens, num_beams)
        else:
            raise ValueError(f"알 수 없는 전략: {strategy}")

        # 생성된 토큰만 디코딩
        new_tokens = generated[0][input_ids.shape[1]:].tolist()
        return self.tokenizer.decode(new_tokens)

    def _greedy_decode(self, input_ids: torch.Tensor, max_new_tokens: int) -> torch.Tensor:
        """탐욕적 디코딩: 항상 가장 높은 확률의 토큰 선택"""
        current_ids = input_ids.clone()

        for _ in range(max_new_tokens):
            logits, _ = self.model(current_ids)
            next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
            current_ids = torch.cat([current_ids, next_token], dim=1)

            if next_token.item() == self.tokenizer.eot_token:
                break

        return current_ids

    def _temperature_sampling(
        self,
        input_ids: torch.Tensor,
        max_new_tokens: int,
        temperature: float
    ) -> torch.Tensor:
        """Temperature 샘플링: 확률 분포를 조정하여 다양성 조절"""
        current_ids = input_ids.clone()

        for _ in range(max_new_tokens):
            logits, _ = self.model(current_ids)
            logits = logits[:, -1, :] / temperature  # 온도로 나누기
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            current_ids = torch.cat([current_ids, next_token], dim=1)

            if next_token.item() == self.tokenizer.eot_token:
                break

        return current_ids

    def _top_k_sampling(
        self,
        input_ids: torch.Tensor,
        max_new_tokens: int,
        temperature: float,
        top_k: int
    ) -> torch.Tensor:
        """Top-k 샘플링: 상위 k개 토큰 중에서만 샘플링"""
        current_ids = input_ids.clone()

        for _ in range(max_new_tokens):
            logits, _ = self.model(current_ids)
            logits = logits[:, -1, :] / temperature

            # 상위 k개만 유지, 나머지는 -inf
            top_k_logits, top_k_indices = torch.topk(logits, k=top_k, dim=-1)
            filtered_logits = torch.full_like(logits, float("-inf"))
            filtered_logits.scatter_(1, top_k_indices, top_k_logits)

            probs = F.softmax(filtered_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            current_ids = torch.cat([current_ids, next_token], dim=1)

            if next_token.item() == self.tokenizer.eot_token:
                break

        return current_ids

    def _top_p_sampling(
        self,
        input_ids: torch.Tensor,
        max_new_tokens: int,
        temperature: float,
        top_p: float
    ) -> torch.Tensor:
        """Top-p (Nucleus) 샘플링: 누적 확률 p 이내의 토큰에서 샘플링"""
        current_ids = input_ids.clone()

        for _ in range(max_new_tokens):
            logits, _ = self.model(current_ids)
            logits = logits[:, -1, :] / temperature

            # 내림차순 정렬
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

            # 누적 확률이 top_p를 초과하는 토큰 제거
            sorted_indices_to_remove = cumulative_probs - F.softmax(sorted_logits, dim=-1) > top_p
            sorted_logits[sorted_indices_to_remove] = float("-inf")

            # 원래 순서로 복원
            logits = torch.zeros_like(logits).scatter_(1, sorted_indices, sorted_logits)
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            current_ids = torch.cat([current_ids, next_token], dim=1)

            if next_token.item() == self.tokenizer.eot_token:
                break

        return current_ids

    def _beam_search(
        self,
        input_ids: torch.Tensor,
        max_new_tokens: int,
        num_beams: int
    ) -> torch.Tensor:
        """빔 서치: 여러 후보를 동시에 탐색"""
        B = input_ids.shape[0]
        assert B == 1, "빔 서치는 배치 크기 1만 지원"

        # 빔 초기화
        beams = [(input_ids, 0.0)]  # (ids, score)

        for _ in range(max_new_tokens):
            all_candidates = []

            for beam_ids, beam_score in beams:
                logits, _ = self.model(beam_ids)
                log_probs = F.log_softmax(logits[:, -1, :], dim=-1)

                # 상위 num_beams 토큰 선택
                top_log_probs, top_ids = log_probs.topk(num_beams)

                for i in range(num_beams):
                    token = top_ids[0, i].unsqueeze(0).unsqueeze(0)
                    new_ids = torch.cat([beam_ids, token], dim=1)
                    new_score = beam_score + top_log_probs[0, i].item()
                    all_candidates.append((new_ids, new_score))

            # 길이 정규화 후 상위 num_beams 선택
            all_candidates.sort(key=lambda x: x[1] / x[0].shape[1], reverse=True)
            beams = all_candidates[:num_beams]

            # EOT 토큰으로 끝난 빔이 있으면 중단
            if beams[0][0][0, -1].item() == self.tokenizer.eot_token:
                break

        return beams[0][0]  # 최고 점수 빔 반환

7. Perplexity 계산

import math
import torch

@torch.no_grad()
def compute_perplexity(
    model: MiniGPT,
    tokenizer: GPTTokenizer,
    text: str,
    device: str = "cuda",
    seq_len: int = 512
) -> float:
    """텍스트의 Perplexity 계산"""
    model.eval()
    model.to(device)

    tokens = tokenizer.encode(text, add_special_tokens=False)
    total_loss = 0.0
    total_tokens = 0

    for i in range(0, len(tokens) - seq_len, seq_len):
        chunk = tokens[i:i + seq_len + 1]
        if len(chunk) < seq_len + 1:
            break

        x = torch.tensor(chunk[:-1], dtype=torch.long, device=device).unsqueeze(0)
        y = torch.tensor(chunk[1:], dtype=torch.long, device=device).unsqueeze(0)

        _, loss = model(x, y)
        total_loss += loss.item() * seq_len
        total_tokens += seq_len

    avg_loss = total_loss / total_tokens
    perplexity = math.exp(avg_loss)
    return perplexity


# 사용 예시
model = MiniGPT.from_pretrained("./checkpoints/best_model.pt")
tokenizer = GPTTokenizer()
generator = TextGenerator(model, tokenizer)

test_text = "The quick brown fox jumps over the lazy dog."
ppl = compute_perplexity(model, tokenizer, test_text)
print(f"Perplexity: {ppl:.2f}")

# 텍스트 생성 테스트
for strategy in ["greedy", "top_k", "top_p"]:
    generated = generator.generate(
        prompt="Once upon a time",
        max_new_tokens=100,
        strategy=strategy,
        temperature=0.8
    )
    print(f"\n[{strategy}]: {generated}")

8. 규모 확장

8.1 Flash Attention 2

# pip install flash-attn

from flash_attn import flash_attn_qkvpacked_func, flash_attn_func

class FlashCausalAttention(nn.Module):
    """Flash Attention 2를 사용한 고속 어텐션"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = config.n_embd // config.n_head

        self.qkv_proj = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
        self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.dropout = config.dropout

    def forward(self, x: torch.Tensor):
        B, T, C = x.shape

        qkv = self.qkv_proj(x)
        # Flash Attention은 [B, T, 3, n_head, head_dim] 형식 필요
        qkv = qkv.view(B, T, 3, self.n_head, self.head_dim)

        # Flash Attention 호출 (자동으로 인과 마스크 적용)
        attn_out = flash_attn_qkvpacked_func(
            qkv,
            dropout_p=self.dropout if self.training else 0.0,
            causal=True  # 인과 마스크
        )  # [B, T, n_head, head_dim]

        attn_out = attn_out.reshape(B, T, C)
        return self.out_proj(attn_out)

8.2 FSDP 분산 학습

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import functools

def setup_fsdp_training():
    """FSDP 분산 학습 설정"""
    dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    torch.cuda.set_device(rank)

    config = GPTConfig(
        n_embd=1024,
        n_layer=24,
        n_head=16,
        vocab_size=50257
    )
    model = MiniGPT(config)

    # FSDP 래핑 정책: TransformerBlock 단위로 샤딩
    auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={TransformerBlock}
    )

    model = FSDP(
        model,
        auto_wrap_policy=auto_wrap_policy,
        mixed_precision=torch.distributed.fsdp.MixedPrecision(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.bfloat16,
            buffer_dtype=torch.bfloat16
        ),
        sharding_strategy="FULL_SHARD",  # ZeRO-3 동급
        device_id=rank
    )

    return model, rank, world_size

8.3 Chinchilla 스케일링 법칙

Chinchilla 논문(Hoffmann et al., 2022)에 따르면, 최적의 모델 크기와 학습 토큰 수의 관계:

최적 토큰 수 = 20 x 파라미터 수

def chinchilla_optimal_tokens(model_params: int) -> int:
    """Chinchilla 법칙에 따른 최적 학습 토큰 수"""
    return 20 * model_params

# 예시:
models = {
    "124M (GPT-2 Small)": 124e6,
    "1.3B": 1.3e9,
    "7B (Llama-2 7B)": 7e9,
    "13B": 13e9,
    "70B (Llama-2 70B)": 70e9,
}

print("모델 크기별 권장 학습 토큰 수:")
for name, params in models.items():
    optimal_tokens = chinchilla_optimal_tokens(int(params))
    print(f"  {name}: {optimal_tokens/1e9:.1f}B 토큰")

9. Instruction Tuning

9.1 SFT 데이터 형식

# ChatML 형식
CHATML_SYSTEM = "<|im_start|>system\n{system}<|im_end|>\n"
CHATML_USER = "<|im_start|>user\n{user}<|im_end|>\n"
CHATML_ASSISTANT = "<|im_start|>assistant\n{assistant}<|im_end|>\n"

def format_chatml(
    system: str,
    user: str,
    assistant: str,
    add_generation_prompt: bool = False
) -> str:
    """ChatML 형식으로 대화 포맷팅"""
    formatted = CHATML_SYSTEM.format(system=system)
    formatted += CHATML_USER.format(user=user)
    formatted += CHATML_ASSISTANT.format(assistant=assistant)
    if add_generation_prompt:
        formatted += "<|im_start|>assistant\n"
    return formatted


# Llama-3 형식
def format_llama3(
    system: str,
    user: str,
    assistant: str,
    add_generation_prompt: bool = False
) -> str:
    """Llama-3 형식으로 대화 포맷팅"""
    formatted = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system}<|eot_id|>\n"
    formatted += f"<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|>\n"
    formatted += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant}<|eot_id|>"
    if add_generation_prompt:
        formatted += "\n<|start_header_id|>assistant<|end_header_id|>\n\n"
    return formatted

9.2 SFT 데이터셋 클래스

class InstructionDataset(Dataset):
    """지시 파인튜닝 데이터셋 (어시스턴트 응답에만 손실 계산)"""

    def __init__(
        self,
        data: list,
        tokenizer: GPTTokenizer,
        max_seq_len: int = 2048,
        format_func = format_chatml
    ):
        self.samples = []

        for item in data:
            # 전체 대화 토크나이징
            full_text = format_func(
                system=item.get("system", "You are a helpful assistant."),
                user=item["instruction"],
                assistant=item["response"]
            )
            full_ids = tokenizer.encode(full_text, add_special_tokens=False)

            # 프롬프트 부분 (손실 마스킹 대상)
            prompt_text = format_func(
                system=item.get("system", "You are a helpful assistant."),
                user=item["instruction"],
                assistant="",
                add_generation_prompt=True
            )
            prompt_ids = tokenizer.encode(prompt_text, add_special_tokens=False)
            prompt_len = len(prompt_ids)

            if len(full_ids) > max_seq_len:
                full_ids = full_ids[:max_seq_len]

            # 레이블: 프롬프트 부분은 -100으로 마스킹 (손실 계산 제외)
            labels = full_ids.copy()
            labels[:prompt_len] = [-100] * prompt_len

            self.samples.append({
                "input_ids": full_ids,
                "labels": labels
            })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        return (
            torch.tensor(sample["input_ids"], dtype=torch.long),
            torch.tensor(sample["labels"], dtype=torch.long)
        )

9.3 SFT 학습 루프

def train_sft(
    base_model_path: str,
    data_path: str,
    output_dir: str,
    num_epochs: int = 3,
    learning_rate: float = 2e-5,
    batch_size: int = 4
):
    """지시 파인튜닝 (SFT)"""
    import json

    # 데이터 로딩
    with open(data_path) as f:
        data = json.load(f)

    tokenizer = GPTTokenizer()
    dataset = InstructionDataset(data, tokenizer, max_seq_len=2048)

    # 패딩 콜레이터
    def collate_fn(batch):
        input_ids = [item[0] for item in batch]
        labels = [item[1] for item in batch]

        # 배치 내 최대 길이로 패딩
        max_len = max(len(x) for x in input_ids)
        padded_ids = torch.zeros(len(batch), max_len, dtype=torch.long)
        padded_labels = torch.full((len(batch), max_len), -100, dtype=torch.long)

        for i, (ids, lbls) in enumerate(zip(input_ids, labels)):
            padded_ids[i, :len(ids)] = ids
            padded_labels[i, :len(lbls)] = lbls

        return padded_ids, padded_labels

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

    # 모델 로딩
    model = MiniGPT.from_pretrained(base_model_path)
    device = "cuda"
    model.to(device)

    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    total_steps = len(dataloader) * num_epochs
    scheduler = get_cosine_schedule_with_warmup(optimizer, int(total_steps * 0.1), total_steps)

    # 학습
    model.train()
    for epoch in range(num_epochs):
        for step, (x, y) in enumerate(dataloader):
            x, y = x.to(device), y.to(device)

            logits, loss = model(x, y)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            if step % 50 == 0:
                print(f"Epoch {epoch}, Step {step}, Loss: {loss.item():.4f}")

    # 저장
    os.makedirs(output_dir, exist_ok=True)
    torch.save({
        "model": model.state_dict(),
        "config": model.config
    }, f"{output_dir}/sft_model.pt")

    print(f"SFT 완료. 모델 저장: {output_dir}")

10. 오픈소스 LLM 분석

10.1 Llama 3 구조 분석

Llama 3는 Meta에서 공개한 오픈소스 LLM으로, 다음과 같은 특징을 가집니다.

특징설명
아키텍처Decoder-only Transformer
정규화Pre-RMSNorm
활성화SwiGLU
위치 인코딩RoPE (theta=500,000)
어휘 크기128,256
컨텍스트 길이128K (Llama 3.1+)
GQAGrouped Query Attention (8B, 70B)
# Llama 3 스타일 설정 (8B 모델 기준)
llama3_config = GPTConfig(
    vocab_size=128256,
    max_seq_len=8192,
    n_embd=4096,
    n_layer=32,
    n_head=32,
    use_rope=True,
    use_swiglu=True,
    bias=False
)

10.2 Grouped Query Attention (GQA)

class GroupedQueryAttention(nn.Module):
    """Grouped Query Attention (Llama 3, Mistral 사용)"""

    def __init__(self, config: GPTConfig, n_kv_heads: int = 8):
        super().__init__()
        self.n_head = config.n_head          # Query 헤드 수
        self.n_kv_heads = n_kv_heads         # KV 헤드 수 (더 적음)
        self.n_rep = config.n_head // n_kv_heads  # 각 KV 헤드를 몇 번 반복할지
        self.head_dim = config.n_embd // config.n_head

        self.q_proj = nn.Linear(config.n_embd, config.n_head * self.head_dim, bias=False)
        self.k_proj = nn.Linear(config.n_embd, n_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(config.n_embd, n_kv_heads * self.head_dim, bias=False)
        self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)

        self.rope = RotaryEmbedding(self.head_dim, config.max_seq_len, base=500000)

    def forward(self, x: torch.Tensor):
        B, T, C = x.shape

        q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)

        # RoPE 적용
        q, k = self.rope(q, k, T)

        # KV 헤드 반복 (GQA의 핵심)
        # [B, n_kv_heads, T, head_dim] -> [B, n_head, T, head_dim]
        k = k.unsqueeze(2).expand(B, self.n_kv_heads, self.n_rep, T, self.head_dim)
        k = k.reshape(B, self.n_head, T, self.head_dim)
        v = v.unsqueeze(2).expand(B, self.n_kv_heads, self.n_rep, T, self.head_dim)
        v = v.reshape(B, self.n_head, T, self.head_dim)

        # Flash Attention
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(y)

10.3 DeepSeek 혁신

DeepSeek은 여러 기술적 혁신을 도입했습니다.

MLA (Multi-head Latent Attention): K, V를 저차원 잠재 공간으로 압축하여 KV 캐시 메모리를 획기적으로 줄입니다.

DeepSeekMoE: 전문가 혼합(MoE) 아키텍처를 미세화된 방식으로 적용합니다.

class MoEFFN(nn.Module):
    """Mixture of Experts FFN (간략화 버전)"""

    def __init__(self, config: GPTConfig, num_experts: int = 8, top_k: int = 2):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k

        # 라우터
        self.router = nn.Linear(config.n_embd, num_experts, bias=False)

        # 각 전문가 FFN
        self.experts = nn.ModuleList([
            SwiGLUFFN(config) for _ in range(num_experts)
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape
        x_flat = x.view(B * T, C)

        # 라우팅 점수 계산
        router_logits = self.router(x_flat)  # [B*T, num_experts]
        router_probs = F.softmax(router_logits, dim=-1)

        # 상위 k 전문가 선택
        top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)  # 정규화

        # 전문가 출력 계산
        output = torch.zeros_like(x_flat)
        for k in range(self.top_k):
            expert_idx = top_k_indices[:, k]  # [B*T]
            expert_prob = top_k_probs[:, k].unsqueeze(-1)  # [B*T, 1]

            for e in range(self.num_experts):
                mask = (expert_idx == e)
                if mask.any():
                    expert_out = self.experts[e](x_flat[mask])
                    output[mask] += expert_prob[mask] * expert_out

        return output.view(B, T, C)

마치며

이 가이드에서 miniGPT를 처음부터 완전히 구현하며 현대 LLM의 모든 핵심 요소를 살펴보았습니다.

핵심 컴포넌트 요약:

  • 토크나이저: BPE 알고리즘으로 텍스트를 서브워드 토큰으로 분리
  • 임베딩: 토큰 임베딩 + RoPE 위치 인코딩
  • Multi-Head Causal Attention: Q/K/V 투영, 인과 마스크, KV 캐시
  • FFN: SwiGLU 활성화로 비선형성 추가
  • Pre-RMSNorm: 학습 안정성
  • 텍스트 생성: Greedy, Temperature, Top-k, Top-p, Beam Search
  • 사전학습: Cosine LR Schedule, Gradient Clipping, Mixed Precision
  • SFT: 응답 부분에만 손실 계산

이 기반 위에 Llama 3의 GQA, DeepSeek의 MoE 같은 현대적 기법들이 추가됩니다. 직접 구현하며 이해한 내용은 HuggingFace 라이브러리를 더 깊이 활용하는 데 큰 도움이 될 것입니다.

참고 자료

Building LLM from Scratch: Complete Guide to Understanding GPT through Code

Introduction

GPT, Llama, Mistral — large language models (LLMs) are at the center of the current AI revolution. Yet very few people understand how these models work at the code level. In this guide, we build "miniGPT," a small GPT-style model from scratch using PyTorch, mastering every core concept behind modern LLMs.


1. LLM Architecture Overview

1.1 The Pretraining Pipeline

Modern LLM development follows three stages:

  1. Pretraining: The model learns to predict the next token from billions of text tokens. During this stage, it acquires statistical language patterns, world knowledge, and reasoning ability.

  2. Instruction Fine-Tuning (SFT): The model is trained on human-written question-answer pairs so that it learns to follow instructions.

  3. Preference Learning (RLHF / DPO): The model is refined using human preference data to produce more helpful and safe responses.

1.2 miniGPT Specification

ParameterValue
Vocabulary size50,257 (same as GPT-2)
Context length1024
Embedding dimension768
Number of layers12
Attention heads12
FFN expansion ratio4x
Total parameters~124M

This matches GPT-2 Small and is trainable on a single GPU.


2. Tokenizer Implementation

2.1 BPE Algorithm

Byte-Pair Encoding (BPE) is the tokenization algorithm used by most modern LLMs.

from collections import defaultdict
from typing import List, Tuple, Dict

class SimpleBPETokenizer:
    """BPE tokenizer implemented from scratch"""

    def __init__(self):
        self.vocab = {}
        self.merges = {}
        self.special_tokens = {
            "<|endoftext|>": 50256,
        }

    def get_stats(self, vocab: Dict[str, int]) -> Dict[Tuple, int]:
        """Count frequency of adjacent token pairs"""
        pairs = defaultdict(int)
        for word, freq in vocab.items():
            symbols = word.split()
            for i in range(len(symbols) - 1):
                pairs[(symbols[i], symbols[i+1])] += freq
        return pairs

    def merge_vocab(self, pair: Tuple, v_in: Dict) -> Dict:
        """Merge the most frequent pair"""
        v_out = {}
        bigram = " ".join(pair)
        replacement = "".join(pair)
        for word in v_in:
            w_out = word.replace(bigram, replacement)
            v_out[w_out] = v_in[word]
        return v_out

    def train(self, corpus: List[str], vocab_size: int = 1000):
        """Train BPE vocabulary"""
        # 1. Build initial character-level vocabulary
        word_freq = defaultdict(int)
        for text in corpus:
            for word in text.split():
                word_freq[" ".join(list(word)) + " </w>"] += 1

        vocab = dict(word_freq)

        # 2. Iteratively merge until vocab_size is reached
        for i in range(vocab_size):
            pairs = self.get_stats(vocab)
            if not pairs:
                break
            best = max(pairs, key=pairs.get)
            vocab = self.merge_vocab(best, vocab)
            self.merges[best] = i
            print(f"Step {i}: merged {best} (freq={pairs[best]})")

        # 3. Build final vocabulary
        for word in vocab:
            for token in word.split():
                if token not in self.vocab:
                    self.vocab[token] = len(self.vocab)

        return self.vocab, self.merges

2.2 Using tiktoken

In practice, OpenAI's tiktoken library is much more efficient.

import tiktoken
import torch
from typing import List

class GPTTokenizer:
    """tiktoken-based GPT tokenizer wrapper"""

    def __init__(self, encoding_name: str = "gpt2"):
        self.enc = tiktoken.get_encoding(encoding_name)
        self.vocab_size = self.enc.n_vocab
        self.eot_token = self.enc.eot_token  # 50256

    def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
        """Convert text to list of token IDs"""
        ids = self.enc.encode(text, allowed_special={"<|endoftext|>"})
        if add_special_tokens:
            ids = [self.eot_token] + ids
        return ids

    def decode(self, ids: List[int]) -> str:
        """Convert list of token IDs back to text"""
        return self.enc.decode(ids)

    def encode_batch(self, texts: List[str]) -> List[List[int]]:
        """Batch encoding"""
        return [self.encode(t) for t in texts]

    def __len__(self):
        return self.vocab_size

# Usage
tokenizer = GPTTokenizer()
text = "Hello! Let's build GPT from scratch."
ids = tokenizer.encode(text)
print(f"Token count: {len(ids)}")
print(f"Token IDs: {ids}")
decoded = tokenizer.decode(ids)
print(f"Decoded: {decoded}")

3. Data Preparation

3.1 Text Dataset Class

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class TextDataset(Dataset):
    """Sliding window language modeling dataset"""

    def __init__(
        self,
        data_path: str,
        tokenizer: GPTTokenizer,
        seq_len: int = 1024,
        stride: int = 512
    ):
        self.seq_len = seq_len
        self.stride = stride

        print(f"Loading data: {data_path}")
        with open(data_path, "r", encoding="utf-8") as f:
            text = f.read()

        print(f"Total characters: {len(text):,}")
        self.tokens = tokenizer.encode(text, add_special_tokens=False)
        print(f"Total tokens: {len(self.tokens):,}")

        self.tokens = np.array(self.tokens, dtype=np.int32)

    def __len__(self):
        return max(0, (len(self.tokens) - self.seq_len) // self.stride)

    def __getitem__(self, idx):
        start = idx * self.stride
        end = start + self.seq_len + 1

        chunk = self.tokens[start:end]

        # Input: tokens[0:seq_len], Label: tokens[1:seq_len+1]
        x = torch.tensor(chunk[:-1], dtype=torch.long)
        y = torch.tensor(chunk[1:], dtype=torch.long)

        return x, y


def create_dataloader(
    data_path: str,
    tokenizer: GPTTokenizer,
    batch_size: int = 8,
    seq_len: int = 1024,
    stride: int = 512,
    num_workers: int = 4,
    shuffle: bool = True
) -> DataLoader:
    dataset = TextDataset(data_path, tokenizer, seq_len, stride)
    print(f"Dataset size: {len(dataset):,} samples")

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True
    )

4. Model Architecture

4.1 Configuration Class

from dataclasses import dataclass

@dataclass
class GPTConfig:
    """GPT model configuration"""
    vocab_size: int = 50257
    max_seq_len: int = 1024
    n_embd: int = 768
    n_layer: int = 12
    n_head: int = 12
    dropout: float = 0.1
    bias: bool = False

    ffn_mult: int = 4
    use_swiglu: bool = True
    use_rope: bool = True

4.2 RMSNorm

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization (Llama-style)"""

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        return self.weight * self._norm(x.float()).type_as(x)

4.3 Rotary Position Embedding (RoPE)

class RotaryEmbedding(nn.Module):
    """Rotary Position Embedding (RoPE)"""

    def __init__(self, dim: int, max_seq_len: int = 2048, base: int = 10000):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len

        # Compute frequencies: theta_i = base^(-2i/dim)
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

        self._build_cache(max_seq_len)

    def _build_cache(self, seq_len: int):
        t = torch.arange(seq_len, device=self.inv_freq.device).float()
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        self.register_buffer("cos_cache", emb.cos()[None, None, :, :])
        self.register_buffer("sin_cache", emb.sin()[None, None, :, :])

    def rotate_half(self, x):
        """Rotate half the tensor"""
        x1 = x[..., :x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2:]
        return torch.cat([-x2, x1], dim=-1)

    def forward(self, q: torch.Tensor, k: torch.Tensor, seq_len: int):
        cos = self.cos_cache[:, :, :seq_len, :].to(q.dtype)
        sin = self.sin_cache[:, :, :seq_len, :].to(q.dtype)

        q_rot = (q * cos) + (self.rotate_half(q) * sin)
        k_rot = (k * cos) + (self.rotate_half(k) * sin)
        return q_rot, k_rot

4.4 Multi-Head Causal Attention with KV Cache

class CausalSelfAttention(nn.Module):
    """Causal self-attention with KV cache support"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        assert config.n_embd % config.n_head == 0

        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = config.n_embd // config.n_head

        # Combined Q, K, V projection
        self.qkv_proj = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)

        if config.use_rope:
            self.rope = RotaryEmbedding(self.head_dim, config.max_seq_len)
        else:
            self.rope = None

        # Causal mask (lower triangular)
        self.register_buffer(
            "causal_mask",
            torch.tril(torch.ones(config.max_seq_len, config.max_seq_len))
            .view(1, 1, config.max_seq_len, config.max_seq_len)
        )

    def forward(
        self,
        x: torch.Tensor,
        use_cache: bool = False,
        past_kv=None
    ):
        B, T, C = x.shape

        qkv = self.qkv_proj(x)
        q, k, v = qkv.split(self.n_embd, dim=2)

        # Split heads: [B, T, C] -> [B, n_head, T, head_dim]
        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)

        if self.rope is not None:
            q, k = self.rope(q, k, T)

        # KV cache handling
        if past_kv is not None:
            past_k, past_v = past_kv
            k = torch.cat([past_k, k], dim=2)
            v = torch.cat([past_v, v], dim=2)

        present_kv = (k, v) if use_cache else None
        kv_len = k.shape[2]

        # Scaled Dot-Product Attention (uses Flash Attention if available)
        if hasattr(F, "scaled_dot_product_attention"):
            y = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask=None,
                dropout_p=self.attn_dropout.p if self.training else 0.0,
                is_causal=True
            )
        else:
            # Manual implementation
            scale = 1.0 / math.sqrt(self.head_dim)
            attn = (q @ k.transpose(-2, -1)) * scale
            mask = self.causal_mask[:, :, :T, :kv_len]
            attn = attn.masked_fill(mask == 0, float("-inf"))
            attn = F.softmax(attn, dim=-1)
            attn = self.attn_dropout(attn)
            y = attn @ v

        # Merge heads: [B, n_head, T, head_dim] -> [B, T, C]
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.out_proj(y))

        return y, present_kv

4.5 SwiGLU Feed-Forward Network

class SwiGLUFFN(nn.Module):
    """FFN with SwiGLU activation (Llama-style)"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        hidden_dim = int(config.n_embd * config.ffn_mult * 2 / 3)
        # Round up to nearest 64 for hardware efficiency
        hidden_dim = ((hidden_dim + 63) // 64) * 64

        self.gate_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)
        self.up_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)
        self.down_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # SwiGLU: gate * SiLU(x)
        gate = F.silu(self.gate_proj(x))
        up = self.up_proj(x)
        return self.dropout(self.down_proj(gate * up))


class GPTFFN(nn.Module):
    """Classic GPT-style GELU FFN"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        hidden_dim = config.n_embd * config.ffn_mult
        self.fc1 = nn.Linear(config.n_embd, hidden_dim, bias=config.bias)
        self.fc2 = nn.Linear(hidden_dim, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.gelu(self.fc1(x))
        x = self.dropout(self.fc2(x))
        return x

4.6 Transformer Block

class TransformerBlock(nn.Module):
    """Pre-RMSNorm Transformer block"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        self.norm1 = RMSNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.norm2 = RMSNorm(config.n_embd)

        if config.use_swiglu:
            self.ffn = SwiGLUFFN(config)
        else:
            self.ffn = GPTFFN(config)

    def forward(self, x: torch.Tensor, use_cache: bool = False, past_kv=None):
        # Pre-Norm + Attention + Residual
        attn_out, present_kv = self.attn(
            self.norm1(x),
            use_cache=use_cache,
            past_kv=past_kv
        )
        x = x + attn_out

        # Pre-Norm + FFN + Residual
        x = x + self.ffn(self.norm2(x))

        return x, present_kv

4.7 Complete GPT Model

class MiniGPT(nn.Module):
    """Full miniGPT model"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config

        # Embedding layers
        self.token_emb = nn.Embedding(config.vocab_size, config.n_embd)

        if not config.use_rope:
            self.pos_emb = nn.Embedding(config.max_seq_len, config.n_embd)

        self.emb_dropout = nn.Dropout(config.dropout)

        # Transformer layers
        self.layers = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.n_layer)
        ])

        # Final normalization
        self.norm_final = RMSNorm(config.n_embd)

        # Language model head
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # Weight tying (input embeddings == output embeddings)
        self.lm_head.weight = self.token_emb.weight

        # Weight initialization
        self.apply(self._init_weights)

        # Special scaling for residual stream
        for pn, p in self.named_parameters():
            if pn.endswith("out_proj.weight") or pn.endswith("down_proj.weight"):
                nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))

        print(f"Parameters: {self.count_params() / 1e6:.1f}M")

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def count_params(self) -> int:
        return sum(p.numel() for p in self.parameters())

    def forward(
        self,
        input_ids: torch.Tensor,
        targets: torch.Tensor = None,
        use_cache: bool = False,
        past_kvs: list = None
    ):
        B, T = input_ids.shape
        device = input_ids.device

        # Token embeddings
        x = self.token_emb(input_ids)  # [B, T, n_embd]

        # Positional embeddings (only when not using RoPE)
        if not self.config.use_rope:
            positions = torch.arange(T, device=device)
            x = x + self.pos_emb(positions)

        x = self.emb_dropout(x)

        # Transformer layers
        present_kvs = []
        for i, layer in enumerate(self.layers):
            past_kv = past_kvs[i] if past_kvs is not None else None
            x, present_kv = layer(x, use_cache=use_cache, past_kv=past_kv)
            present_kvs.append(present_kv)

        x = self.norm_final(x)

        if targets is not None:
            # Training: compute loss over full sequence
            logits = self.lm_head(x)
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
                ignore_index=-1
            )
            return logits, loss
        else:
            # Inference: only need last token logits
            logits = self.lm_head(x[:, [-1], :])
            return logits, present_kvs

    @classmethod
    def from_pretrained(cls, model_path: str):
        checkpoint = torch.load(model_path, map_location="cpu")
        config = checkpoint["config"]
        model = cls(config)
        model.load_state_dict(checkpoint["model"])
        return model

5. Pretraining Implementation

5.1 Learning Rate Scheduler

import math

def get_cosine_schedule_with_warmup(
    optimizer,
    warmup_steps: int,
    total_steps: int,
    min_lr_ratio: float = 0.1
):
    """Cosine LR schedule with warmup"""

    def lr_lambda(current_step: int):
        if current_step < warmup_steps:
            return current_step / max(1, warmup_steps)
        else:
            progress = (current_step - warmup_steps) / max(1, total_steps - warmup_steps)
            cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
            return max(min_lr_ratio, cosine_decay)

    from torch.optim.lr_scheduler import LambdaLR
    return LambdaLR(optimizer, lr_lambda)

5.2 Training Loop

import torch
import torch.nn as nn
from torch.optim import AdamW
import os
import time

class Trainer:
    """miniGPT pretraining trainer"""

    def __init__(self, model: MiniGPT, train_dataloader, val_dataloader, config: dict):
        self.model = model
        self.train_dl = train_dataloader
        self.val_dl = val_dataloader
        self.config = config
        self.device = config.get("device", "cuda" if torch.cuda.is_available() else "cpu")

        # Mixed precision
        self.dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        self.scaler = torch.cuda.amp.GradScaler(enabled=(self.dtype == torch.float16))
        self.ctx = torch.amp.autocast(device_type="cuda", dtype=self.dtype)

        self.optimizer = self._create_optimizer()

        total_steps = len(train_dataloader) * config["num_epochs"]
        warmup_steps = int(total_steps * config.get("warmup_ratio", 0.05))
        self.scheduler = get_cosine_schedule_with_warmup(
            self.optimizer,
            warmup_steps=warmup_steps,
            total_steps=total_steps
        )

        self.step = 0
        self.best_val_loss = float("inf")

    def _create_optimizer(self):
        """Separate params with/without weight decay"""
        decay_params = []
        no_decay_params = []

        for name, param in self.model.named_parameters():
            if not param.requires_grad:
                continue
            if param.dim() == 1 or "emb" in name:
                no_decay_params.append(param)
            else:
                decay_params.append(param)

        optim_groups = [
            {"params": decay_params, "weight_decay": self.config.get("weight_decay", 0.1)},
            {"params": no_decay_params, "weight_decay": 0.0}
        ]

        return AdamW(
            optim_groups,
            lr=self.config["learning_rate"],
            betas=(0.9, 0.95),
            fused=True
        )

    def compute_val_loss(self) -> float:
        self.model.eval()
        total_loss = 0.0
        num_batches = min(len(self.val_dl), 50)

        with torch.no_grad():
            for i, (x, y) in enumerate(self.val_dl):
                if i >= num_batches:
                    break
                x, y = x.to(self.device), y.to(self.device)
                with self.ctx:
                    _, loss = self.model(x, y)
                total_loss += loss.item()

        self.model.train()
        return total_loss / num_batches

    def save_checkpoint(self, path: str):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        torch.save({
            "step": self.step,
            "model": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "scheduler": self.scheduler.state_dict(),
            "config": self.model.config,
            "best_val_loss": self.best_val_loss
        }, path)
        print(f"Checkpoint saved: {path}")

    def train(self):
        self.model.to(self.device)
        self.model.train()

        num_epochs = self.config["num_epochs"]
        grad_accum = self.config.get("gradient_accumulation_steps", 1)
        max_grad_norm = self.config.get("max_grad_norm", 1.0)
        log_interval = self.config.get("log_interval", 100)
        eval_interval = self.config.get("eval_interval", 500)
        save_interval = self.config.get("save_interval", 1000)

        print(f"Training on {self.device}, dtype={self.dtype}")

        for epoch in range(num_epochs):
            t0 = time.time()

            for batch_idx, (x, y) in enumerate(self.train_dl):
                x, y = x.to(self.device), y.to(self.device)

                with self.ctx:
                    _, loss = self.model(x, y)
                    loss = loss / grad_accum

                self.scaler.scale(loss).backward()

                if (batch_idx + 1) % grad_accum == 0:
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)

                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    self.scheduler.step()
                    self.optimizer.zero_grad(set_to_none=True)
                    self.step += 1

                if self.step % log_interval == 0:
                    lr = self.optimizer.param_groups[0]["lr"]
                    elapsed = time.time() - t0
                    tokens_per_sec = log_interval * x.shape[0] * x.shape[1] / elapsed
                    print(
                        f"Epoch {epoch} | Step {self.step} | "
                        f"Loss: {loss.item() * grad_accum:.4f} | "
                        f"LR: {lr:.6f} | "
                        f"Tokens/s: {tokens_per_sec:.0f}"
                    )
                    t0 = time.time()

                if self.step % eval_interval == 0:
                    val_loss = self.compute_val_loss()
                    print(f"Val Loss: {val_loss:.4f} | Perplexity: {math.exp(val_loss):.2f}")
                    if val_loss < self.best_val_loss:
                        self.best_val_loss = val_loss
                        self.save_checkpoint("./checkpoints/best_model.pt")

                if self.step % save_interval == 0:
                    self.save_checkpoint(f"./checkpoints/step_{self.step}.pt")

6. Text Generation

6.1 Decoding Strategies

import torch
import torch.nn.functional as F

class TextGenerator:
    """miniGPT text generator"""

    def __init__(self, model: MiniGPT, tokenizer: GPTTokenizer, device: str = "cuda"):
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.device = device
        self.model.eval()

    @torch.no_grad()
    def generate(
        self,
        prompt: str,
        max_new_tokens: int = 200,
        strategy: str = "top_p",
        temperature: float = 0.8,
        top_k: int = 50,
        top_p: float = 0.9,
        num_beams: int = 4
    ) -> str:
        input_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
        input_ids = torch.tensor(input_ids, dtype=torch.long, device=self.device).unsqueeze(0)

        if strategy == "greedy":
            generated = self._greedy_decode(input_ids, max_new_tokens)
        elif strategy == "temperature":
            generated = self._temperature_sampling(input_ids, max_new_tokens, temperature)
        elif strategy == "top_k":
            generated = self._top_k_sampling(input_ids, max_new_tokens, temperature, top_k)
        elif strategy == "top_p":
            generated = self._top_p_sampling(input_ids, max_new_tokens, temperature, top_p)
        elif strategy == "beam":
            generated = self._beam_search(input_ids, max_new_tokens, num_beams)
        else:
            raise ValueError(f"Unknown strategy: {strategy}")

        new_tokens = generated[0][input_ids.shape[1]:].tolist()
        return self.tokenizer.decode(new_tokens)

    def _greedy_decode(self, input_ids, max_new_tokens):
        """Greedy decoding: always pick the highest probability token"""
        current_ids = input_ids.clone()
        for _ in range(max_new_tokens):
            logits, _ = self.model(current_ids)
            next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
            current_ids = torch.cat([current_ids, next_token], dim=1)
            if next_token.item() == self.tokenizer.eot_token:
                break
        return current_ids

    def _temperature_sampling(self, input_ids, max_new_tokens, temperature):
        """Temperature sampling: scale logits before softmax"""
        current_ids = input_ids.clone()
        for _ in range(max_new_tokens):
            logits, _ = self.model(current_ids)
            logits = logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            current_ids = torch.cat([current_ids, next_token], dim=1)
            if next_token.item() == self.tokenizer.eot_token:
                break
        return current_ids

    def _top_k_sampling(self, input_ids, max_new_tokens, temperature, top_k):
        """Top-k sampling: sample only from the top k tokens"""
        current_ids = input_ids.clone()
        for _ in range(max_new_tokens):
            logits, _ = self.model(current_ids)
            logits = logits[:, -1, :] / temperature

            top_k_logits, top_k_indices = torch.topk(logits, k=top_k, dim=-1)
            filtered_logits = torch.full_like(logits, float("-inf"))
            filtered_logits.scatter_(1, top_k_indices, top_k_logits)

            probs = F.softmax(filtered_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            current_ids = torch.cat([current_ids, next_token], dim=1)
            if next_token.item() == self.tokenizer.eot_token:
                break
        return current_ids

    def _top_p_sampling(self, input_ids, max_new_tokens, temperature, top_p):
        """Top-p (Nucleus) sampling: sample from the top tokens covering probability mass p"""
        current_ids = input_ids.clone()
        for _ in range(max_new_tokens):
            logits, _ = self.model(current_ids)
            logits = logits[:, -1, :] / temperature

            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

            # Remove tokens with cumulative probability above top_p
            sorted_indices_to_remove = cumulative_probs - F.softmax(sorted_logits, dim=-1) > top_p
            sorted_logits[sorted_indices_to_remove] = float("-inf")

            logits = torch.zeros_like(logits).scatter_(1, sorted_indices, sorted_logits)
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            current_ids = torch.cat([current_ids, next_token], dim=1)
            if next_token.item() == self.tokenizer.eot_token:
                break
        return current_ids

    def _beam_search(self, input_ids, max_new_tokens, num_beams):
        """Beam search: explore multiple candidates simultaneously"""
        B = input_ids.shape[0]
        assert B == 1, "Beam search only supports batch size 1"

        beams = [(input_ids, 0.0)]

        for _ in range(max_new_tokens):
            all_candidates = []
            for beam_ids, beam_score in beams:
                logits, _ = self.model(beam_ids)
                log_probs = F.log_softmax(logits[:, -1, :], dim=-1)
                top_log_probs, top_ids = log_probs.topk(num_beams)

                for i in range(num_beams):
                    token = top_ids[0, i].unsqueeze(0).unsqueeze(0)
                    new_ids = torch.cat([beam_ids, token], dim=1)
                    new_score = beam_score + top_log_probs[0, i].item()
                    all_candidates.append((new_ids, new_score))

            all_candidates.sort(key=lambda x: x[1] / x[0].shape[1], reverse=True)
            beams = all_candidates[:num_beams]

            if beams[0][0][0, -1].item() == self.tokenizer.eot_token:
                break

        return beams[0][0]

7. Perplexity Evaluation

import math
import torch

@torch.no_grad()
def compute_perplexity(
    model: MiniGPT,
    tokenizer: GPTTokenizer,
    text: str,
    device: str = "cuda",
    seq_len: int = 512
) -> float:
    """Compute perplexity of a text"""
    model.eval()
    model.to(device)

    tokens = tokenizer.encode(text, add_special_tokens=False)
    total_loss = 0.0
    total_tokens = 0

    for i in range(0, len(tokens) - seq_len, seq_len):
        chunk = tokens[i:i + seq_len + 1]
        if len(chunk) < seq_len + 1:
            break

        x = torch.tensor(chunk[:-1], dtype=torch.long, device=device).unsqueeze(0)
        y = torch.tensor(chunk[1:], dtype=torch.long, device=device).unsqueeze(0)

        _, loss = model(x, y)
        total_loss += loss.item() * seq_len
        total_tokens += seq_len

    avg_loss = total_loss / total_tokens
    return math.exp(avg_loss)


# Usage
model = MiniGPT.from_pretrained("./checkpoints/best_model.pt")
tokenizer = GPTTokenizer()
generator = TextGenerator(model, tokenizer)

test_text = "The quick brown fox jumps over the lazy dog."
ppl = compute_perplexity(model, tokenizer, test_text)
print(f"Perplexity: {ppl:.2f}")

for strategy in ["greedy", "top_k", "top_p"]:
    generated = generator.generate(
        prompt="Once upon a time",
        max_new_tokens=100,
        strategy=strategy,
        temperature=0.8
    )
    print(f"\n[{strategy}]: {generated}")

8. Scaling Up

8.1 Flash Attention 2

# pip install flash-attn

from flash_attn import flash_attn_qkvpacked_func

class FlashCausalAttention(nn.Module):
    """High-speed attention using Flash Attention 2"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = config.n_embd // config.n_head

        self.qkv_proj = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
        self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.dropout = config.dropout

    def forward(self, x: torch.Tensor):
        B, T, C = x.shape

        qkv = self.qkv_proj(x)
        # Flash Attention expects shape [B, T, 3, n_head, head_dim]
        qkv = qkv.view(B, T, 3, self.n_head, self.head_dim)

        attn_out = flash_attn_qkvpacked_func(
            qkv,
            dropout_p=self.dropout if self.training else 0.0,
            causal=True
        )  # [B, T, n_head, head_dim]

        attn_out = attn_out.reshape(B, T, C)
        return self.out_proj(attn_out)

8.2 FSDP Distributed Training

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import functools

def setup_fsdp_training():
    """Setup FSDP distributed training"""
    dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    torch.cuda.set_device(rank)

    config = GPTConfig(
        n_embd=1024,
        n_layer=24,
        n_head=16,
        vocab_size=50257
    )
    model = MiniGPT(config)

    # Wrap policy: shard at TransformerBlock boundaries
    auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={TransformerBlock}
    )

    model = FSDP(
        model,
        auto_wrap_policy=auto_wrap_policy,
        mixed_precision=torch.distributed.fsdp.MixedPrecision(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.bfloat16,
            buffer_dtype=torch.bfloat16
        ),
        sharding_strategy="FULL_SHARD",  # Equivalent to ZeRO-3
        device_id=rank
    )

    return model, rank, world_size

8.3 Chinchilla Scaling Laws

The Chinchilla paper (Hoffmann et al., 2022) gives the following relationship between optimal model size and training tokens:

Optimal tokens = 20 x number of parameters

def chinchilla_optimal_tokens(model_params: int) -> int:
    """Compute optimal training tokens per Chinchilla scaling law"""
    return 20 * model_params

models = {
    "124M (GPT-2 Small)": 124e6,
    "1.3B": 1.3e9,
    "7B (Llama-2 7B)": 7e9,
    "13B": 13e9,
    "70B (Llama-2 70B)": 70e9,
}

print("Recommended training tokens by model size:")
for name, params in models.items():
    optimal_tokens = chinchilla_optimal_tokens(int(params))
    print(f"  {name}: {optimal_tokens/1e9:.1f}B tokens")

9. Instruction Tuning

9.1 SFT Data Formats

# ChatML format
CHATML_SYSTEM = "<|im_start|>system\n{system}<|im_end|>\n"
CHATML_USER = "<|im_start|>user\n{user}<|im_end|>\n"
CHATML_ASSISTANT = "<|im_start|>assistant\n{assistant}<|im_end|>\n"

def format_chatml(system: str, user: str, assistant: str, add_generation_prompt: bool = False) -> str:
    formatted = CHATML_SYSTEM.format(system=system)
    formatted += CHATML_USER.format(user=user)
    formatted += CHATML_ASSISTANT.format(assistant=assistant)
    if add_generation_prompt:
        formatted += "<|im_start|>assistant\n"
    return formatted


# Llama-3 format
def format_llama3(system: str, user: str, assistant: str, add_generation_prompt: bool = False) -> str:
    formatted = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system}<|eot_id|>\n"
    formatted += f"<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|>\n"
    formatted += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant}<|eot_id|>"
    if add_generation_prompt:
        formatted += "\n<|start_header_id|>assistant<|end_header_id|>\n\n"
    return formatted

9.2 Instruction Dataset Class

class InstructionDataset(Dataset):
    """SFT dataset: only compute loss on assistant responses"""

    def __init__(
        self,
        data: list,
        tokenizer: GPTTokenizer,
        max_seq_len: int = 2048,
        format_func=format_chatml
    ):
        self.samples = []

        for item in data:
            # Tokenize full conversation
            full_text = format_func(
                system=item.get("system", "You are a helpful assistant."),
                user=item["instruction"],
                assistant=item["response"]
            )
            full_ids = tokenizer.encode(full_text, add_special_tokens=False)

            # Tokenize prompt only (to find where response starts)
            prompt_text = format_func(
                system=item.get("system", "You are a helpful assistant."),
                user=item["instruction"],
                assistant="",
                add_generation_prompt=True
            )
            prompt_ids = tokenizer.encode(prompt_text, add_special_tokens=False)
            prompt_len = len(prompt_ids)

            if len(full_ids) > max_seq_len:
                full_ids = full_ids[:max_seq_len]

            # Mask prompt tokens in labels (set to -100 = ignore in loss)
            labels = full_ids.copy()
            labels[:prompt_len] = [-100] * prompt_len

            self.samples.append({
                "input_ids": full_ids,
                "labels": labels
            })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        return (
            torch.tensor(sample["input_ids"], dtype=torch.long),
            torch.tensor(sample["labels"], dtype=torch.long)
        )

9.3 SFT Training Loop

def train_sft(
    base_model_path: str,
    data_path: str,
    output_dir: str,
    num_epochs: int = 3,
    learning_rate: float = 2e-5,
    batch_size: int = 4
):
    """Supervised Fine-Tuning (SFT)"""
    import json

    with open(data_path) as f:
        data = json.load(f)

    tokenizer = GPTTokenizer()
    dataset = InstructionDataset(data, tokenizer, max_seq_len=2048)

    def collate_fn(batch):
        input_ids = [item[0] for item in batch]
        labels = [item[1] for item in batch]
        max_len = max(len(x) for x in input_ids)
        padded_ids = torch.zeros(len(batch), max_len, dtype=torch.long)
        padded_labels = torch.full((len(batch), max_len), -100, dtype=torch.long)
        for i, (ids, lbls) in enumerate(zip(input_ids, labels)):
            padded_ids[i, :len(ids)] = ids
            padded_labels[i, :len(lbls)] = lbls
        return padded_ids, padded_labels

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

    model = MiniGPT.from_pretrained(base_model_path)
    device = "cuda"
    model.to(device)

    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    total_steps = len(dataloader) * num_epochs
    scheduler = get_cosine_schedule_with_warmup(optimizer, int(total_steps * 0.1), total_steps)

    model.train()
    for epoch in range(num_epochs):
        for step, (x, y) in enumerate(dataloader):
            x, y = x.to(device), y.to(device)
            logits, loss = model(x, y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            if step % 50 == 0:
                print(f"Epoch {epoch}, Step {step}, Loss: {loss.item():.4f}")

    os.makedirs(output_dir, exist_ok=True)
    torch.save({
        "model": model.state_dict(),
        "config": model.config
    }, f"{output_dir}/sft_model.pt")
    print(f"SFT complete. Model saved: {output_dir}")

10. Open-Source LLM Architecture Analysis

10.1 Llama 3 Architecture

Llama 3 is Meta's open-source LLM with these key characteristics:

FeatureDetails
ArchitectureDecoder-only Transformer
NormalizationPre-RMSNorm
ActivationSwiGLU
Position EncodingRoPE (theta=500,000)
Vocabulary Size128,256
Context Length128K (Llama 3.1+)
AttentionGQA (8B, 70B models)
# Llama 3 style config (8B model)
llama3_config = GPTConfig(
    vocab_size=128256,
    max_seq_len=8192,
    n_embd=4096,
    n_layer=32,
    n_head=32,
    use_rope=True,
    use_swiglu=True,
    bias=False
)

10.2 Grouped Query Attention (GQA)

class GroupedQueryAttention(nn.Module):
    """Grouped Query Attention (used in Llama 3, Mistral)"""

    def __init__(self, config: GPTConfig, n_kv_heads: int = 8):
        super().__init__()
        self.n_head = config.n_head
        self.n_kv_heads = n_kv_heads
        self.n_rep = config.n_head // n_kv_heads
        self.head_dim = config.n_embd // config.n_head

        self.q_proj = nn.Linear(config.n_embd, config.n_head * self.head_dim, bias=False)
        self.k_proj = nn.Linear(config.n_embd, n_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(config.n_embd, n_kv_heads * self.head_dim, bias=False)
        self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)

        self.rope = RotaryEmbedding(self.head_dim, config.max_seq_len, base=500000)

    def forward(self, x: torch.Tensor):
        B, T, C = x.shape

        q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)

        q, k = self.rope(q, k, T)

        # Repeat KV heads to match number of query heads (core of GQA)
        k = k.unsqueeze(2).expand(B, self.n_kv_heads, self.n_rep, T, self.head_dim)
        k = k.reshape(B, self.n_head, T, self.head_dim)
        v = v.unsqueeze(2).expand(B, self.n_kv_heads, self.n_rep, T, self.head_dim)
        v = v.reshape(B, self.n_head, T, self.head_dim)

        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(y)

10.3 Mixture of Experts (MoE)

DeepSeek and Mixtral use Mixture of Experts to increase model capacity without proportionally increasing compute.

class MoEFFN(nn.Module):
    """Mixture of Experts FFN (simplified version)"""

    def __init__(self, config: GPTConfig, num_experts: int = 8, top_k: int = 2):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k

        # Router
        self.router = nn.Linear(config.n_embd, num_experts, bias=False)

        # Expert FFNs
        self.experts = nn.ModuleList([
            SwiGLUFFN(config) for _ in range(num_experts)
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape
        x_flat = x.view(B * T, C)

        # Compute routing scores
        router_logits = self.router(x_flat)
        router_probs = F.softmax(router_logits, dim=-1)

        # Select top-k experts
        top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)

        # Compute expert outputs
        output = torch.zeros_like(x_flat)
        for k in range(self.top_k):
            expert_idx = top_k_indices[:, k]
            expert_prob = top_k_probs[:, k].unsqueeze(-1)

            for e in range(self.num_experts):
                mask = (expert_idx == e)
                if mask.any():
                    expert_out = self.experts[e](x_flat[mask])
                    output[mask] += expert_prob[mask] * expert_out

        return output.view(B, T, C)

Putting It All Together

Here is a complete script that ties all components together:

import torch
from dataclasses import dataclass

# ── 1. Configuration ─────────────────────────────────────
config = GPTConfig(
    vocab_size=50257,
    max_seq_len=1024,
    n_embd=768,
    n_layer=12,
    n_head=12,
    dropout=0.1,
    use_rope=True,
    use_swiglu=True
)

# ── 2. Model ─────────────────────────────────────────────
model = MiniGPT(config)
print(f"Model parameters: {model.count_params() / 1e6:.1f}M")

# ── 3. Tokenizer ─────────────────────────────────────────
tokenizer = GPTTokenizer()

# ── 4. Data ───────────────────────────────────────────────
train_loader = create_dataloader("train.txt", tokenizer, batch_size=8, seq_len=1024)
val_loader = create_dataloader("val.txt", tokenizer, batch_size=8, seq_len=1024, shuffle=False)

# ── 5. Trainer ────────────────────────────────────────────
trainer_config = {
    "num_epochs": 5,
    "learning_rate": 6e-4,
    "weight_decay": 0.1,
    "gradient_accumulation_steps": 4,
    "max_grad_norm": 1.0,
    "warmup_ratio": 0.05,
    "log_interval": 100,
    "eval_interval": 500,
    "save_interval": 1000,
    "device": "cuda"
}

trainer = Trainer(model, train_loader, val_loader, trainer_config)
trainer.train()

# ── 6. Generation ─────────────────────────────────────────
generator = TextGenerator(model, tokenizer)

prompts = [
    "The history of artificial intelligence",
    "In the year 2050, scientists discovered",
    "The best way to learn programming is",
]

for prompt in prompts:
    print(f"\nPrompt: {prompt}")
    for strategy in ["greedy", "top_p"]:
        text = generator.generate(prompt, max_new_tokens=100, strategy=strategy)
        print(f"  [{strategy}]: {text}")

Conclusion

In this guide, we built miniGPT from scratch and covered every core element of modern LLMs:

  • Tokenizer: BPE algorithm splits text into subword tokens
  • Embeddings: Token embeddings + RoPE positional encoding
  • Multi-Head Causal Attention: Q/K/V projections, causal mask, KV cache
  • FFN: SwiGLU activation for non-linearity
  • Pre-RMSNorm: Training stability
  • Text Generation: Greedy, Temperature, Top-k, Top-p, Beam Search
  • Pretraining: Cosine LR schedule, gradient clipping, mixed precision
  • SFT: Loss computed only on assistant response tokens
  • Scaling: Flash Attention, FSDP, Chinchilla laws

This foundation gives you the tools to understand and extend any modern open-source LLM. The concepts covered here directly map to the implementation details of Llama 3, Mistral, and Qwen.

References