Skip to content
Published on

Building LLM from Scratch: Complete Guide to Understanding GPT through Code

Authors

Introduction

GPT, Llama, Mistral — large language models (LLMs) are at the center of the current AI revolution. Yet very few people understand how these models work at the code level. In this guide, we build "miniGPT," a small GPT-style model from scratch using PyTorch, mastering every core concept behind modern LLMs.


1. LLM Architecture Overview

1.1 The Pretraining Pipeline

Modern LLM development follows three stages:

  1. Pretraining: The model learns to predict the next token from billions of text tokens. During this stage, it acquires statistical language patterns, world knowledge, and reasoning ability.

  2. Instruction Fine-Tuning (SFT): The model is trained on human-written question-answer pairs so that it learns to follow instructions.

  3. Preference Learning (RLHF / DPO): The model is refined using human preference data to produce more helpful and safe responses.

1.2 miniGPT Specification

ParameterValue
Vocabulary size50,257 (same as GPT-2)
Context length1024
Embedding dimension768
Number of layers12
Attention heads12
FFN expansion ratio4x
Total parameters~124M

This matches GPT-2 Small and is trainable on a single GPU.


2. Tokenizer Implementation

2.1 BPE Algorithm

Byte-Pair Encoding (BPE) is the tokenization algorithm used by most modern LLMs.

from collections import defaultdict
from typing import List, Tuple, Dict

class SimpleBPETokenizer:
    """BPE tokenizer implemented from scratch"""

    def __init__(self):
        self.vocab = {}
        self.merges = {}
        self.special_tokens = {
            "<|endoftext|>": 50256,
        }

    def get_stats(self, vocab: Dict[str, int]) -> Dict[Tuple, int]:
        """Count frequency of adjacent token pairs"""
        pairs = defaultdict(int)
        for word, freq in vocab.items():
            symbols = word.split()
            for i in range(len(symbols) - 1):
                pairs[(symbols[i], symbols[i+1])] += freq
        return pairs

    def merge_vocab(self, pair: Tuple, v_in: Dict) -> Dict:
        """Merge the most frequent pair"""
        v_out = {}
        bigram = " ".join(pair)
        replacement = "".join(pair)
        for word in v_in:
            w_out = word.replace(bigram, replacement)
            v_out[w_out] = v_in[word]
        return v_out

    def train(self, corpus: List[str], vocab_size: int = 1000):
        """Train BPE vocabulary"""
        # 1. Build initial character-level vocabulary
        word_freq = defaultdict(int)
        for text in corpus:
            for word in text.split():
                word_freq[" ".join(list(word)) + " </w>"] += 1

        vocab = dict(word_freq)

        # 2. Iteratively merge until vocab_size is reached
        for i in range(vocab_size):
            pairs = self.get_stats(vocab)
            if not pairs:
                break
            best = max(pairs, key=pairs.get)
            vocab = self.merge_vocab(best, vocab)
            self.merges[best] = i
            print(f"Step {i}: merged {best} (freq={pairs[best]})")

        # 3. Build final vocabulary
        for word in vocab:
            for token in word.split():
                if token not in self.vocab:
                    self.vocab[token] = len(self.vocab)

        return self.vocab, self.merges

2.2 Using tiktoken

In practice, OpenAI's tiktoken library is much more efficient.

import tiktoken
import torch
from typing import List

class GPTTokenizer:
    """tiktoken-based GPT tokenizer wrapper"""

    def __init__(self, encoding_name: str = "gpt2"):
        self.enc = tiktoken.get_encoding(encoding_name)
        self.vocab_size = self.enc.n_vocab
        self.eot_token = self.enc.eot_token  # 50256

    def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
        """Convert text to list of token IDs"""
        ids = self.enc.encode(text, allowed_special={"<|endoftext|>"})
        if add_special_tokens:
            ids = [self.eot_token] + ids
        return ids

    def decode(self, ids: List[int]) -> str:
        """Convert list of token IDs back to text"""
        return self.enc.decode(ids)

    def encode_batch(self, texts: List[str]) -> List[List[int]]:
        """Batch encoding"""
        return [self.encode(t) for t in texts]

    def __len__(self):
        return self.vocab_size

# Usage
tokenizer = GPTTokenizer()
text = "Hello! Let's build GPT from scratch."
ids = tokenizer.encode(text)
print(f"Token count: {len(ids)}")
print(f"Token IDs: {ids}")
decoded = tokenizer.decode(ids)
print(f"Decoded: {decoded}")

3. Data Preparation

3.1 Text Dataset Class

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class TextDataset(Dataset):
    """Sliding window language modeling dataset"""

    def __init__(
        self,
        data_path: str,
        tokenizer: GPTTokenizer,
        seq_len: int = 1024,
        stride: int = 512
    ):
        self.seq_len = seq_len
        self.stride = stride

        print(f"Loading data: {data_path}")
        with open(data_path, "r", encoding="utf-8") as f:
            text = f.read()

        print(f"Total characters: {len(text):,}")
        self.tokens = tokenizer.encode(text, add_special_tokens=False)
        print(f"Total tokens: {len(self.tokens):,}")

        self.tokens = np.array(self.tokens, dtype=np.int32)

    def __len__(self):
        return max(0, (len(self.tokens) - self.seq_len) // self.stride)

    def __getitem__(self, idx):
        start = idx * self.stride
        end = start + self.seq_len + 1

        chunk = self.tokens[start:end]

        # Input: tokens[0:seq_len], Label: tokens[1:seq_len+1]
        x = torch.tensor(chunk[:-1], dtype=torch.long)
        y = torch.tensor(chunk[1:], dtype=torch.long)

        return x, y


def create_dataloader(
    data_path: str,
    tokenizer: GPTTokenizer,
    batch_size: int = 8,
    seq_len: int = 1024,
    stride: int = 512,
    num_workers: int = 4,
    shuffle: bool = True
) -> DataLoader:
    dataset = TextDataset(data_path, tokenizer, seq_len, stride)
    print(f"Dataset size: {len(dataset):,} samples")

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True
    )

4. Model Architecture

4.1 Configuration Class

from dataclasses import dataclass

@dataclass
class GPTConfig:
    """GPT model configuration"""
    vocab_size: int = 50257
    max_seq_len: int = 1024
    n_embd: int = 768
    n_layer: int = 12
    n_head: int = 12
    dropout: float = 0.1
    bias: bool = False

    ffn_mult: int = 4
    use_swiglu: bool = True
    use_rope: bool = True

4.2 RMSNorm

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization (Llama-style)"""

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        return self.weight * self._norm(x.float()).type_as(x)

4.3 Rotary Position Embedding (RoPE)

class RotaryEmbedding(nn.Module):
    """Rotary Position Embedding (RoPE)"""

    def __init__(self, dim: int, max_seq_len: int = 2048, base: int = 10000):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len

        # Compute frequencies: theta_i = base^(-2i/dim)
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

        self._build_cache(max_seq_len)

    def _build_cache(self, seq_len: int):
        t = torch.arange(seq_len, device=self.inv_freq.device).float()
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        self.register_buffer("cos_cache", emb.cos()[None, None, :, :])
        self.register_buffer("sin_cache", emb.sin()[None, None, :, :])

    def rotate_half(self, x):
        """Rotate half the tensor"""
        x1 = x[..., :x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2:]
        return torch.cat([-x2, x1], dim=-1)

    def forward(self, q: torch.Tensor, k: torch.Tensor, seq_len: int):
        cos = self.cos_cache[:, :, :seq_len, :].to(q.dtype)
        sin = self.sin_cache[:, :, :seq_len, :].to(q.dtype)

        q_rot = (q * cos) + (self.rotate_half(q) * sin)
        k_rot = (k * cos) + (self.rotate_half(k) * sin)
        return q_rot, k_rot

4.4 Multi-Head Causal Attention with KV Cache

class CausalSelfAttention(nn.Module):
    """Causal self-attention with KV cache support"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        assert config.n_embd % config.n_head == 0

        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = config.n_embd // config.n_head

        # Combined Q, K, V projection
        self.qkv_proj = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)

        if config.use_rope:
            self.rope = RotaryEmbedding(self.head_dim, config.max_seq_len)
        else:
            self.rope = None

        # Causal mask (lower triangular)
        self.register_buffer(
            "causal_mask",
            torch.tril(torch.ones(config.max_seq_len, config.max_seq_len))
            .view(1, 1, config.max_seq_len, config.max_seq_len)
        )

    def forward(
        self,
        x: torch.Tensor,
        use_cache: bool = False,
        past_kv=None
    ):
        B, T, C = x.shape

        qkv = self.qkv_proj(x)
        q, k, v = qkv.split(self.n_embd, dim=2)

        # Split heads: [B, T, C] -> [B, n_head, T, head_dim]
        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)

        if self.rope is not None:
            q, k = self.rope(q, k, T)

        # KV cache handling
        if past_kv is not None:
            past_k, past_v = past_kv
            k = torch.cat([past_k, k], dim=2)
            v = torch.cat([past_v, v], dim=2)

        present_kv = (k, v) if use_cache else None
        kv_len = k.shape[2]

        # Scaled Dot-Product Attention (uses Flash Attention if available)
        if hasattr(F, "scaled_dot_product_attention"):
            y = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask=None,
                dropout_p=self.attn_dropout.p if self.training else 0.0,
                is_causal=True
            )
        else:
            # Manual implementation
            scale = 1.0 / math.sqrt(self.head_dim)
            attn = (q @ k.transpose(-2, -1)) * scale
            mask = self.causal_mask[:, :, :T, :kv_len]
            attn = attn.masked_fill(mask == 0, float("-inf"))
            attn = F.softmax(attn, dim=-1)
            attn = self.attn_dropout(attn)
            y = attn @ v

        # Merge heads: [B, n_head, T, head_dim] -> [B, T, C]
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.out_proj(y))

        return y, present_kv

4.5 SwiGLU Feed-Forward Network

class SwiGLUFFN(nn.Module):
    """FFN with SwiGLU activation (Llama-style)"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        hidden_dim = int(config.n_embd * config.ffn_mult * 2 / 3)
        # Round up to nearest 64 for hardware efficiency
        hidden_dim = ((hidden_dim + 63) // 64) * 64

        self.gate_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)
        self.up_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)
        self.down_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # SwiGLU: gate * SiLU(x)
        gate = F.silu(self.gate_proj(x))
        up = self.up_proj(x)
        return self.dropout(self.down_proj(gate * up))


class GPTFFN(nn.Module):
    """Classic GPT-style GELU FFN"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        hidden_dim = config.n_embd * config.ffn_mult
        self.fc1 = nn.Linear(config.n_embd, hidden_dim, bias=config.bias)
        self.fc2 = nn.Linear(hidden_dim, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.gelu(self.fc1(x))
        x = self.dropout(self.fc2(x))
        return x

4.6 Transformer Block

class TransformerBlock(nn.Module):
    """Pre-RMSNorm Transformer block"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        self.norm1 = RMSNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.norm2 = RMSNorm(config.n_embd)

        if config.use_swiglu:
            self.ffn = SwiGLUFFN(config)
        else:
            self.ffn = GPTFFN(config)

    def forward(self, x: torch.Tensor, use_cache: bool = False, past_kv=None):
        # Pre-Norm + Attention + Residual
        attn_out, present_kv = self.attn(
            self.norm1(x),
            use_cache=use_cache,
            past_kv=past_kv
        )
        x = x + attn_out

        # Pre-Norm + FFN + Residual
        x = x + self.ffn(self.norm2(x))

        return x, present_kv

4.7 Complete GPT Model

class MiniGPT(nn.Module):
    """Full miniGPT model"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config

        # Embedding layers
        self.token_emb = nn.Embedding(config.vocab_size, config.n_embd)

        if not config.use_rope:
            self.pos_emb = nn.Embedding(config.max_seq_len, config.n_embd)

        self.emb_dropout = nn.Dropout(config.dropout)

        # Transformer layers
        self.layers = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.n_layer)
        ])

        # Final normalization
        self.norm_final = RMSNorm(config.n_embd)

        # Language model head
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # Weight tying (input embeddings == output embeddings)
        self.lm_head.weight = self.token_emb.weight

        # Weight initialization
        self.apply(self._init_weights)

        # Special scaling for residual stream
        for pn, p in self.named_parameters():
            if pn.endswith("out_proj.weight") or pn.endswith("down_proj.weight"):
                nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))

        print(f"Parameters: {self.count_params() / 1e6:.1f}M")

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def count_params(self) -> int:
        return sum(p.numel() for p in self.parameters())

    def forward(
        self,
        input_ids: torch.Tensor,
        targets: torch.Tensor = None,
        use_cache: bool = False,
        past_kvs: list = None
    ):
        B, T = input_ids.shape
        device = input_ids.device

        # Token embeddings
        x = self.token_emb(input_ids)  # [B, T, n_embd]

        # Positional embeddings (only when not using RoPE)
        if not self.config.use_rope:
            positions = torch.arange(T, device=device)
            x = x + self.pos_emb(positions)

        x = self.emb_dropout(x)

        # Transformer layers
        present_kvs = []
        for i, layer in enumerate(self.layers):
            past_kv = past_kvs[i] if past_kvs is not None else None
            x, present_kv = layer(x, use_cache=use_cache, past_kv=past_kv)
            present_kvs.append(present_kv)

        x = self.norm_final(x)

        if targets is not None:
            # Training: compute loss over full sequence
            logits = self.lm_head(x)
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
                ignore_index=-1
            )
            return logits, loss
        else:
            # Inference: only need last token logits
            logits = self.lm_head(x[:, [-1], :])
            return logits, present_kvs

    @classmethod
    def from_pretrained(cls, model_path: str):
        checkpoint = torch.load(model_path, map_location="cpu")
        config = checkpoint["config"]
        model = cls(config)
        model.load_state_dict(checkpoint["model"])
        return model

5. Pretraining Implementation

5.1 Learning Rate Scheduler

import math

def get_cosine_schedule_with_warmup(
    optimizer,
    warmup_steps: int,
    total_steps: int,
    min_lr_ratio: float = 0.1
):
    """Cosine LR schedule with warmup"""

    def lr_lambda(current_step: int):
        if current_step < warmup_steps:
            return current_step / max(1, warmup_steps)
        else:
            progress = (current_step - warmup_steps) / max(1, total_steps - warmup_steps)
            cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
            return max(min_lr_ratio, cosine_decay)

    from torch.optim.lr_scheduler import LambdaLR
    return LambdaLR(optimizer, lr_lambda)

5.2 Training Loop

import torch
import torch.nn as nn
from torch.optim import AdamW
import os
import time

class Trainer:
    """miniGPT pretraining trainer"""

    def __init__(self, model: MiniGPT, train_dataloader, val_dataloader, config: dict):
        self.model = model
        self.train_dl = train_dataloader
        self.val_dl = val_dataloader
        self.config = config
        self.device = config.get("device", "cuda" if torch.cuda.is_available() else "cpu")

        # Mixed precision
        self.dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        self.scaler = torch.cuda.amp.GradScaler(enabled=(self.dtype == torch.float16))
        self.ctx = torch.amp.autocast(device_type="cuda", dtype=self.dtype)

        self.optimizer = self._create_optimizer()

        total_steps = len(train_dataloader) * config["num_epochs"]
        warmup_steps = int(total_steps * config.get("warmup_ratio", 0.05))
        self.scheduler = get_cosine_schedule_with_warmup(
            self.optimizer,
            warmup_steps=warmup_steps,
            total_steps=total_steps
        )

        self.step = 0
        self.best_val_loss = float("inf")

    def _create_optimizer(self):
        """Separate params with/without weight decay"""
        decay_params = []
        no_decay_params = []

        for name, param in self.model.named_parameters():
            if not param.requires_grad:
                continue
            if param.dim() == 1 or "emb" in name:
                no_decay_params.append(param)
            else:
                decay_params.append(param)

        optim_groups = [
            {"params": decay_params, "weight_decay": self.config.get("weight_decay", 0.1)},
            {"params": no_decay_params, "weight_decay": 0.0}
        ]

        return AdamW(
            optim_groups,
            lr=self.config["learning_rate"],
            betas=(0.9, 0.95),
            fused=True
        )

    def compute_val_loss(self) -> float:
        self.model.eval()
        total_loss = 0.0
        num_batches = min(len(self.val_dl), 50)

        with torch.no_grad():
            for i, (x, y) in enumerate(self.val_dl):
                if i >= num_batches:
                    break
                x, y = x.to(self.device), y.to(self.device)
                with self.ctx:
                    _, loss = self.model(x, y)
                total_loss += loss.item()

        self.model.train()
        return total_loss / num_batches

    def save_checkpoint(self, path: str):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        torch.save({
            "step": self.step,
            "model": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "scheduler": self.scheduler.state_dict(),
            "config": self.model.config,
            "best_val_loss": self.best_val_loss
        }, path)
        print(f"Checkpoint saved: {path}")

    def train(self):
        self.model.to(self.device)
        self.model.train()

        num_epochs = self.config["num_epochs"]
        grad_accum = self.config.get("gradient_accumulation_steps", 1)
        max_grad_norm = self.config.get("max_grad_norm", 1.0)
        log_interval = self.config.get("log_interval", 100)
        eval_interval = self.config.get("eval_interval", 500)
        save_interval = self.config.get("save_interval", 1000)

        print(f"Training on {self.device}, dtype={self.dtype}")

        for epoch in range(num_epochs):
            t0 = time.time()

            for batch_idx, (x, y) in enumerate(self.train_dl):
                x, y = x.to(self.device), y.to(self.device)

                with self.ctx:
                    _, loss = self.model(x, y)
                    loss = loss / grad_accum

                self.scaler.scale(loss).backward()

                if (batch_idx + 1) % grad_accum == 0:
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)

                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    self.scheduler.step()
                    self.optimizer.zero_grad(set_to_none=True)
                    self.step += 1

                if self.step % log_interval == 0:
                    lr = self.optimizer.param_groups[0]["lr"]
                    elapsed = time.time() - t0
                    tokens_per_sec = log_interval * x.shape[0] * x.shape[1] / elapsed
                    print(
                        f"Epoch {epoch} | Step {self.step} | "
                        f"Loss: {loss.item() * grad_accum:.4f} | "
                        f"LR: {lr:.6f} | "
                        f"Tokens/s: {tokens_per_sec:.0f}"
                    )
                    t0 = time.time()

                if self.step % eval_interval == 0:
                    val_loss = self.compute_val_loss()
                    print(f"Val Loss: {val_loss:.4f} | Perplexity: {math.exp(val_loss):.2f}")
                    if val_loss < self.best_val_loss:
                        self.best_val_loss = val_loss
                        self.save_checkpoint("./checkpoints/best_model.pt")

                if self.step % save_interval == 0:
                    self.save_checkpoint(f"./checkpoints/step_{self.step}.pt")

6. Text Generation

6.1 Decoding Strategies

import torch
import torch.nn.functional as F

class TextGenerator:
    """miniGPT text generator"""

    def __init__(self, model: MiniGPT, tokenizer: GPTTokenizer, device: str = "cuda"):
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.device = device
        self.model.eval()

    @torch.no_grad()
    def generate(
        self,
        prompt: str,
        max_new_tokens: int = 200,
        strategy: str = "top_p",
        temperature: float = 0.8,
        top_k: int = 50,
        top_p: float = 0.9,
        num_beams: int = 4
    ) -> str:
        input_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
        input_ids = torch.tensor(input_ids, dtype=torch.long, device=self.device).unsqueeze(0)

        if strategy == "greedy":
            generated = self._greedy_decode(input_ids, max_new_tokens)
        elif strategy == "temperature":
            generated = self._temperature_sampling(input_ids, max_new_tokens, temperature)
        elif strategy == "top_k":
            generated = self._top_k_sampling(input_ids, max_new_tokens, temperature, top_k)
        elif strategy == "top_p":
            generated = self._top_p_sampling(input_ids, max_new_tokens, temperature, top_p)
        elif strategy == "beam":
            generated = self._beam_search(input_ids, max_new_tokens, num_beams)
        else:
            raise ValueError(f"Unknown strategy: {strategy}")

        new_tokens = generated[0][input_ids.shape[1]:].tolist()
        return self.tokenizer.decode(new_tokens)

    def _greedy_decode(self, input_ids, max_new_tokens):
        """Greedy decoding: always pick the highest probability token"""
        current_ids = input_ids.clone()
        for _ in range(max_new_tokens):
            logits, _ = self.model(current_ids)
            next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
            current_ids = torch.cat([current_ids, next_token], dim=1)
            if next_token.item() == self.tokenizer.eot_token:
                break
        return current_ids

    def _temperature_sampling(self, input_ids, max_new_tokens, temperature):
        """Temperature sampling: scale logits before softmax"""
        current_ids = input_ids.clone()
        for _ in range(max_new_tokens):
            logits, _ = self.model(current_ids)
            logits = logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            current_ids = torch.cat([current_ids, next_token], dim=1)
            if next_token.item() == self.tokenizer.eot_token:
                break
        return current_ids

    def _top_k_sampling(self, input_ids, max_new_tokens, temperature, top_k):
        """Top-k sampling: sample only from the top k tokens"""
        current_ids = input_ids.clone()
        for _ in range(max_new_tokens):
            logits, _ = self.model(current_ids)
            logits = logits[:, -1, :] / temperature

            top_k_logits, top_k_indices = torch.topk(logits, k=top_k, dim=-1)
            filtered_logits = torch.full_like(logits, float("-inf"))
            filtered_logits.scatter_(1, top_k_indices, top_k_logits)

            probs = F.softmax(filtered_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            current_ids = torch.cat([current_ids, next_token], dim=1)
            if next_token.item() == self.tokenizer.eot_token:
                break
        return current_ids

    def _top_p_sampling(self, input_ids, max_new_tokens, temperature, top_p):
        """Top-p (Nucleus) sampling: sample from the top tokens covering probability mass p"""
        current_ids = input_ids.clone()
        for _ in range(max_new_tokens):
            logits, _ = self.model(current_ids)
            logits = logits[:, -1, :] / temperature

            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

            # Remove tokens with cumulative probability above top_p
            sorted_indices_to_remove = cumulative_probs - F.softmax(sorted_logits, dim=-1) > top_p
            sorted_logits[sorted_indices_to_remove] = float("-inf")

            logits = torch.zeros_like(logits).scatter_(1, sorted_indices, sorted_logits)
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            current_ids = torch.cat([current_ids, next_token], dim=1)
            if next_token.item() == self.tokenizer.eot_token:
                break
        return current_ids

    def _beam_search(self, input_ids, max_new_tokens, num_beams):
        """Beam search: explore multiple candidates simultaneously"""
        B = input_ids.shape[0]
        assert B == 1, "Beam search only supports batch size 1"

        beams = [(input_ids, 0.0)]

        for _ in range(max_new_tokens):
            all_candidates = []
            for beam_ids, beam_score in beams:
                logits, _ = self.model(beam_ids)
                log_probs = F.log_softmax(logits[:, -1, :], dim=-1)
                top_log_probs, top_ids = log_probs.topk(num_beams)

                for i in range(num_beams):
                    token = top_ids[0, i].unsqueeze(0).unsqueeze(0)
                    new_ids = torch.cat([beam_ids, token], dim=1)
                    new_score = beam_score + top_log_probs[0, i].item()
                    all_candidates.append((new_ids, new_score))

            all_candidates.sort(key=lambda x: x[1] / x[0].shape[1], reverse=True)
            beams = all_candidates[:num_beams]

            if beams[0][0][0, -1].item() == self.tokenizer.eot_token:
                break

        return beams[0][0]

7. Perplexity Evaluation

import math
import torch

@torch.no_grad()
def compute_perplexity(
    model: MiniGPT,
    tokenizer: GPTTokenizer,
    text: str,
    device: str = "cuda",
    seq_len: int = 512
) -> float:
    """Compute perplexity of a text"""
    model.eval()
    model.to(device)

    tokens = tokenizer.encode(text, add_special_tokens=False)
    total_loss = 0.0
    total_tokens = 0

    for i in range(0, len(tokens) - seq_len, seq_len):
        chunk = tokens[i:i + seq_len + 1]
        if len(chunk) < seq_len + 1:
            break

        x = torch.tensor(chunk[:-1], dtype=torch.long, device=device).unsqueeze(0)
        y = torch.tensor(chunk[1:], dtype=torch.long, device=device).unsqueeze(0)

        _, loss = model(x, y)
        total_loss += loss.item() * seq_len
        total_tokens += seq_len

    avg_loss = total_loss / total_tokens
    return math.exp(avg_loss)


# Usage
model = MiniGPT.from_pretrained("./checkpoints/best_model.pt")
tokenizer = GPTTokenizer()
generator = TextGenerator(model, tokenizer)

test_text = "The quick brown fox jumps over the lazy dog."
ppl = compute_perplexity(model, tokenizer, test_text)
print(f"Perplexity: {ppl:.2f}")

for strategy in ["greedy", "top_k", "top_p"]:
    generated = generator.generate(
        prompt="Once upon a time",
        max_new_tokens=100,
        strategy=strategy,
        temperature=0.8
    )
    print(f"\n[{strategy}]: {generated}")

8. Scaling Up

8.1 Flash Attention 2

# pip install flash-attn

from flash_attn import flash_attn_qkvpacked_func

class FlashCausalAttention(nn.Module):
    """High-speed attention using Flash Attention 2"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = config.n_embd // config.n_head

        self.qkv_proj = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
        self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.dropout = config.dropout

    def forward(self, x: torch.Tensor):
        B, T, C = x.shape

        qkv = self.qkv_proj(x)
        # Flash Attention expects shape [B, T, 3, n_head, head_dim]
        qkv = qkv.view(B, T, 3, self.n_head, self.head_dim)

        attn_out = flash_attn_qkvpacked_func(
            qkv,
            dropout_p=self.dropout if self.training else 0.0,
            causal=True
        )  # [B, T, n_head, head_dim]

        attn_out = attn_out.reshape(B, T, C)
        return self.out_proj(attn_out)

8.2 FSDP Distributed Training

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import functools

def setup_fsdp_training():
    """Setup FSDP distributed training"""
    dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    torch.cuda.set_device(rank)

    config = GPTConfig(
        n_embd=1024,
        n_layer=24,
        n_head=16,
        vocab_size=50257
    )
    model = MiniGPT(config)

    # Wrap policy: shard at TransformerBlock boundaries
    auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={TransformerBlock}
    )

    model = FSDP(
        model,
        auto_wrap_policy=auto_wrap_policy,
        mixed_precision=torch.distributed.fsdp.MixedPrecision(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.bfloat16,
            buffer_dtype=torch.bfloat16
        ),
        sharding_strategy="FULL_SHARD",  # Equivalent to ZeRO-3
        device_id=rank
    )

    return model, rank, world_size

8.3 Chinchilla Scaling Laws

The Chinchilla paper (Hoffmann et al., 2022) gives the following relationship between optimal model size and training tokens:

Optimal tokens = 20 x number of parameters

def chinchilla_optimal_tokens(model_params: int) -> int:
    """Compute optimal training tokens per Chinchilla scaling law"""
    return 20 * model_params

models = {
    "124M (GPT-2 Small)": 124e6,
    "1.3B": 1.3e9,
    "7B (Llama-2 7B)": 7e9,
    "13B": 13e9,
    "70B (Llama-2 70B)": 70e9,
}

print("Recommended training tokens by model size:")
for name, params in models.items():
    optimal_tokens = chinchilla_optimal_tokens(int(params))
    print(f"  {name}: {optimal_tokens/1e9:.1f}B tokens")

9. Instruction Tuning

9.1 SFT Data Formats

# ChatML format
CHATML_SYSTEM = "<|im_start|>system\n{system}<|im_end|>\n"
CHATML_USER = "<|im_start|>user\n{user}<|im_end|>\n"
CHATML_ASSISTANT = "<|im_start|>assistant\n{assistant}<|im_end|>\n"

def format_chatml(system: str, user: str, assistant: str, add_generation_prompt: bool = False) -> str:
    formatted = CHATML_SYSTEM.format(system=system)
    formatted += CHATML_USER.format(user=user)
    formatted += CHATML_ASSISTANT.format(assistant=assistant)
    if add_generation_prompt:
        formatted += "<|im_start|>assistant\n"
    return formatted


# Llama-3 format
def format_llama3(system: str, user: str, assistant: str, add_generation_prompt: bool = False) -> str:
    formatted = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system}<|eot_id|>\n"
    formatted += f"<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|>\n"
    formatted += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant}<|eot_id|>"
    if add_generation_prompt:
        formatted += "\n<|start_header_id|>assistant<|end_header_id|>\n\n"
    return formatted

9.2 Instruction Dataset Class

class InstructionDataset(Dataset):
    """SFT dataset: only compute loss on assistant responses"""

    def __init__(
        self,
        data: list,
        tokenizer: GPTTokenizer,
        max_seq_len: int = 2048,
        format_func=format_chatml
    ):
        self.samples = []

        for item in data:
            # Tokenize full conversation
            full_text = format_func(
                system=item.get("system", "You are a helpful assistant."),
                user=item["instruction"],
                assistant=item["response"]
            )
            full_ids = tokenizer.encode(full_text, add_special_tokens=False)

            # Tokenize prompt only (to find where response starts)
            prompt_text = format_func(
                system=item.get("system", "You are a helpful assistant."),
                user=item["instruction"],
                assistant="",
                add_generation_prompt=True
            )
            prompt_ids = tokenizer.encode(prompt_text, add_special_tokens=False)
            prompt_len = len(prompt_ids)

            if len(full_ids) > max_seq_len:
                full_ids = full_ids[:max_seq_len]

            # Mask prompt tokens in labels (set to -100 = ignore in loss)
            labels = full_ids.copy()
            labels[:prompt_len] = [-100] * prompt_len

            self.samples.append({
                "input_ids": full_ids,
                "labels": labels
            })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        return (
            torch.tensor(sample["input_ids"], dtype=torch.long),
            torch.tensor(sample["labels"], dtype=torch.long)
        )

9.3 SFT Training Loop

def train_sft(
    base_model_path: str,
    data_path: str,
    output_dir: str,
    num_epochs: int = 3,
    learning_rate: float = 2e-5,
    batch_size: int = 4
):
    """Supervised Fine-Tuning (SFT)"""
    import json

    with open(data_path) as f:
        data = json.load(f)

    tokenizer = GPTTokenizer()
    dataset = InstructionDataset(data, tokenizer, max_seq_len=2048)

    def collate_fn(batch):
        input_ids = [item[0] for item in batch]
        labels = [item[1] for item in batch]
        max_len = max(len(x) for x in input_ids)
        padded_ids = torch.zeros(len(batch), max_len, dtype=torch.long)
        padded_labels = torch.full((len(batch), max_len), -100, dtype=torch.long)
        for i, (ids, lbls) in enumerate(zip(input_ids, labels)):
            padded_ids[i, :len(ids)] = ids
            padded_labels[i, :len(lbls)] = lbls
        return padded_ids, padded_labels

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

    model = MiniGPT.from_pretrained(base_model_path)
    device = "cuda"
    model.to(device)

    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    total_steps = len(dataloader) * num_epochs
    scheduler = get_cosine_schedule_with_warmup(optimizer, int(total_steps * 0.1), total_steps)

    model.train()
    for epoch in range(num_epochs):
        for step, (x, y) in enumerate(dataloader):
            x, y = x.to(device), y.to(device)
            logits, loss = model(x, y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            if step % 50 == 0:
                print(f"Epoch {epoch}, Step {step}, Loss: {loss.item():.4f}")

    os.makedirs(output_dir, exist_ok=True)
    torch.save({
        "model": model.state_dict(),
        "config": model.config
    }, f"{output_dir}/sft_model.pt")
    print(f"SFT complete. Model saved: {output_dir}")

10. Open-Source LLM Architecture Analysis

10.1 Llama 3 Architecture

Llama 3 is Meta's open-source LLM with these key characteristics:

FeatureDetails
ArchitectureDecoder-only Transformer
NormalizationPre-RMSNorm
ActivationSwiGLU
Position EncodingRoPE (theta=500,000)
Vocabulary Size128,256
Context Length128K (Llama 3.1+)
AttentionGQA (8B, 70B models)
# Llama 3 style config (8B model)
llama3_config = GPTConfig(
    vocab_size=128256,
    max_seq_len=8192,
    n_embd=4096,
    n_layer=32,
    n_head=32,
    use_rope=True,
    use_swiglu=True,
    bias=False
)

10.2 Grouped Query Attention (GQA)

class GroupedQueryAttention(nn.Module):
    """Grouped Query Attention (used in Llama 3, Mistral)"""

    def __init__(self, config: GPTConfig, n_kv_heads: int = 8):
        super().__init__()
        self.n_head = config.n_head
        self.n_kv_heads = n_kv_heads
        self.n_rep = config.n_head // n_kv_heads
        self.head_dim = config.n_embd // config.n_head

        self.q_proj = nn.Linear(config.n_embd, config.n_head * self.head_dim, bias=False)
        self.k_proj = nn.Linear(config.n_embd, n_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(config.n_embd, n_kv_heads * self.head_dim, bias=False)
        self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)

        self.rope = RotaryEmbedding(self.head_dim, config.max_seq_len, base=500000)

    def forward(self, x: torch.Tensor):
        B, T, C = x.shape

        q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)

        q, k = self.rope(q, k, T)

        # Repeat KV heads to match number of query heads (core of GQA)
        k = k.unsqueeze(2).expand(B, self.n_kv_heads, self.n_rep, T, self.head_dim)
        k = k.reshape(B, self.n_head, T, self.head_dim)
        v = v.unsqueeze(2).expand(B, self.n_kv_heads, self.n_rep, T, self.head_dim)
        v = v.reshape(B, self.n_head, T, self.head_dim)

        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(y)

10.3 Mixture of Experts (MoE)

DeepSeek and Mixtral use Mixture of Experts to increase model capacity without proportionally increasing compute.

class MoEFFN(nn.Module):
    """Mixture of Experts FFN (simplified version)"""

    def __init__(self, config: GPTConfig, num_experts: int = 8, top_k: int = 2):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k

        # Router
        self.router = nn.Linear(config.n_embd, num_experts, bias=False)

        # Expert FFNs
        self.experts = nn.ModuleList([
            SwiGLUFFN(config) for _ in range(num_experts)
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape
        x_flat = x.view(B * T, C)

        # Compute routing scores
        router_logits = self.router(x_flat)
        router_probs = F.softmax(router_logits, dim=-1)

        # Select top-k experts
        top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)

        # Compute expert outputs
        output = torch.zeros_like(x_flat)
        for k in range(self.top_k):
            expert_idx = top_k_indices[:, k]
            expert_prob = top_k_probs[:, k].unsqueeze(-1)

            for e in range(self.num_experts):
                mask = (expert_idx == e)
                if mask.any():
                    expert_out = self.experts[e](x_flat[mask])
                    output[mask] += expert_prob[mask] * expert_out

        return output.view(B, T, C)

Putting It All Together

Here is a complete script that ties all components together:

import torch
from dataclasses import dataclass

# ── 1. Configuration ─────────────────────────────────────
config = GPTConfig(
    vocab_size=50257,
    max_seq_len=1024,
    n_embd=768,
    n_layer=12,
    n_head=12,
    dropout=0.1,
    use_rope=True,
    use_swiglu=True
)

# ── 2. Model ─────────────────────────────────────────────
model = MiniGPT(config)
print(f"Model parameters: {model.count_params() / 1e6:.1f}M")

# ── 3. Tokenizer ─────────────────────────────────────────
tokenizer = GPTTokenizer()

# ── 4. Data ───────────────────────────────────────────────
train_loader = create_dataloader("train.txt", tokenizer, batch_size=8, seq_len=1024)
val_loader = create_dataloader("val.txt", tokenizer, batch_size=8, seq_len=1024, shuffle=False)

# ── 5. Trainer ────────────────────────────────────────────
trainer_config = {
    "num_epochs": 5,
    "learning_rate": 6e-4,
    "weight_decay": 0.1,
    "gradient_accumulation_steps": 4,
    "max_grad_norm": 1.0,
    "warmup_ratio": 0.05,
    "log_interval": 100,
    "eval_interval": 500,
    "save_interval": 1000,
    "device": "cuda"
}

trainer = Trainer(model, train_loader, val_loader, trainer_config)
trainer.train()

# ── 6. Generation ─────────────────────────────────────────
generator = TextGenerator(model, tokenizer)

prompts = [
    "The history of artificial intelligence",
    "In the year 2050, scientists discovered",
    "The best way to learn programming is",
]

for prompt in prompts:
    print(f"\nPrompt: {prompt}")
    for strategy in ["greedy", "top_p"]:
        text = generator.generate(prompt, max_new_tokens=100, strategy=strategy)
        print(f"  [{strategy}]: {text}")

Conclusion

In this guide, we built miniGPT from scratch and covered every core element of modern LLMs:

  • Tokenizer: BPE algorithm splits text into subword tokens
  • Embeddings: Token embeddings + RoPE positional encoding
  • Multi-Head Causal Attention: Q/K/V projections, causal mask, KV cache
  • FFN: SwiGLU activation for non-linearity
  • Pre-RMSNorm: Training stability
  • Text Generation: Greedy, Temperature, Top-k, Top-p, Beam Search
  • Pretraining: Cosine LR schedule, gradient clipping, mixed precision
  • SFT: Loss computed only on assistant response tokens
  • Scaling: Flash Attention, FSDP, Chinchilla laws

This foundation gives you the tools to understand and extend any modern open-source LLM. The concepts covered here directly map to the implementation details of Llama 3, Mistral, and Qwen.

References