- Authors

- Name
- Youngju Kim
- @fjvbn20031
Transformer Architecture Complete Analysis: From Attention to Modern LLMs
Google's 2017 paper "Attention is All You Need" completely transformed the natural language processing landscape. Moving away from sequential RNN and LSTM architectures, the Transformer — built entirely on the Attention mechanism — became the foundation for GPT, BERT, T5, LLaMA, and every major modern LLM. This guide takes you from first principles through the complete architecture, with PyTorch implementations at every step.
1. The Origins of Attention
1.1 Limitations of RNN/LSTM
Before Transformers, sequence modeling relied heavily on RNNs and LSTMs. These models propagated information through hidden states processed one time step at a time, with fundamental limitations:
Long-Range Dependency Problem
RNNs struggle to learn relationships between words that are far apart. In "The cat that sat on the mat was hungry," the connection between "cat" and "was hungry" must pass through every intermediate token. As sequences grow longer, information fades or gets overwritten. LSTMs alleviate this with gating, but provide no complete solution.
Sequential Processing
RNNs require step t to complete before step t+1 begins. This means the parallel processing power of modern GPUs goes completely unused for sequence-internal computation. You can parallelize across batches, but not within a sequence.
Vanishing/Exploding Gradients
Backpropagating through hundreds or thousands of time steps causes gradients to either shrink to zero or grow uncontrollably. Even with LSTM gates, sequences beyond a few hundred tokens are problematic.
1.2 The Intuition Behind Attention
The Attention mechanism draws inspiration from how humans read. We do not give equal attention to every word — we focus more on words relevant to what we are currently processing.
Consider: "I saw the Eiffel Tower in Paris; it was breathtaking." When processing "it," the model should attend strongly to both "Eiffel Tower" and "Paris." Every word can directly interact with every other word, regardless of distance.
Bahdanau et al. (2014) introduced the first Attention mechanism as an auxiliary component in a seq2seq model. Vaswani et al. (2017) then made the decisive move: eliminate the RNN entirely and build everything from Attention alone.
2. Scaled Dot-Product Attention
2.1 The Q, K, V Framework
The Transformer uses three vectors — Query, Key, and Value — to implement Attention. Think of it as a soft database lookup:
- Query (Q): "What am I looking for?" — the current position's representation
- Key (K): "What information do I hold?" — each position's label/identifier
- Value (V): "What is actually stored?" — each position's content
We compute similarity scores between Queries and Keys, normalize them with softmax into attention weights, then use those weights to produce a weighted sum of the Values.
2.2 The Formula
Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V
Where:
- Q: Query matrix (seq_len x d_k)
- K: Key matrix (seq_len x d_k)
- V: Value matrix (seq_len x d_v)
- d_k: Key/Query dimension
- sqrt(d_k): scaling factor
Why scale by sqrt(d_k)?
As d_k grows, dot products grow proportionally in magnitude. With d_k=512, raw dot products can be very large, pushing softmax into regions with near-zero gradients. Dividing by sqrt(d_k) keeps the variance of dot products close to 1 (assuming q and k have unit variance components), preventing gradient issues.
2.3 Masking
Padding Mask: When sequences in a batch have different lengths, shorter sequences are padded. We add -inf to padded positions so softmax assigns them zero weight.
Causal Mask (Look-ahead Mask): Used in Decoders. Position i should only attend to positions 0 through i — never future positions. We fill the upper triangle of the score matrix with -inf.
2.4 PyTorch Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def scaled_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = None,
dropout_p: float = 0.0,
) -> tuple:
"""
Scaled Dot-Product Attention
Args:
query: (batch, heads, seq_len, d_k)
key: (batch, heads, seq_len, d_k)
value: (batch, heads, seq_len, d_v)
mask: (batch, 1, 1, seq_len) or (batch, 1, seq_len, seq_len)
dropout_p: dropout probability
Returns:
output: (batch, heads, seq_len, d_v)
attn_weights: (batch, heads, seq_len, seq_len)
"""
d_k = query.size(-1)
# Q * K^T / sqrt(d_k): (batch, heads, seq_len, seq_len)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# Apply mask
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax attention weights
attn_weights = F.softmax(scores, dim=-1)
# Handle NaN (all positions masked)
attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
# Dropout
if dropout_p > 0.0:
attn_weights = F.dropout(attn_weights, p=dropout_p)
# Weighted sum: (batch, heads, seq_len, d_v)
output = torch.matmul(attn_weights, value)
return output, attn_weights
# Quick test
batch_size = 2
num_heads = 8
seq_len = 10
d_k = 64
q = torch.randn(batch_size, num_heads, seq_len, d_k)
k = torch.randn(batch_size, num_heads, seq_len, d_k)
v = torch.randn(batch_size, num_heads, seq_len, d_k)
# Causal mask (lower triangular)
causal_mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0)
output, weights = scaled_dot_product_attention(q, k, v, mask=causal_mask)
print(f"Output shape: {output.shape}") # (2, 8, 10, 64)
print(f"Weights shape: {weights.shape}") # (2, 8, 10, 10)
3. Multi-Head Attention
3.1 Why Multiple Heads?
Single-head attention processes the input in only one representation subspace. Multi-Head Attention runs multiple independent attention operations in parallel, each learning different aspects of token relationships:
- Head 1: syntactic relations (subject-verb agreement)
- Head 2: semantic similarity (synonyms, related concepts)
- Head 3: positional relations (neighboring tokens)
- Head 4: coreference (pronouns and their referents)
Each head uses d_k = d_model / num_heads, so the total computation is similar to a single full-sized attention.
3.2 Formula
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) * W_O
head_i = Attention(Q * W_Q_i, K * W_K_i, V * W_V_i)
3.3 Complete PyTorch Implementation
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Single large projection for efficiency
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
self._init_weights()
def _init_weights(self):
for module in [self.W_q, self.W_k, self.W_v, self.W_o]:
nn.init.xavier_uniform_(module.weight)
def split_heads(self, x: torch.Tensor) -> torch.Tensor:
"""(batch, seq_len, d_model) -> (batch, num_heads, seq_len, d_k)"""
batch_size, seq_len, _ = x.size()
x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
return x.transpose(1, 2)
def combine_heads(self, x: torch.Tensor) -> torch.Tensor:
"""(batch, num_heads, seq_len, d_k) -> (batch, seq_len, d_model)"""
batch_size, _, seq_len, _ = x.size()
x = x.transpose(1, 2).contiguous()
return x.view(batch_size, seq_len, self.d_model)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = None,
) -> tuple:
"""
Args:
query, key, value: (batch, seq_len, d_model)
mask: (batch, 1, 1, seq_len) or (batch, 1, seq_len, seq_len)
Returns:
output: (batch, seq_len, d_model)
attn_weights: (batch, num_heads, seq_len, seq_len)
"""
Q = self.split_heads(self.W_q(query))
K = self.split_heads(self.W_k(key))
V = self.split_heads(self.W_v(value))
attn_output, attn_weights = scaled_dot_product_attention(
Q, K, V, mask=mask, dropout_p=self.dropout.p if self.training else 0.0
)
output = self.combine_heads(attn_output)
output = self.W_o(output)
return output, attn_weights
# Test
mha = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(2, 10, 512)
out, weights = mha(x, x, x)
print(f"MHA output: {out.shape}") # (2, 10, 512)
print(f"Attn weights: {weights.shape}") # (2, 8, 10, 10)
4. Positional Encoding
4.1 Why Position Information Matters
Attention is permutation-invariant: if you shuffle all input tokens, the attention scores are identical. "I ate rice" and "Rice ate I" would produce the same attention patterns without positional encoding. We need to inject sequence order information.
4.2 Sinusoidal Positional Encoding
The original Transformer uses fixed sinusoidal functions — no learned parameters:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
Where pos is the token position and i is the dimension index.
Advantages:
- No parameters to learn
- Can extrapolate to sequences longer than seen during training
- PE(pos+k) can be expressed as a linear function of PE(pos), encoding relative position
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_seq_len: int = 5000, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len).unsqueeze(1).float()
# 10000^(2i/d_model) = exp(2i * ln(10000) / d_model)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term) # even dims: sin
pe[:, 1::2] = torch.cos(position * div_term) # odd dims: cos
pe = pe.unsqueeze(0) # (1, max_seq_len, d_model) for batch broadcasting
self.register_buffer('pe', pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""x: (batch, seq_len, d_model)"""
seq_len = x.size(1)
x = x + self.pe[:, :seq_len, :]
return self.dropout(x)
4.3 RoPE (Rotary Position Embedding)
RoPE has become the standard positional encoding for modern LLMs (LLaMA, GPT-NeoX, PaLM, Mistral). The key idea is to encode position by applying rotation matrices to Q and K vectors.
Core properties:
- Encodes relative rather than absolute positions
- Q and K dot products automatically depend on relative position
- Excellent extrapolation to longer sequences
- Applied only to Q and K — not V
def precompute_freqs_cis(dim: int, max_seq_len: int, theta: float = 10000.0):
"""Precompute RoPE frequency matrix"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len)
freqs = torch.outer(t, freqs) # (max_seq_len, dim/2)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex unit vectors
return freqs_cis
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
"""Apply RoPE to query and key tensors"""
xq_r = xq.float().reshape(*xq.shape[:-1], -1, 2)
xk_r = xk.float().reshape(*xk.shape[:-1], -1, 2)
xq_complex = torch.view_as_complex(xq_r)
xk_complex = torch.view_as_complex(xk_r)
freqs_cis = freqs_cis[:xq.shape[1]]
xq_out = torch.view_as_real(xq_complex * freqs_cis).flatten(-2)
xk_out = torch.view_as_real(xk_complex * freqs_cis).flatten(-2)
return xq_out.type_as(xq), xk_out.type_as(xk)
4.4 ALiBi (Attention with Linear Biases)
ALiBi adds a linear position bias to attention scores:
Score = Q * K^T / sqrt(d_k) - m * |i - j|
Where m is a head-specific slope. No position embeddings are needed, and it generalizes very well to longer sequences than seen during training.
5. Transformer Encoder
5.1 Encoder Architecture
The Encoder stacks N identical layers, each with two sub-layers:
- Multi-Head Self-Attention
- Position-wise Feed-Forward Network
Each sub-layer is wrapped with a residual connection and layer normalization.
5.2 Feed-Forward Network
The FFN is a two-layer MLP applied independently to each position:
FFN(x) = max(0, x * W_1 + b_1) * W_2 + b_2
The original paper uses d_model=512, d_ff=2048. Modern LLMs use SwiGLU and d_ff ≈ 2.67 * d_model.
5.3 Pre-LN vs Post-LN
# Post-LN (original Transformer)
x = LayerNorm(x + Sublayer(x))
# Pre-LN (modern standard)
x = x + Sublayer(LayerNorm(x))
Pre-LN is more training-stable and does not require learning rate warmup, which is why modern models have adopted it.
5.4 Full Encoder Implementation
class FeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class EncoderLayer(nn.Module):
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
# Pre-LN Self-Attention
attn_out, _ = self.self_attn(
self.norm1(x), self.norm1(x), self.norm1(x), mask=mask
)
x = x + self.dropout(attn_out)
# Pre-LN Feed-Forward
x = x + self.dropout(self.ffn(self.norm2(x)))
return x
class TransformerEncoder(nn.Module):
def __init__(
self,
vocab_size: int,
d_model: int = 512,
num_heads: int = 8,
num_layers: int = 6,
d_ff: int = 2048,
max_seq_len: int = 512,
dropout: float = 0.1,
):
super().__init__()
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len, dropout)
self.layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
"""x: (batch, seq_len) token indices"""
x = self.embedding(x) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)
6. Transformer Decoder
6.1 Decoder Architecture
The Decoder has three sub-layers per layer:
- Masked Multi-Head Self-Attention (with causal mask)
- Multi-Head Cross-Attention (attends to Encoder output)
- Feed-Forward Network
6.2 Cross-Attention
In Cross-Attention:
- Query comes from the Decoder's current state
- Keys and Values come from the Encoder's output
This is how the Decoder learns which parts of the source sequence to focus on when generating each output token.
6.3 Autoregressive Generation
During inference, the Decoder generates tokens one at a time:
- Start with a [BOS] token
- Use all previously generated tokens as input
- Predict the next token
- Repeat until [EOS] is generated
class DecoderLayer(nn.Module):
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
encoder_output: torch.Tensor,
src_mask: torch.Tensor = None,
tgt_mask: torch.Tensor = None,
) -> torch.Tensor:
# 1. Masked Self-Attention (causal)
self_attn_out, _ = self.self_attn(
self.norm1(x), self.norm1(x), self.norm1(x), mask=tgt_mask
)
x = x + self.dropout(self_attn_out)
# 2. Cross-Attention over Encoder output
cross_attn_out, _ = self.cross_attn(
self.norm2(x), encoder_output, encoder_output, mask=src_mask
)
x = x + self.dropout(cross_attn_out)
# 3. Feed-Forward
x = x + self.dropout(self.ffn(self.norm3(x)))
return x
7. Full Transformer Implementation
class Transformer(nn.Module):
"""Complete Encoder-Decoder Transformer"""
def __init__(
self,
src_vocab_size: int,
tgt_vocab_size: int,
d_model: int = 512,
num_heads: int = 8,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
d_ff: int = 2048,
max_seq_len: int = 512,
dropout: float = 0.1,
):
super().__init__()
self.d_model = d_model
self.src_embedding = nn.Embedding(src_vocab_size, d_model)
self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len, dropout)
self.encoder_layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_encoder_layers)
])
self.decoder_layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_decoder_layers)
])
self.encoder_norm = nn.LayerNorm(d_model)
self.decoder_norm = nn.LayerNorm(d_model)
self.output_projection = nn.Linear(d_model, tgt_vocab_size)
self._init_weights()
def _init_weights(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def encode(self, src: torch.Tensor, src_mask: torch.Tensor = None) -> torch.Tensor:
x = self.src_embedding(src) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
for layer in self.encoder_layers:
x = layer(x, src_mask)
return self.encoder_norm(x)
def decode(
self,
tgt: torch.Tensor,
encoder_output: torch.Tensor,
src_mask: torch.Tensor = None,
tgt_mask: torch.Tensor = None,
) -> torch.Tensor:
x = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
for layer in self.decoder_layers:
x = layer(x, encoder_output, src_mask, tgt_mask)
return self.decoder_norm(x)
def forward(
self,
src: torch.Tensor,
tgt: torch.Tensor,
src_mask: torch.Tensor = None,
tgt_mask: torch.Tensor = None,
) -> torch.Tensor:
encoder_output = self.encode(src, src_mask)
decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)
return self.output_projection(decoder_output)
@torch.no_grad()
def generate(
self,
src: torch.Tensor,
bos_token_id: int,
eos_token_id: int,
max_new_tokens: int = 100,
temperature: float = 1.0,
) -> torch.Tensor:
"""Greedy decoding"""
self.eval()
device = src.device
src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
encoder_output = self.encode(src, src_mask)
tgt = torch.tensor([[bos_token_id]], device=device)
for _ in range(max_new_tokens):
seq_len = tgt.size(1)
tgt_mask = torch.tril(
torch.ones(seq_len, seq_len, device=device)
).unsqueeze(0).unsqueeze(0)
logits = self.decode(tgt, encoder_output, src_mask, tgt_mask)
next_token_logits = logits[:, -1, :] / temperature
next_token = next_token_logits.argmax(dim=-1, keepdim=True)
tgt = torch.cat([tgt, next_token], dim=1)
if next_token.item() == eos_token_id:
break
return tgt
# Create and test model
model = Transformer(src_vocab_size=32000, tgt_vocab_size=32000)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}") # ~44M
src = torch.randint(1, 32000, (2, 20))
tgt = torch.randint(1, 32000, (2, 15))
logits = model(src, tgt)
print(f"Logits shape: {logits.shape}") # (2, 15, 32000)
8. Transformer Variants
8.1 BERT (Encoder-only)
BERT (Bidirectional Encoder Representations from Transformers) uses only the Encoder with two pretraining objectives:
Masked Language Modeling (MLM): 15% of tokens are masked at random. The model predicts the masked tokens using bidirectional context. This makes BERT excellent for understanding tasks.
Next Sentence Prediction (NSP): Predicts whether two sentences are consecutive. Later research showed NSP contributes little, and RoBERTa dropped it without performance loss.
BERT excels at classification, QA, and NER but cannot directly generate text.
8.2 GPT (Decoder-only)
GPT (Generative Pre-trained Transformer) uses only the Decoder stack with causal language modeling: predict the next token from all previous tokens.
The architecture is a simplified Decoder with no cross-attention — just masked self-attention and FFN. GPT-2, GPT-3, GPT-4, LLaMA, Mistral, and most modern LLMs follow this Decoder-only design.
8.3 T5 and BART (Encoder-Decoder)
T5 (Text-To-Text Transfer Transformer) unifies all NLP tasks as text-to-text. Translation, summarization, classification, and QA all use identical input-output format.
BART is pretrained as a denoising autoencoder with various corruption strategies including token masking, sentence shuffling, and document rotation.
8.4 Vision Transformer (ViT)
ViT splits an image into 16×16 patches, linearly projects each patch into a token embedding, adds positional embeddings, and feeds the sequence into a standard Transformer Encoder.
With large-scale pretraining, ViT matches and surpasses CNNs on image classification benchmarks.
9. Modern LLM Optimizations
9.1 RMSNorm
Modern LLMs replace LayerNorm with RMSNorm:
RMSNorm(x) = x / RMS(x) * g
RMS(x) = sqrt(mean(x^2) + epsilon)
No mean subtraction needed — it is faster with comparable performance. Used in LLaMA, Mistral, and Gemma.
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).sqrt()
return x / rms * self.weight
9.2 SwiGLU Activation
SwiGLU combines Swish and Gated Linear Units:
SwiGLU(x, W, V) = Swish(x * W) * (x * V)
Swish(x) = x * sigmoid(x) = x * sigma(x)
The FFN dimension is typically adjusted to d_ff = (2/3) _ 4 _ d_model ≈ 2.67 * d_model.
class SwiGLUFeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int = None):
super().__init__()
if d_ff is None:
d_ff = int(2 * 4 * d_model / 3)
d_ff = ((d_ff + 63) // 64) * 64 # round to multiple of 64
self.w1 = nn.Linear(d_model, d_ff, bias=False)
self.w2 = nn.Linear(d_ff, d_model, bias=False)
self.w3 = nn.Linear(d_model, d_ff, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
9.3 Grouped Query Attention (GQA)
MHA: every head has independent Q, K, V. MQA: all heads share K, V. GQA: G groups share K, V within each group.
LLaMA 2, Mistral, and Gemma use GQA to reduce KV cache size while maintaining quality.
class GroupedQueryAttention(nn.Module):
def __init__(
self,
d_model: int,
num_heads: int,
num_kv_heads: int,
dropout: float = 0.0,
):
super().__init__()
assert num_heads % num_kv_heads == 0
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.num_rep = num_heads // num_kv_heads
self.d_head = d_model // num_heads
self.wq = nn.Linear(d_model, num_heads * self.d_head, bias=False)
self.wk = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
self.wv = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
self.wo = nn.Linear(num_heads * self.d_head, d_model, bias=False)
self.dropout = dropout
def repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
"""Repeat KV heads to match Q head count"""
if self.num_rep == 1:
return x
batch, n_kv, seq_len, d_head = x.shape
return x.unsqueeze(2).expand(
batch, n_kv, self.num_rep, seq_len, d_head
).reshape(batch, n_kv * self.num_rep, seq_len, d_head)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
batch, seq_len, _ = x.shape
xq = self.wq(x).view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
xk = self.wk(x).view(batch, seq_len, self.num_kv_heads, self.d_head).transpose(1, 2)
xv = self.wv(x).view(batch, seq_len, self.num_kv_heads, self.d_head).transpose(1, 2)
xk = self.repeat_kv(xk)
xv = self.repeat_kv(xv)
scores = torch.matmul(xq, xk.transpose(-2, -1)) / math.sqrt(self.d_head)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
attn = F.dropout(attn, p=self.dropout, training=self.training)
out = torch.matmul(attn, xv)
out = out.transpose(1, 2).contiguous().view(batch, seq_len, -1)
return self.wo(out)
9.4 KV Cache
During autoregressive generation, recomputing K and V for all previous tokens at every step is wasteful. KV Cache stores previously computed K and V tensors and reuses them:
class KVCacheAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.d_head = d_model // num_heads
self.wq = nn.Linear(d_model, d_model, bias=False)
self.wk = nn.Linear(d_model, d_model, bias=False)
self.wv = nn.Linear(d_model, d_model, bias=False)
self.wo = nn.Linear(d_model, d_model, bias=False)
self.cache_k = None
self.cache_v = None
def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
batch, seq_len, _ = x.shape
xq = self.wq(x).view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
xk = self.wk(x).view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
xv = self.wv(x).view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
if start_pos > 0 and self.cache_k is not None:
self.cache_k = torch.cat([self.cache_k, xk], dim=2)
self.cache_v = torch.cat([self.cache_v, xv], dim=2)
else:
self.cache_k = xk
self.cache_v = xv
scores = torch.matmul(xq, self.cache_k.transpose(-2, -1)) / math.sqrt(self.d_head)
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, self.cache_v)
out = out.transpose(1, 2).contiguous().view(batch, seq_len, -1)
return self.wo(out)
10. Flash Attention
10.1 The Memory Problem with Standard Attention
The attention matrix is O(N^2) in memory. For N=8192, the attention matrix alone requires 8192 x 8192 x 4 bytes ≈ 256MB in FP32. Writing and reading this matrix to/from GPU HBM creates a bandwidth bottleneck.
Modern GPUs are heavily FLOP-bound: they can perform far more arithmetic than they can move data. Standard Attention is memory-bandwidth-bound, meaning it runs much slower than its FLOP count suggests.
10.2 IO-Aware Attention Algorithm
Flash Attention (Dao et al., 2022) computes exact Attention without ever materializing the full attention matrix in HBM.
Core Idea: Tiling
Split Q, K, V into blocks that fit in SRAM (fast on-chip memory). Process one block at a time using the Online Softmax algorithm, which accumulates the correct softmax result without seeing all scores at once.
Algorithm sketch:
- Load Q_i block to SRAM
- For each K_j, V_j block: load to SRAM, compute S_ij = Q_i * K_j^T
- Use running max/sum to update softmax incrementally
- Accumulate into output O_i
Complexity:
- Memory: O(N) instead of O(N^2) — no full attention matrix stored
- FLOPs: O(N^2) — same as standard
- Wall-clock speed: 2-4x faster on A100 due to far fewer HBM reads/writes
10.3 Flash Attention Versions
Flash Attention 1 (2022): First IO-aware Attention formalization. Custom CUDA kernels for forward and backward pass.
Flash Attention 2 (2023): Outer loop over Q, inner loop over K/V (better parallelism). Optimized warp-level work partitioning. Roughly 2x additional speedup.
Flash Attention 3 (2024): Exploits H100's WGMMA (Warp Group Matrix Multiply-Accumulate) and TMA (Tensor Memory Accelerator) asynchronous copy. Another 1.5-2x speedup.
10.4 Usage
import torch
import torch.nn.functional as F
# PyTorch 2.0+ built-in Flash Attention
# On CUDA, automatically uses Flash Attention when possible
q = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
# is_causal=True uses causal mask efficiently
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False
):
output = F.scaled_dot_product_attention(q, k, v, is_causal=True)
print(f"Output shape: {output.shape}") # (2, 8, 1024, 64)
# Direct flash-attn package usage
# pip install flash-attn --no-build-isolation
try:
from flash_attn import flash_attn_qkvpacked_func
qkv = torch.randn(2, 1024, 3, 8, 64, device='cuda', dtype=torch.float16)
out = flash_attn_qkvpacked_func(qkv, dropout_p=0.0, causal=True)
print(f"Flash Attention output: {out.shape}")
except ImportError:
print("flash-attn not installed")
class FlashAttentionMHA(nn.Module):
"""MHA using PyTorch's built-in Flash Attention"""
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.0):
super().__init__()
self.num_heads = num_heads
self.d_head = d_model // num_heads
self.wqkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.wo = nn.Linear(d_model, d_model, bias=False)
self.dropout = dropout
def forward(self, x: torch.Tensor, is_causal: bool = False) -> torch.Tensor:
batch, seq_len, d_model = x.shape
qkv = self.wqkv(x).view(batch, seq_len, 3, self.num_heads, self.d_head)
q, k, v = qkv.unbind(dim=2)
q = q.transpose(1, 2) # (batch, heads, seq_len, d_head)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
)
out = out.transpose(1, 2).contiguous().view(batch, seq_len, d_model)
return self.wo(out)
11. Mixture of Experts (MoE)
11.1 The Core Idea
Mixture of Experts scales model capacity (parameter count) while keeping active computation constant per forward pass. The key: each token is routed to only a small subset of "expert" networks.
For an 8-expert MoE with Top-2 routing:
- Total parameters: 8 expert FFNs = 8x a dense model's FFN
- Active parameters per token: 2 experts = same compute as a dense model
This achieves a favorable tradeoff: larger models (more capacity/knowledge) without proportionally larger inference cost.
11.2 Top-k Routing
class MoELayer(nn.Module):
"""Mixture of Experts FFN layer"""
def __init__(
self,
d_model: int,
d_ff: int,
num_experts: int = 8,
top_k: int = 2,
dropout: float = 0.0,
):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
# Router / gating network
self.router = nn.Linear(d_model, num_experts, bias=False)
# Experts
self.experts = nn.ModuleList([
SwiGLUFeedForward(d_model, d_ff)
for _ in range(num_experts)
])
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> tuple:
batch, seq_len, d_model = x.shape
x_flat = x.view(-1, d_model)
router_logits = self.router(x_flat)
router_probs = F.softmax(router_logits, dim=-1)
topk_probs, topk_indices = router_probs.topk(self.top_k, dim=-1)
topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True)
output = torch.zeros_like(x_flat)
for i in range(self.top_k):
expert_idx = topk_indices[:, i]
prob = topk_probs[:, i:i+1]
for e in range(self.num_experts):
token_mask = (expert_idx == e)
if token_mask.any():
expert_input = x_flat[token_mask]
expert_output = self.experts[e](expert_input)
output[token_mask] += prob[token_mask] * expert_output
aux_loss = self._load_balancing_loss(router_probs)
return self.dropout(output.view(batch, seq_len, d_model)), aux_loss
def _load_balancing_loss(self, router_probs: torch.Tensor) -> torch.Tensor:
"""Switch Transformer auxiliary loss to prevent expert collapse"""
num_tokens = router_probs.size(0)
avg_probs = router_probs.mean(dim=0)
top1_indices = router_probs.argmax(dim=-1)
expert_counts = torch.bincount(top1_indices, minlength=self.num_experts).float()
expert_fractions = expert_counts / num_tokens
return self.num_experts * (avg_probs * expert_fractions).sum()
11.3 Mixtral and DeepSeek
Mixtral 8x7B:
- 8 expert FFNs, Top-2 routing
- Total parameters: 46.7B; Active parameters per token: 12.9B
- Uses standard MHA + RoPE + SwiGLU + Grouped experts
DeepSeek MoE:
- Fine-grained expert segmentation: more, smaller experts
- Shared expert: a base expert that processes all tokens plus routed experts
- Advanced expert collapse prevention strategies
11.4 MoE Tradeoffs
Advantages:
- Larger model capacity per FLOPs
- Specialization of experts for different token types
- Efficient scaling
Disadvantages:
- All expert parameters must fit in memory even if only 2 are active
- Inter-device communication overhead when experts are sharded across GPUs (expert parallelism)
- Training instability from imbalanced routing without auxiliary losses
Summary
The Transformer has evolved from a seq2seq translation model into the universal architecture powering virtually all state-of-the-art AI systems.
Key concepts covered:
- Scaled Dot-Product Attention: soft database retrieval with Q/K/V
- Multi-Head Attention: parallel attention in multiple subspaces
- Positional Encoding: sinusoidal PE → RoPE → ALiBi
- Encoder/Decoder: foundation for BERT and GPT families
- Modern Optimizations: RMSNorm, SwiGLU, GQA, KV Cache
- Flash Attention: IO-aware memory-efficient exact attention
- MoE: sparse activation for efficient scaling
Next steps to explore:
- Speculative decoding for inference acceleration
- LoRA/QLoRA fine-tuning
- Alignment techniques (RLHF, DPO)
- Production serving with vLLM and TensorRT-LLM
References
- Vaswani et al. (2017). "Attention Is All You Need." — https://arxiv.org/abs/1706.03762
- Su et al. (2022). "RoFormer: Enhanced Transformer with Rotary Position Embedding." — https://arxiv.org/abs/2104.09864
- Dao et al. (2022). "FlashAttention: Fast and Memory-Efficient Exact Attention." — https://arxiv.org/abs/2205.14135
- Dao (2023). "FlashAttention-2: Faster Attention with Better Parallelism." — https://arxiv.org/abs/2307.08691
- Ainslie et al. (2023). "GQA: Training Generalized Multi-Query Transformer Models." — https://arxiv.org/abs/2305.13245
- Shazeer (2020). "GLU Variants Improve Transformer." — https://arxiv.org/abs/2002.05202
- Jiang et al. (2024). "Mixtral of Experts." — https://arxiv.org/abs/2401.04088
- PyTorch scaled_dot_product_attention — https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
- Flash Attention GitHub — https://github.com/Dao-AILab/flash-attention