- Authors

- Name
- Youngju Kim
- @fjvbn20031
Introduction
GPT, Llama, Mistral — large language models (LLMs) are at the center of the current AI revolution. Yet very few people understand how these models work at the code level. In this guide, we build "miniGPT," a small GPT-style model from scratch using PyTorch, mastering every core concept behind modern LLMs.
1. LLM Architecture Overview
1.1 The Pretraining Pipeline
Modern LLM development follows three stages:
-
Pretraining: The model learns to predict the next token from billions of text tokens. During this stage, it acquires statistical language patterns, world knowledge, and reasoning ability.
-
Instruction Fine-Tuning (SFT): The model is trained on human-written question-answer pairs so that it learns to follow instructions.
-
Preference Learning (RLHF / DPO): The model is refined using human preference data to produce more helpful and safe responses.
1.2 miniGPT Specification
| Parameter | Value |
|---|---|
| Vocabulary size | 50,257 (same as GPT-2) |
| Context length | 1024 |
| Embedding dimension | 768 |
| Number of layers | 12 |
| Attention heads | 12 |
| FFN expansion ratio | 4x |
| Total parameters | ~124M |
This matches GPT-2 Small and is trainable on a single GPU.
2. Tokenizer Implementation
2.1 BPE Algorithm
Byte-Pair Encoding (BPE) is the tokenization algorithm used by most modern LLMs.
from collections import defaultdict
from typing import List, Tuple, Dict
class SimpleBPETokenizer:
"""BPE tokenizer implemented from scratch"""
def __init__(self):
self.vocab = {}
self.merges = {}
self.special_tokens = {
"<|endoftext|>": 50256,
}
def get_stats(self, vocab: Dict[str, int]) -> Dict[Tuple, int]:
"""Count frequency of adjacent token pairs"""
pairs = defaultdict(int)
for word, freq in vocab.items():
symbols = word.split()
for i in range(len(symbols) - 1):
pairs[(symbols[i], symbols[i+1])] += freq
return pairs
def merge_vocab(self, pair: Tuple, v_in: Dict) -> Dict:
"""Merge the most frequent pair"""
v_out = {}
bigram = " ".join(pair)
replacement = "".join(pair)
for word in v_in:
w_out = word.replace(bigram, replacement)
v_out[w_out] = v_in[word]
return v_out
def train(self, corpus: List[str], vocab_size: int = 1000):
"""Train BPE vocabulary"""
# 1. Build initial character-level vocabulary
word_freq = defaultdict(int)
for text in corpus:
for word in text.split():
word_freq[" ".join(list(word)) + " </w>"] += 1
vocab = dict(word_freq)
# 2. Iteratively merge until vocab_size is reached
for i in range(vocab_size):
pairs = self.get_stats(vocab)
if not pairs:
break
best = max(pairs, key=pairs.get)
vocab = self.merge_vocab(best, vocab)
self.merges[best] = i
print(f"Step {i}: merged {best} (freq={pairs[best]})")
# 3. Build final vocabulary
for word in vocab:
for token in word.split():
if token not in self.vocab:
self.vocab[token] = len(self.vocab)
return self.vocab, self.merges
2.2 Using tiktoken
In practice, OpenAI's tiktoken library is much more efficient.
import tiktoken
import torch
from typing import List
class GPTTokenizer:
"""tiktoken-based GPT tokenizer wrapper"""
def __init__(self, encoding_name: str = "gpt2"):
self.enc = tiktoken.get_encoding(encoding_name)
self.vocab_size = self.enc.n_vocab
self.eot_token = self.enc.eot_token # 50256
def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
"""Convert text to list of token IDs"""
ids = self.enc.encode(text, allowed_special={"<|endoftext|>"})
if add_special_tokens:
ids = [self.eot_token] + ids
return ids
def decode(self, ids: List[int]) -> str:
"""Convert list of token IDs back to text"""
return self.enc.decode(ids)
def encode_batch(self, texts: List[str]) -> List[List[int]]:
"""Batch encoding"""
return [self.encode(t) for t in texts]
def __len__(self):
return self.vocab_size
# Usage
tokenizer = GPTTokenizer()
text = "Hello! Let's build GPT from scratch."
ids = tokenizer.encode(text)
print(f"Token count: {len(ids)}")
print(f"Token IDs: {ids}")
decoded = tokenizer.decode(ids)
print(f"Decoded: {decoded}")
3. Data Preparation
3.1 Text Dataset Class
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
class TextDataset(Dataset):
"""Sliding window language modeling dataset"""
def __init__(
self,
data_path: str,
tokenizer: GPTTokenizer,
seq_len: int = 1024,
stride: int = 512
):
self.seq_len = seq_len
self.stride = stride
print(f"Loading data: {data_path}")
with open(data_path, "r", encoding="utf-8") as f:
text = f.read()
print(f"Total characters: {len(text):,}")
self.tokens = tokenizer.encode(text, add_special_tokens=False)
print(f"Total tokens: {len(self.tokens):,}")
self.tokens = np.array(self.tokens, dtype=np.int32)
def __len__(self):
return max(0, (len(self.tokens) - self.seq_len) // self.stride)
def __getitem__(self, idx):
start = idx * self.stride
end = start + self.seq_len + 1
chunk = self.tokens[start:end]
# Input: tokens[0:seq_len], Label: tokens[1:seq_len+1]
x = torch.tensor(chunk[:-1], dtype=torch.long)
y = torch.tensor(chunk[1:], dtype=torch.long)
return x, y
def create_dataloader(
data_path: str,
tokenizer: GPTTokenizer,
batch_size: int = 8,
seq_len: int = 1024,
stride: int = 512,
num_workers: int = 4,
shuffle: bool = True
) -> DataLoader:
dataset = TextDataset(data_path, tokenizer, seq_len, stride)
print(f"Dataset size: {len(dataset):,} samples")
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=True,
drop_last=True
)
4. Model Architecture
4.1 Configuration Class
from dataclasses import dataclass
@dataclass
class GPTConfig:
"""GPT model configuration"""
vocab_size: int = 50257
max_seq_len: int = 1024
n_embd: int = 768
n_layer: int = 12
n_head: int = 12
dropout: float = 0.1
bias: bool = False
ffn_mult: int = 4
use_swiglu: bool = True
use_rope: bool = True
4.2 RMSNorm
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization (Llama-style)"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self.weight * self._norm(x.float()).type_as(x)
4.3 Rotary Position Embedding (RoPE)
class RotaryEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE)"""
def __init__(self, dim: int, max_seq_len: int = 2048, base: int = 10000):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
# Compute frequencies: theta_i = base^(-2i/dim)
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self._build_cache(max_seq_len)
def _build_cache(self, seq_len: int):
t = torch.arange(seq_len, device=self.inv_freq.device).float()
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer("cos_cache", emb.cos()[None, None, :, :])
self.register_buffer("sin_cache", emb.sin()[None, None, :, :])
def rotate_half(self, x):
"""Rotate half the tensor"""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat([-x2, x1], dim=-1)
def forward(self, q: torch.Tensor, k: torch.Tensor, seq_len: int):
cos = self.cos_cache[:, :, :seq_len, :].to(q.dtype)
sin = self.sin_cache[:, :, :seq_len, :].to(q.dtype)
q_rot = (q * cos) + (self.rotate_half(q) * sin)
k_rot = (k * cos) + (self.rotate_half(k) * sin)
return q_rot, k_rot
4.4 Multi-Head Causal Attention with KV Cache
class CausalSelfAttention(nn.Module):
"""Causal self-attention with KV cache support"""
def __init__(self, config: GPTConfig):
super().__init__()
assert config.n_embd % config.n_head == 0
self.n_head = config.n_head
self.n_embd = config.n_embd
self.head_dim = config.n_embd // config.n_head
# Combined Q, K, V projection
self.qkv_proj = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
if config.use_rope:
self.rope = RotaryEmbedding(self.head_dim, config.max_seq_len)
else:
self.rope = None
# Causal mask (lower triangular)
self.register_buffer(
"causal_mask",
torch.tril(torch.ones(config.max_seq_len, config.max_seq_len))
.view(1, 1, config.max_seq_len, config.max_seq_len)
)
def forward(
self,
x: torch.Tensor,
use_cache: bool = False,
past_kv=None
):
B, T, C = x.shape
qkv = self.qkv_proj(x)
q, k, v = qkv.split(self.n_embd, dim=2)
# Split heads: [B, T, C] -> [B, n_head, T, head_dim]
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
if self.rope is not None:
q, k = self.rope(q, k, T)
# KV cache handling
if past_kv is not None:
past_k, past_v = past_kv
k = torch.cat([past_k, k], dim=2)
v = torch.cat([past_v, v], dim=2)
present_kv = (k, v) if use_cache else None
kv_len = k.shape[2]
# Scaled Dot-Product Attention (uses Flash Attention if available)
if hasattr(F, "scaled_dot_product_attention"):
y = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=self.attn_dropout.p if self.training else 0.0,
is_causal=True
)
else:
# Manual implementation
scale = 1.0 / math.sqrt(self.head_dim)
attn = (q @ k.transpose(-2, -1)) * scale
mask = self.causal_mask[:, :, :T, :kv_len]
attn = attn.masked_fill(mask == 0, float("-inf"))
attn = F.softmax(attn, dim=-1)
attn = self.attn_dropout(attn)
y = attn @ v
# Merge heads: [B, n_head, T, head_dim] -> [B, T, C]
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.resid_dropout(self.out_proj(y))
return y, present_kv
4.5 SwiGLU Feed-Forward Network
class SwiGLUFFN(nn.Module):
"""FFN with SwiGLU activation (Llama-style)"""
def __init__(self, config: GPTConfig):
super().__init__()
hidden_dim = int(config.n_embd * config.ffn_mult * 2 / 3)
# Round up to nearest 64 for hardware efficiency
hidden_dim = ((hidden_dim + 63) // 64) * 64
self.gate_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)
self.up_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# SwiGLU: gate * SiLU(x)
gate = F.silu(self.gate_proj(x))
up = self.up_proj(x)
return self.dropout(self.down_proj(gate * up))
class GPTFFN(nn.Module):
"""Classic GPT-style GELU FFN"""
def __init__(self, config: GPTConfig):
super().__init__()
hidden_dim = config.n_embd * config.ffn_mult
self.fc1 = nn.Linear(config.n_embd, hidden_dim, bias=config.bias)
self.fc2 = nn.Linear(hidden_dim, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.gelu(self.fc1(x))
x = self.dropout(self.fc2(x))
return x
4.6 Transformer Block
class TransformerBlock(nn.Module):
"""Pre-RMSNorm Transformer block"""
def __init__(self, config: GPTConfig):
super().__init__()
self.norm1 = RMSNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.norm2 = RMSNorm(config.n_embd)
if config.use_swiglu:
self.ffn = SwiGLUFFN(config)
else:
self.ffn = GPTFFN(config)
def forward(self, x: torch.Tensor, use_cache: bool = False, past_kv=None):
# Pre-Norm + Attention + Residual
attn_out, present_kv = self.attn(
self.norm1(x),
use_cache=use_cache,
past_kv=past_kv
)
x = x + attn_out
# Pre-Norm + FFN + Residual
x = x + self.ffn(self.norm2(x))
return x, present_kv
4.7 Complete GPT Model
class MiniGPT(nn.Module):
"""Full miniGPT model"""
def __init__(self, config: GPTConfig):
super().__init__()
self.config = config
# Embedding layers
self.token_emb = nn.Embedding(config.vocab_size, config.n_embd)
if not config.use_rope:
self.pos_emb = nn.Embedding(config.max_seq_len, config.n_embd)
self.emb_dropout = nn.Dropout(config.dropout)
# Transformer layers
self.layers = nn.ModuleList([
TransformerBlock(config) for _ in range(config.n_layer)
])
# Final normalization
self.norm_final = RMSNorm(config.n_embd)
# Language model head
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Weight tying (input embeddings == output embeddings)
self.lm_head.weight = self.token_emb.weight
# Weight initialization
self.apply(self._init_weights)
# Special scaling for residual stream
for pn, p in self.named_parameters():
if pn.endswith("out_proj.weight") or pn.endswith("down_proj.weight"):
nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
print(f"Parameters: {self.count_params() / 1e6:.1f}M")
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def count_params(self) -> int:
return sum(p.numel() for p in self.parameters())
def forward(
self,
input_ids: torch.Tensor,
targets: torch.Tensor = None,
use_cache: bool = False,
past_kvs: list = None
):
B, T = input_ids.shape
device = input_ids.device
# Token embeddings
x = self.token_emb(input_ids) # [B, T, n_embd]
# Positional embeddings (only when not using RoPE)
if not self.config.use_rope:
positions = torch.arange(T, device=device)
x = x + self.pos_emb(positions)
x = self.emb_dropout(x)
# Transformer layers
present_kvs = []
for i, layer in enumerate(self.layers):
past_kv = past_kvs[i] if past_kvs is not None else None
x, present_kv = layer(x, use_cache=use_cache, past_kv=past_kv)
present_kvs.append(present_kv)
x = self.norm_final(x)
if targets is not None:
# Training: compute loss over full sequence
logits = self.lm_head(x)
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
ignore_index=-1
)
return logits, loss
else:
# Inference: only need last token logits
logits = self.lm_head(x[:, [-1], :])
return logits, present_kvs
@classmethod
def from_pretrained(cls, model_path: str):
checkpoint = torch.load(model_path, map_location="cpu")
config = checkpoint["config"]
model = cls(config)
model.load_state_dict(checkpoint["model"])
return model
5. Pretraining Implementation
5.1 Learning Rate Scheduler
import math
def get_cosine_schedule_with_warmup(
optimizer,
warmup_steps: int,
total_steps: int,
min_lr_ratio: float = 0.1
):
"""Cosine LR schedule with warmup"""
def lr_lambda(current_step: int):
if current_step < warmup_steps:
return current_step / max(1, warmup_steps)
else:
progress = (current_step - warmup_steps) / max(1, total_steps - warmup_steps)
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
return max(min_lr_ratio, cosine_decay)
from torch.optim.lr_scheduler import LambdaLR
return LambdaLR(optimizer, lr_lambda)
5.2 Training Loop
import torch
import torch.nn as nn
from torch.optim import AdamW
import os
import time
class Trainer:
"""miniGPT pretraining trainer"""
def __init__(self, model: MiniGPT, train_dataloader, val_dataloader, config: dict):
self.model = model
self.train_dl = train_dataloader
self.val_dl = val_dataloader
self.config = config
self.device = config.get("device", "cuda" if torch.cuda.is_available() else "cpu")
# Mixed precision
self.dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
self.scaler = torch.cuda.amp.GradScaler(enabled=(self.dtype == torch.float16))
self.ctx = torch.amp.autocast(device_type="cuda", dtype=self.dtype)
self.optimizer = self._create_optimizer()
total_steps = len(train_dataloader) * config["num_epochs"]
warmup_steps = int(total_steps * config.get("warmup_ratio", 0.05))
self.scheduler = get_cosine_schedule_with_warmup(
self.optimizer,
warmup_steps=warmup_steps,
total_steps=total_steps
)
self.step = 0
self.best_val_loss = float("inf")
def _create_optimizer(self):
"""Separate params with/without weight decay"""
decay_params = []
no_decay_params = []
for name, param in self.model.named_parameters():
if not param.requires_grad:
continue
if param.dim() == 1 or "emb" in name:
no_decay_params.append(param)
else:
decay_params.append(param)
optim_groups = [
{"params": decay_params, "weight_decay": self.config.get("weight_decay", 0.1)},
{"params": no_decay_params, "weight_decay": 0.0}
]
return AdamW(
optim_groups,
lr=self.config["learning_rate"],
betas=(0.9, 0.95),
fused=True
)
def compute_val_loss(self) -> float:
self.model.eval()
total_loss = 0.0
num_batches = min(len(self.val_dl), 50)
with torch.no_grad():
for i, (x, y) in enumerate(self.val_dl):
if i >= num_batches:
break
x, y = x.to(self.device), y.to(self.device)
with self.ctx:
_, loss = self.model(x, y)
total_loss += loss.item()
self.model.train()
return total_loss / num_batches
def save_checkpoint(self, path: str):
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save({
"step": self.step,
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"scheduler": self.scheduler.state_dict(),
"config": self.model.config,
"best_val_loss": self.best_val_loss
}, path)
print(f"Checkpoint saved: {path}")
def train(self):
self.model.to(self.device)
self.model.train()
num_epochs = self.config["num_epochs"]
grad_accum = self.config.get("gradient_accumulation_steps", 1)
max_grad_norm = self.config.get("max_grad_norm", 1.0)
log_interval = self.config.get("log_interval", 100)
eval_interval = self.config.get("eval_interval", 500)
save_interval = self.config.get("save_interval", 1000)
print(f"Training on {self.device}, dtype={self.dtype}")
for epoch in range(num_epochs):
t0 = time.time()
for batch_idx, (x, y) in enumerate(self.train_dl):
x, y = x.to(self.device), y.to(self.device)
with self.ctx:
_, loss = self.model(x, y)
loss = loss / grad_accum
self.scaler.scale(loss).backward()
if (batch_idx + 1) % grad_accum == 0:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
self.scaler.step(self.optimizer)
self.scaler.update()
self.scheduler.step()
self.optimizer.zero_grad(set_to_none=True)
self.step += 1
if self.step % log_interval == 0:
lr = self.optimizer.param_groups[0]["lr"]
elapsed = time.time() - t0
tokens_per_sec = log_interval * x.shape[0] * x.shape[1] / elapsed
print(
f"Epoch {epoch} | Step {self.step} | "
f"Loss: {loss.item() * grad_accum:.4f} | "
f"LR: {lr:.6f} | "
f"Tokens/s: {tokens_per_sec:.0f}"
)
t0 = time.time()
if self.step % eval_interval == 0:
val_loss = self.compute_val_loss()
print(f"Val Loss: {val_loss:.4f} | Perplexity: {math.exp(val_loss):.2f}")
if val_loss < self.best_val_loss:
self.best_val_loss = val_loss
self.save_checkpoint("./checkpoints/best_model.pt")
if self.step % save_interval == 0:
self.save_checkpoint(f"./checkpoints/step_{self.step}.pt")
6. Text Generation
6.1 Decoding Strategies
import torch
import torch.nn.functional as F
class TextGenerator:
"""miniGPT text generator"""
def __init__(self, model: MiniGPT, tokenizer: GPTTokenizer, device: str = "cuda"):
self.model = model.to(device)
self.tokenizer = tokenizer
self.device = device
self.model.eval()
@torch.no_grad()
def generate(
self,
prompt: str,
max_new_tokens: int = 200,
strategy: str = "top_p",
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 0.9,
num_beams: int = 4
) -> str:
input_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
input_ids = torch.tensor(input_ids, dtype=torch.long, device=self.device).unsqueeze(0)
if strategy == "greedy":
generated = self._greedy_decode(input_ids, max_new_tokens)
elif strategy == "temperature":
generated = self._temperature_sampling(input_ids, max_new_tokens, temperature)
elif strategy == "top_k":
generated = self._top_k_sampling(input_ids, max_new_tokens, temperature, top_k)
elif strategy == "top_p":
generated = self._top_p_sampling(input_ids, max_new_tokens, temperature, top_p)
elif strategy == "beam":
generated = self._beam_search(input_ids, max_new_tokens, num_beams)
else:
raise ValueError(f"Unknown strategy: {strategy}")
new_tokens = generated[0][input_ids.shape[1]:].tolist()
return self.tokenizer.decode(new_tokens)
def _greedy_decode(self, input_ids, max_new_tokens):
"""Greedy decoding: always pick the highest probability token"""
current_ids = input_ids.clone()
for _ in range(max_new_tokens):
logits, _ = self.model(current_ids)
next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
current_ids = torch.cat([current_ids, next_token], dim=1)
if next_token.item() == self.tokenizer.eot_token:
break
return current_ids
def _temperature_sampling(self, input_ids, max_new_tokens, temperature):
"""Temperature sampling: scale logits before softmax"""
current_ids = input_ids.clone()
for _ in range(max_new_tokens):
logits, _ = self.model(current_ids)
logits = logits[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
current_ids = torch.cat([current_ids, next_token], dim=1)
if next_token.item() == self.tokenizer.eot_token:
break
return current_ids
def _top_k_sampling(self, input_ids, max_new_tokens, temperature, top_k):
"""Top-k sampling: sample only from the top k tokens"""
current_ids = input_ids.clone()
for _ in range(max_new_tokens):
logits, _ = self.model(current_ids)
logits = logits[:, -1, :] / temperature
top_k_logits, top_k_indices = torch.topk(logits, k=top_k, dim=-1)
filtered_logits = torch.full_like(logits, float("-inf"))
filtered_logits.scatter_(1, top_k_indices, top_k_logits)
probs = F.softmax(filtered_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
current_ids = torch.cat([current_ids, next_token], dim=1)
if next_token.item() == self.tokenizer.eot_token:
break
return current_ids
def _top_p_sampling(self, input_ids, max_new_tokens, temperature, top_p):
"""Top-p (Nucleus) sampling: sample from the top tokens covering probability mass p"""
current_ids = input_ids.clone()
for _ in range(max_new_tokens):
logits, _ = self.model(current_ids)
logits = logits[:, -1, :] / temperature
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above top_p
sorted_indices_to_remove = cumulative_probs - F.softmax(sorted_logits, dim=-1) > top_p
sorted_logits[sorted_indices_to_remove] = float("-inf")
logits = torch.zeros_like(logits).scatter_(1, sorted_indices, sorted_logits)
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
current_ids = torch.cat([current_ids, next_token], dim=1)
if next_token.item() == self.tokenizer.eot_token:
break
return current_ids
def _beam_search(self, input_ids, max_new_tokens, num_beams):
"""Beam search: explore multiple candidates simultaneously"""
B = input_ids.shape[0]
assert B == 1, "Beam search only supports batch size 1"
beams = [(input_ids, 0.0)]
for _ in range(max_new_tokens):
all_candidates = []
for beam_ids, beam_score in beams:
logits, _ = self.model(beam_ids)
log_probs = F.log_softmax(logits[:, -1, :], dim=-1)
top_log_probs, top_ids = log_probs.topk(num_beams)
for i in range(num_beams):
token = top_ids[0, i].unsqueeze(0).unsqueeze(0)
new_ids = torch.cat([beam_ids, token], dim=1)
new_score = beam_score + top_log_probs[0, i].item()
all_candidates.append((new_ids, new_score))
all_candidates.sort(key=lambda x: x[1] / x[0].shape[1], reverse=True)
beams = all_candidates[:num_beams]
if beams[0][0][0, -1].item() == self.tokenizer.eot_token:
break
return beams[0][0]
7. Perplexity Evaluation
import math
import torch
@torch.no_grad()
def compute_perplexity(
model: MiniGPT,
tokenizer: GPTTokenizer,
text: str,
device: str = "cuda",
seq_len: int = 512
) -> float:
"""Compute perplexity of a text"""
model.eval()
model.to(device)
tokens = tokenizer.encode(text, add_special_tokens=False)
total_loss = 0.0
total_tokens = 0
for i in range(0, len(tokens) - seq_len, seq_len):
chunk = tokens[i:i + seq_len + 1]
if len(chunk) < seq_len + 1:
break
x = torch.tensor(chunk[:-1], dtype=torch.long, device=device).unsqueeze(0)
y = torch.tensor(chunk[1:], dtype=torch.long, device=device).unsqueeze(0)
_, loss = model(x, y)
total_loss += loss.item() * seq_len
total_tokens += seq_len
avg_loss = total_loss / total_tokens
return math.exp(avg_loss)
# Usage
model = MiniGPT.from_pretrained("./checkpoints/best_model.pt")
tokenizer = GPTTokenizer()
generator = TextGenerator(model, tokenizer)
test_text = "The quick brown fox jumps over the lazy dog."
ppl = compute_perplexity(model, tokenizer, test_text)
print(f"Perplexity: {ppl:.2f}")
for strategy in ["greedy", "top_k", "top_p"]:
generated = generator.generate(
prompt="Once upon a time",
max_new_tokens=100,
strategy=strategy,
temperature=0.8
)
print(f"\n[{strategy}]: {generated}")
8. Scaling Up
8.1 Flash Attention 2
# pip install flash-attn
from flash_attn import flash_attn_qkvpacked_func
class FlashCausalAttention(nn.Module):
"""High-speed attention using Flash Attention 2"""
def __init__(self, config: GPTConfig):
super().__init__()
self.n_head = config.n_head
self.n_embd = config.n_embd
self.head_dim = config.n_embd // config.n_head
self.qkv_proj = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.dropout = config.dropout
def forward(self, x: torch.Tensor):
B, T, C = x.shape
qkv = self.qkv_proj(x)
# Flash Attention expects shape [B, T, 3, n_head, head_dim]
qkv = qkv.view(B, T, 3, self.n_head, self.head_dim)
attn_out = flash_attn_qkvpacked_func(
qkv,
dropout_p=self.dropout if self.training else 0.0,
causal=True
) # [B, T, n_head, head_dim]
attn_out = attn_out.reshape(B, T, C)
return self.out_proj(attn_out)
8.2 FSDP Distributed Training
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import functools
def setup_fsdp_training():
"""Setup FSDP distributed training"""
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
torch.cuda.set_device(rank)
config = GPTConfig(
n_embd=1024,
n_layer=24,
n_head=16,
vocab_size=50257
)
model = MiniGPT(config)
# Wrap policy: shard at TransformerBlock boundaries
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={TransformerBlock}
)
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=torch.distributed.fsdp.MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16
),
sharding_strategy="FULL_SHARD", # Equivalent to ZeRO-3
device_id=rank
)
return model, rank, world_size
8.3 Chinchilla Scaling Laws
The Chinchilla paper (Hoffmann et al., 2022) gives the following relationship between optimal model size and training tokens:
Optimal tokens = 20 x number of parameters
def chinchilla_optimal_tokens(model_params: int) -> int:
"""Compute optimal training tokens per Chinchilla scaling law"""
return 20 * model_params
models = {
"124M (GPT-2 Small)": 124e6,
"1.3B": 1.3e9,
"7B (Llama-2 7B)": 7e9,
"13B": 13e9,
"70B (Llama-2 70B)": 70e9,
}
print("Recommended training tokens by model size:")
for name, params in models.items():
optimal_tokens = chinchilla_optimal_tokens(int(params))
print(f" {name}: {optimal_tokens/1e9:.1f}B tokens")
9. Instruction Tuning
9.1 SFT Data Formats
# ChatML format
CHATML_SYSTEM = "<|im_start|>system\n{system}<|im_end|>\n"
CHATML_USER = "<|im_start|>user\n{user}<|im_end|>\n"
CHATML_ASSISTANT = "<|im_start|>assistant\n{assistant}<|im_end|>\n"
def format_chatml(system: str, user: str, assistant: str, add_generation_prompt: bool = False) -> str:
formatted = CHATML_SYSTEM.format(system=system)
formatted += CHATML_USER.format(user=user)
formatted += CHATML_ASSISTANT.format(assistant=assistant)
if add_generation_prompt:
formatted += "<|im_start|>assistant\n"
return formatted
# Llama-3 format
def format_llama3(system: str, user: str, assistant: str, add_generation_prompt: bool = False) -> str:
formatted = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system}<|eot_id|>\n"
formatted += f"<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|>\n"
formatted += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant}<|eot_id|>"
if add_generation_prompt:
formatted += "\n<|start_header_id|>assistant<|end_header_id|>\n\n"
return formatted
9.2 Instruction Dataset Class
class InstructionDataset(Dataset):
"""SFT dataset: only compute loss on assistant responses"""
def __init__(
self,
data: list,
tokenizer: GPTTokenizer,
max_seq_len: int = 2048,
format_func=format_chatml
):
self.samples = []
for item in data:
# Tokenize full conversation
full_text = format_func(
system=item.get("system", "You are a helpful assistant."),
user=item["instruction"],
assistant=item["response"]
)
full_ids = tokenizer.encode(full_text, add_special_tokens=False)
# Tokenize prompt only (to find where response starts)
prompt_text = format_func(
system=item.get("system", "You are a helpful assistant."),
user=item["instruction"],
assistant="",
add_generation_prompt=True
)
prompt_ids = tokenizer.encode(prompt_text, add_special_tokens=False)
prompt_len = len(prompt_ids)
if len(full_ids) > max_seq_len:
full_ids = full_ids[:max_seq_len]
# Mask prompt tokens in labels (set to -100 = ignore in loss)
labels = full_ids.copy()
labels[:prompt_len] = [-100] * prompt_len
self.samples.append({
"input_ids": full_ids,
"labels": labels
})
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample = self.samples[idx]
return (
torch.tensor(sample["input_ids"], dtype=torch.long),
torch.tensor(sample["labels"], dtype=torch.long)
)
9.3 SFT Training Loop
def train_sft(
base_model_path: str,
data_path: str,
output_dir: str,
num_epochs: int = 3,
learning_rate: float = 2e-5,
batch_size: int = 4
):
"""Supervised Fine-Tuning (SFT)"""
import json
with open(data_path) as f:
data = json.load(f)
tokenizer = GPTTokenizer()
dataset = InstructionDataset(data, tokenizer, max_seq_len=2048)
def collate_fn(batch):
input_ids = [item[0] for item in batch]
labels = [item[1] for item in batch]
max_len = max(len(x) for x in input_ids)
padded_ids = torch.zeros(len(batch), max_len, dtype=torch.long)
padded_labels = torch.full((len(batch), max_len), -100, dtype=torch.long)
for i, (ids, lbls) in enumerate(zip(input_ids, labels)):
padded_ids[i, :len(ids)] = ids
padded_labels[i, :len(lbls)] = lbls
return padded_ids, padded_labels
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
model = MiniGPT.from_pretrained(base_model_path)
device = "cuda"
model.to(device)
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
total_steps = len(dataloader) * num_epochs
scheduler = get_cosine_schedule_with_warmup(optimizer, int(total_steps * 0.1), total_steps)
model.train()
for epoch in range(num_epochs):
for step, (x, y) in enumerate(dataloader):
x, y = x.to(device), y.to(device)
logits, loss = model(x, y)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if step % 50 == 0:
print(f"Epoch {epoch}, Step {step}, Loss: {loss.item():.4f}")
os.makedirs(output_dir, exist_ok=True)
torch.save({
"model": model.state_dict(),
"config": model.config
}, f"{output_dir}/sft_model.pt")
print(f"SFT complete. Model saved: {output_dir}")
10. Open-Source LLM Architecture Analysis
10.1 Llama 3 Architecture
Llama 3 is Meta's open-source LLM with these key characteristics:
| Feature | Details |
|---|---|
| Architecture | Decoder-only Transformer |
| Normalization | Pre-RMSNorm |
| Activation | SwiGLU |
| Position Encoding | RoPE (theta=500,000) |
| Vocabulary Size | 128,256 |
| Context Length | 128K (Llama 3.1+) |
| Attention | GQA (8B, 70B models) |
# Llama 3 style config (8B model)
llama3_config = GPTConfig(
vocab_size=128256,
max_seq_len=8192,
n_embd=4096,
n_layer=32,
n_head=32,
use_rope=True,
use_swiglu=True,
bias=False
)
10.2 Grouped Query Attention (GQA)
class GroupedQueryAttention(nn.Module):
"""Grouped Query Attention (used in Llama 3, Mistral)"""
def __init__(self, config: GPTConfig, n_kv_heads: int = 8):
super().__init__()
self.n_head = config.n_head
self.n_kv_heads = n_kv_heads
self.n_rep = config.n_head // n_kv_heads
self.head_dim = config.n_embd // config.n_head
self.q_proj = nn.Linear(config.n_embd, config.n_head * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.n_embd, n_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.n_embd, n_kv_heads * self.head_dim, bias=False)
self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.rope = RotaryEmbedding(self.head_dim, config.max_seq_len, base=500000)
def forward(self, x: torch.Tensor):
B, T, C = x.shape
q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
q, k = self.rope(q, k, T)
# Repeat KV heads to match number of query heads (core of GQA)
k = k.unsqueeze(2).expand(B, self.n_kv_heads, self.n_rep, T, self.head_dim)
k = k.reshape(B, self.n_head, T, self.head_dim)
v = v.unsqueeze(2).expand(B, self.n_kv_heads, self.n_rep, T, self.head_dim)
v = v.reshape(B, self.n_head, T, self.head_dim)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.out_proj(y)
10.3 Mixture of Experts (MoE)
DeepSeek and Mixtral use Mixture of Experts to increase model capacity without proportionally increasing compute.
class MoEFFN(nn.Module):
"""Mixture of Experts FFN (simplified version)"""
def __init__(self, config: GPTConfig, num_experts: int = 8, top_k: int = 2):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
# Router
self.router = nn.Linear(config.n_embd, num_experts, bias=False)
# Expert FFNs
self.experts = nn.ModuleList([
SwiGLUFFN(config) for _ in range(num_experts)
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.shape
x_flat = x.view(B * T, C)
# Compute routing scores
router_logits = self.router(x_flat)
router_probs = F.softmax(router_logits, dim=-1)
# Select top-k experts
top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
# Compute expert outputs
output = torch.zeros_like(x_flat)
for k in range(self.top_k):
expert_idx = top_k_indices[:, k]
expert_prob = top_k_probs[:, k].unsqueeze(-1)
for e in range(self.num_experts):
mask = (expert_idx == e)
if mask.any():
expert_out = self.experts[e](x_flat[mask])
output[mask] += expert_prob[mask] * expert_out
return output.view(B, T, C)
Putting It All Together
Here is a complete script that ties all components together:
import torch
from dataclasses import dataclass
# ── 1. Configuration ─────────────────────────────────────
config = GPTConfig(
vocab_size=50257,
max_seq_len=1024,
n_embd=768,
n_layer=12,
n_head=12,
dropout=0.1,
use_rope=True,
use_swiglu=True
)
# ── 2. Model ─────────────────────────────────────────────
model = MiniGPT(config)
print(f"Model parameters: {model.count_params() / 1e6:.1f}M")
# ── 3. Tokenizer ─────────────────────────────────────────
tokenizer = GPTTokenizer()
# ── 4. Data ───────────────────────────────────────────────
train_loader = create_dataloader("train.txt", tokenizer, batch_size=8, seq_len=1024)
val_loader = create_dataloader("val.txt", tokenizer, batch_size=8, seq_len=1024, shuffle=False)
# ── 5. Trainer ────────────────────────────────────────────
trainer_config = {
"num_epochs": 5,
"learning_rate": 6e-4,
"weight_decay": 0.1,
"gradient_accumulation_steps": 4,
"max_grad_norm": 1.0,
"warmup_ratio": 0.05,
"log_interval": 100,
"eval_interval": 500,
"save_interval": 1000,
"device": "cuda"
}
trainer = Trainer(model, train_loader, val_loader, trainer_config)
trainer.train()
# ── 6. Generation ─────────────────────────────────────────
generator = TextGenerator(model, tokenizer)
prompts = [
"The history of artificial intelligence",
"In the year 2050, scientists discovered",
"The best way to learn programming is",
]
for prompt in prompts:
print(f"\nPrompt: {prompt}")
for strategy in ["greedy", "top_p"]:
text = generator.generate(prompt, max_new_tokens=100, strategy=strategy)
print(f" [{strategy}]: {text}")
Conclusion
In this guide, we built miniGPT from scratch and covered every core element of modern LLMs:
- Tokenizer: BPE algorithm splits text into subword tokens
- Embeddings: Token embeddings + RoPE positional encoding
- Multi-Head Causal Attention: Q/K/V projections, causal mask, KV cache
- FFN: SwiGLU activation for non-linearity
- Pre-RMSNorm: Training stability
- Text Generation: Greedy, Temperature, Top-k, Top-p, Beam Search
- Pretraining: Cosine LR schedule, gradient clipping, mixed precision
- SFT: Loss computed only on assistant response tokens
- Scaling: Flash Attention, FSDP, Chinchilla laws
This foundation gives you the tools to understand and extend any modern open-source LLM. The concepts covered here directly map to the implementation details of Llama 3, Mistral, and Qwen.
References
- Karpathy's nanoGPT: https://github.com/karpathy/nanoGPT
- Chinchilla Scaling Laws (Hoffmann et al., 2022): https://arxiv.org/abs/2203.15556
- Llama 3 Technical Report: https://ai.meta.com/blog/meta-llama-3
- Flash Attention: https://github.com/Dao-AILab/flash-attention
- OpenAI tiktoken: https://github.com/openai/tiktoken
- Attention is All You Need (Vaswani et al., 2017)
- RoFormer: Enhanced Transformer with Rotary Position Embedding (Su et al., 2021)