Skip to content
Published on

Transformer Architecture Complete Analysis: From Attention to Modern LLMs

Authors

Transformer Architecture Complete Analysis: From Attention to Modern LLMs

Google's 2017 paper "Attention is All You Need" completely transformed the natural language processing landscape. Moving away from sequential RNN and LSTM architectures, the Transformer — built entirely on the Attention mechanism — became the foundation for GPT, BERT, T5, LLaMA, and every major modern LLM. This guide takes you from first principles through the complete architecture, with PyTorch implementations at every step.


1. The Origins of Attention

1.1 Limitations of RNN/LSTM

Before Transformers, sequence modeling relied heavily on RNNs and LSTMs. These models propagated information through hidden states processed one time step at a time, with fundamental limitations:

Long-Range Dependency Problem

RNNs struggle to learn relationships between words that are far apart. In "The cat that sat on the mat was hungry," the connection between "cat" and "was hungry" must pass through every intermediate token. As sequences grow longer, information fades or gets overwritten. LSTMs alleviate this with gating, but provide no complete solution.

Sequential Processing

RNNs require step t to complete before step t+1 begins. This means the parallel processing power of modern GPUs goes completely unused for sequence-internal computation. You can parallelize across batches, but not within a sequence.

Vanishing/Exploding Gradients

Backpropagating through hundreds or thousands of time steps causes gradients to either shrink to zero or grow uncontrollably. Even with LSTM gates, sequences beyond a few hundred tokens are problematic.

1.2 The Intuition Behind Attention

The Attention mechanism draws inspiration from how humans read. We do not give equal attention to every word — we focus more on words relevant to what we are currently processing.

Consider: "I saw the Eiffel Tower in Paris; it was breathtaking." When processing "it," the model should attend strongly to both "Eiffel Tower" and "Paris." Every word can directly interact with every other word, regardless of distance.

Bahdanau et al. (2014) introduced the first Attention mechanism as an auxiliary component in a seq2seq model. Vaswani et al. (2017) then made the decisive move: eliminate the RNN entirely and build everything from Attention alone.


2. Scaled Dot-Product Attention

2.1 The Q, K, V Framework

The Transformer uses three vectors — Query, Key, and Value — to implement Attention. Think of it as a soft database lookup:

  • Query (Q): "What am I looking for?" — the current position's representation
  • Key (K): "What information do I hold?" — each position's label/identifier
  • Value (V): "What is actually stored?" — each position's content

We compute similarity scores between Queries and Keys, normalize them with softmax into attention weights, then use those weights to produce a weighted sum of the Values.

2.2 The Formula

Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V

Where:

  • Q: Query matrix (seq_len x d_k)
  • K: Key matrix (seq_len x d_k)
  • V: Value matrix (seq_len x d_v)
  • d_k: Key/Query dimension
  • sqrt(d_k): scaling factor

Why scale by sqrt(d_k)?

As d_k grows, dot products grow proportionally in magnitude. With d_k=512, raw dot products can be very large, pushing softmax into regions with near-zero gradients. Dividing by sqrt(d_k) keeps the variance of dot products close to 1 (assuming q and k have unit variance components), preventing gradient issues.

2.3 Masking

Padding Mask: When sequences in a batch have different lengths, shorter sequences are padded. We add -inf to padded positions so softmax assigns them zero weight.

Causal Mask (Look-ahead Mask): Used in Decoders. Position i should only attend to positions 0 through i — never future positions. We fill the upper triangle of the score matrix with -inf.

2.4 PyTorch Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


def scaled_dot_product_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    mask: torch.Tensor = None,
    dropout_p: float = 0.0,
) -> tuple:
    """
    Scaled Dot-Product Attention

    Args:
        query: (batch, heads, seq_len, d_k)
        key:   (batch, heads, seq_len, d_k)
        value: (batch, heads, seq_len, d_v)
        mask:  (batch, 1, 1, seq_len) or (batch, 1, seq_len, seq_len)
        dropout_p: dropout probability

    Returns:
        output: (batch, heads, seq_len, d_v)
        attn_weights: (batch, heads, seq_len, seq_len)
    """
    d_k = query.size(-1)

    # Q * K^T / sqrt(d_k): (batch, heads, seq_len, seq_len)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

    # Apply mask
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # Softmax attention weights
    attn_weights = F.softmax(scores, dim=-1)

    # Handle NaN (all positions masked)
    attn_weights = torch.nan_to_num(attn_weights, nan=0.0)

    # Dropout
    if dropout_p > 0.0:
        attn_weights = F.dropout(attn_weights, p=dropout_p)

    # Weighted sum: (batch, heads, seq_len, d_v)
    output = torch.matmul(attn_weights, value)

    return output, attn_weights


# Quick test
batch_size = 2
num_heads = 8
seq_len = 10
d_k = 64

q = torch.randn(batch_size, num_heads, seq_len, d_k)
k = torch.randn(batch_size, num_heads, seq_len, d_k)
v = torch.randn(batch_size, num_heads, seq_len, d_k)

# Causal mask (lower triangular)
causal_mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0)

output, weights = scaled_dot_product_attention(q, k, v, mask=causal_mask)
print(f"Output shape: {output.shape}")       # (2, 8, 10, 64)
print(f"Weights shape: {weights.shape}")     # (2, 8, 10, 10)

3. Multi-Head Attention

3.1 Why Multiple Heads?

Single-head attention processes the input in only one representation subspace. Multi-Head Attention runs multiple independent attention operations in parallel, each learning different aspects of token relationships:

  • Head 1: syntactic relations (subject-verb agreement)
  • Head 2: semantic similarity (synonyms, related concepts)
  • Head 3: positional relations (neighboring tokens)
  • Head 4: coreference (pronouns and their referents)

Each head uses d_k = d_model / num_heads, so the total computation is similar to a single full-sized attention.

3.2 Formula

MultiHead(Q, K, V) = Concat(head_1, ..., head_h) * W_O

head_i = Attention(Q * W_Q_i, K * W_K_i, V * W_V_i)

3.3 Complete PyTorch Implementation

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # Single large projection for efficiency
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

        self.dropout = nn.Dropout(dropout)
        self._init_weights()

    def _init_weights(self):
        for module in [self.W_q, self.W_k, self.W_v, self.W_o]:
            nn.init.xavier_uniform_(module.weight)

    def split_heads(self, x: torch.Tensor) -> torch.Tensor:
        """(batch, seq_len, d_model) -> (batch, num_heads, seq_len, d_k)"""
        batch_size, seq_len, _ = x.size()
        x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
        return x.transpose(1, 2)

    def combine_heads(self, x: torch.Tensor) -> torch.Tensor:
        """(batch, num_heads, seq_len, d_k) -> (batch, seq_len, d_model)"""
        batch_size, _, seq_len, _ = x.size()
        x = x.transpose(1, 2).contiguous()
        return x.view(batch_size, seq_len, self.d_model)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: torch.Tensor = None,
    ) -> tuple:
        """
        Args:
            query, key, value: (batch, seq_len, d_model)
            mask: (batch, 1, 1, seq_len) or (batch, 1, seq_len, seq_len)
        Returns:
            output: (batch, seq_len, d_model)
            attn_weights: (batch, num_heads, seq_len, seq_len)
        """
        Q = self.split_heads(self.W_q(query))
        K = self.split_heads(self.W_k(key))
        V = self.split_heads(self.W_v(value))

        attn_output, attn_weights = scaled_dot_product_attention(
            Q, K, V, mask=mask, dropout_p=self.dropout.p if self.training else 0.0
        )

        output = self.combine_heads(attn_output)
        output = self.W_o(output)

        return output, attn_weights


# Test
mha = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(2, 10, 512)
out, weights = mha(x, x, x)
print(f"MHA output: {out.shape}")      # (2, 10, 512)
print(f"Attn weights: {weights.shape}")  # (2, 8, 10, 10)

4. Positional Encoding

4.1 Why Position Information Matters

Attention is permutation-invariant: if you shuffle all input tokens, the attention scores are identical. "I ate rice" and "Rice ate I" would produce the same attention patterns without positional encoding. We need to inject sequence order information.

4.2 Sinusoidal Positional Encoding

The original Transformer uses fixed sinusoidal functions — no learned parameters:

PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

Where pos is the token position and i is the dimension index.

Advantages:

  • No parameters to learn
  • Can extrapolate to sequences longer than seen during training
  • PE(pos+k) can be expressed as a linear function of PE(pos), encoding relative position
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_seq_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len).unsqueeze(1).float()

        # 10000^(2i/d_model) = exp(2i * ln(10000) / d_model)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )

        pe[:, 0::2] = torch.sin(position * div_term)  # even dims: sin
        pe[:, 1::2] = torch.cos(position * div_term)  # odd dims: cos

        pe = pe.unsqueeze(0)  # (1, max_seq_len, d_model) for batch broadcasting
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """x: (batch, seq_len, d_model)"""
        seq_len = x.size(1)
        x = x + self.pe[:, :seq_len, :]
        return self.dropout(x)

4.3 RoPE (Rotary Position Embedding)

RoPE has become the standard positional encoding for modern LLMs (LLaMA, GPT-NeoX, PaLM, Mistral). The key idea is to encode position by applying rotation matrices to Q and K vectors.

Core properties:

  • Encodes relative rather than absolute positions
  • Q and K dot products automatically depend on relative position
  • Excellent extrapolation to longer sequences
  • Applied only to Q and K — not V
def precompute_freqs_cis(dim: int, max_seq_len: int, theta: float = 10000.0):
    """Precompute RoPE frequency matrix"""
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(max_seq_len)
    freqs = torch.outer(t, freqs)  # (max_seq_len, dim/2)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex unit vectors
    return freqs_cis


def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
    """Apply RoPE to query and key tensors"""
    xq_r = xq.float().reshape(*xq.shape[:-1], -1, 2)
    xk_r = xk.float().reshape(*xk.shape[:-1], -1, 2)

    xq_complex = torch.view_as_complex(xq_r)
    xk_complex = torch.view_as_complex(xk_r)

    freqs_cis = freqs_cis[:xq.shape[1]]
    xq_out = torch.view_as_real(xq_complex * freqs_cis).flatten(-2)
    xk_out = torch.view_as_real(xk_complex * freqs_cis).flatten(-2)

    return xq_out.type_as(xq), xk_out.type_as(xk)

4.4 ALiBi (Attention with Linear Biases)

ALiBi adds a linear position bias to attention scores:

Score = Q * K^T / sqrt(d_k) - m * |i - j|

Where m is a head-specific slope. No position embeddings are needed, and it generalizes very well to longer sequences than seen during training.


5. Transformer Encoder

5.1 Encoder Architecture

The Encoder stacks N identical layers, each with two sub-layers:

  1. Multi-Head Self-Attention
  2. Position-wise Feed-Forward Network

Each sub-layer is wrapped with a residual connection and layer normalization.

5.2 Feed-Forward Network

The FFN is a two-layer MLP applied independently to each position:

FFN(x) = max(0, x * W_1 + b_1) * W_2 + b_2

The original paper uses d_model=512, d_ff=2048. Modern LLMs use SwiGLU and d_ff ≈ 2.67 * d_model.

5.3 Pre-LN vs Post-LN

# Post-LN (original Transformer)
x = LayerNorm(x + Sublayer(x))

# Pre-LN (modern standard)
x = x + Sublayer(LayerNorm(x))

Pre-LN is more training-stable and does not require learning rate warmup, which is why modern models have adopted it.

5.4 Full Encoder Implementation

class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class EncoderLayer(nn.Module):
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        # Pre-LN Self-Attention
        attn_out, _ = self.self_attn(
            self.norm1(x), self.norm1(x), self.norm1(x), mask=mask
        )
        x = x + self.dropout(attn_out)

        # Pre-LN Feed-Forward
        x = x + self.dropout(self.ffn(self.norm2(x)))
        return x


class TransformerEncoder(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 512,
        num_heads: int = 8,
        num_layers: int = 6,
        d_ff: int = 2048,
        max_seq_len: int = 512,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len, dropout)
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        """x: (batch, seq_len) token indices"""
        x = self.embedding(x) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

6. Transformer Decoder

6.1 Decoder Architecture

The Decoder has three sub-layers per layer:

  1. Masked Multi-Head Self-Attention (with causal mask)
  2. Multi-Head Cross-Attention (attends to Encoder output)
  3. Feed-Forward Network

6.2 Cross-Attention

In Cross-Attention:

  • Query comes from the Decoder's current state
  • Keys and Values come from the Encoder's output

This is how the Decoder learns which parts of the source sequence to focus on when generating each output token.

6.3 Autoregressive Generation

During inference, the Decoder generates tokens one at a time:

  1. Start with a [BOS] token
  2. Use all previously generated tokens as input
  3. Predict the next token
  4. Repeat until [EOS] is generated
class DecoderLayer(nn.Module):
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        encoder_output: torch.Tensor,
        src_mask: torch.Tensor = None,
        tgt_mask: torch.Tensor = None,
    ) -> torch.Tensor:
        # 1. Masked Self-Attention (causal)
        self_attn_out, _ = self.self_attn(
            self.norm1(x), self.norm1(x), self.norm1(x), mask=tgt_mask
        )
        x = x + self.dropout(self_attn_out)

        # 2. Cross-Attention over Encoder output
        cross_attn_out, _ = self.cross_attn(
            self.norm2(x), encoder_output, encoder_output, mask=src_mask
        )
        x = x + self.dropout(cross_attn_out)

        # 3. Feed-Forward
        x = x + self.dropout(self.ffn(self.norm3(x)))
        return x

7. Full Transformer Implementation

class Transformer(nn.Module):
    """Complete Encoder-Decoder Transformer"""

    def __init__(
        self,
        src_vocab_size: int,
        tgt_vocab_size: int,
        d_model: int = 512,
        num_heads: int = 8,
        num_encoder_layers: int = 6,
        num_decoder_layers: int = 6,
        d_ff: int = 2048,
        max_seq_len: int = 512,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.d_model = d_model

        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len, dropout)

        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_encoder_layers)
        ])
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_decoder_layers)
        ])

        self.encoder_norm = nn.LayerNorm(d_model)
        self.decoder_norm = nn.LayerNorm(d_model)
        self.output_projection = nn.Linear(d_model, tgt_vocab_size)

        self._init_weights()

    def _init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def encode(self, src: torch.Tensor, src_mask: torch.Tensor = None) -> torch.Tensor:
        x = self.src_embedding(src) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        for layer in self.encoder_layers:
            x = layer(x, src_mask)
        return self.encoder_norm(x)

    def decode(
        self,
        tgt: torch.Tensor,
        encoder_output: torch.Tensor,
        src_mask: torch.Tensor = None,
        tgt_mask: torch.Tensor = None,
    ) -> torch.Tensor:
        x = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        for layer in self.decoder_layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.decoder_norm(x)

    def forward(
        self,
        src: torch.Tensor,
        tgt: torch.Tensor,
        src_mask: torch.Tensor = None,
        tgt_mask: torch.Tensor = None,
    ) -> torch.Tensor:
        encoder_output = self.encode(src, src_mask)
        decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)
        return self.output_projection(decoder_output)

    @torch.no_grad()
    def generate(
        self,
        src: torch.Tensor,
        bos_token_id: int,
        eos_token_id: int,
        max_new_tokens: int = 100,
        temperature: float = 1.0,
    ) -> torch.Tensor:
        """Greedy decoding"""
        self.eval()
        device = src.device

        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        encoder_output = self.encode(src, src_mask)
        tgt = torch.tensor([[bos_token_id]], device=device)

        for _ in range(max_new_tokens):
            seq_len = tgt.size(1)
            tgt_mask = torch.tril(
                torch.ones(seq_len, seq_len, device=device)
            ).unsqueeze(0).unsqueeze(0)

            logits = self.decode(tgt, encoder_output, src_mask, tgt_mask)
            next_token_logits = logits[:, -1, :] / temperature
            next_token = next_token_logits.argmax(dim=-1, keepdim=True)
            tgt = torch.cat([tgt, next_token], dim=1)

            if next_token.item() == eos_token_id:
                break

        return tgt


# Create and test model
model = Transformer(src_vocab_size=32000, tgt_vocab_size=32000)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")  # ~44M

src = torch.randint(1, 32000, (2, 20))
tgt = torch.randint(1, 32000, (2, 15))
logits = model(src, tgt)
print(f"Logits shape: {logits.shape}")  # (2, 15, 32000)

8. Transformer Variants

8.1 BERT (Encoder-only)

BERT (Bidirectional Encoder Representations from Transformers) uses only the Encoder with two pretraining objectives:

Masked Language Modeling (MLM): 15% of tokens are masked at random. The model predicts the masked tokens using bidirectional context. This makes BERT excellent for understanding tasks.

Next Sentence Prediction (NSP): Predicts whether two sentences are consecutive. Later research showed NSP contributes little, and RoBERTa dropped it without performance loss.

BERT excels at classification, QA, and NER but cannot directly generate text.

8.2 GPT (Decoder-only)

GPT (Generative Pre-trained Transformer) uses only the Decoder stack with causal language modeling: predict the next token from all previous tokens.

The architecture is a simplified Decoder with no cross-attention — just masked self-attention and FFN. GPT-2, GPT-3, GPT-4, LLaMA, Mistral, and most modern LLMs follow this Decoder-only design.

8.3 T5 and BART (Encoder-Decoder)

T5 (Text-To-Text Transfer Transformer) unifies all NLP tasks as text-to-text. Translation, summarization, classification, and QA all use identical input-output format.

BART is pretrained as a denoising autoencoder with various corruption strategies including token masking, sentence shuffling, and document rotation.

8.4 Vision Transformer (ViT)

ViT splits an image into 16×16 patches, linearly projects each patch into a token embedding, adds positional embeddings, and feeds the sequence into a standard Transformer Encoder.

With large-scale pretraining, ViT matches and surpasses CNNs on image classification benchmarks.


9. Modern LLM Optimizations

9.1 RMSNorm

Modern LLMs replace LayerNorm with RMSNorm:

RMSNorm(x) = x / RMS(x) * g
RMS(x) = sqrt(mean(x^2) + epsilon)

No mean subtraction needed — it is faster with comparable performance. Used in LLaMA, Mistral, and Gemma.

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).sqrt()
        return x / rms * self.weight

9.2 SwiGLU Activation

SwiGLU combines Swish and Gated Linear Units:

SwiGLU(x, W, V) = Swish(x * W) * (x * V)
Swish(x) = x * sigmoid(x) = x * sigma(x)

The FFN dimension is typically adjusted to d_ff = (2/3) _ 4 _ d_model ≈ 2.67 * d_model.

class SwiGLUFeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int = None):
        super().__init__()
        if d_ff is None:
            d_ff = int(2 * 4 * d_model / 3)
            d_ff = ((d_ff + 63) // 64) * 64  # round to multiple of 64

        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w2 = nn.Linear(d_ff, d_model, bias=False)
        self.w3 = nn.Linear(d_model, d_ff, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

9.3 Grouped Query Attention (GQA)

MHA: every head has independent Q, K, V. MQA: all heads share K, V. GQA: G groups share K, V within each group.

LLaMA 2, Mistral, and Gemma use GQA to reduce KV cache size while maintaining quality.

class GroupedQueryAttention(nn.Module):
    def __init__(
        self,
        d_model: int,
        num_heads: int,
        num_kv_heads: int,
        dropout: float = 0.0,
    ):
        super().__init__()
        assert num_heads % num_kv_heads == 0

        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.num_rep = num_heads // num_kv_heads
        self.d_head = d_model // num_heads

        self.wq = nn.Linear(d_model, num_heads * self.d_head, bias=False)
        self.wk = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
        self.wv = nn.Linear(d_model, num_kv_heads * self.d_head, bias=False)
        self.wo = nn.Linear(num_heads * self.d_head, d_model, bias=False)
        self.dropout = dropout

    def repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
        """Repeat KV heads to match Q head count"""
        if self.num_rep == 1:
            return x
        batch, n_kv, seq_len, d_head = x.shape
        return x.unsqueeze(2).expand(
            batch, n_kv, self.num_rep, seq_len, d_head
        ).reshape(batch, n_kv * self.num_rep, seq_len, d_head)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        batch, seq_len, _ = x.shape

        xq = self.wq(x).view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
        xk = self.wk(x).view(batch, seq_len, self.num_kv_heads, self.d_head).transpose(1, 2)
        xv = self.wv(x).view(batch, seq_len, self.num_kv_heads, self.d_head).transpose(1, 2)

        xk = self.repeat_kv(xk)
        xv = self.repeat_kv(xv)

        scores = torch.matmul(xq, xk.transpose(-2, -1)) / math.sqrt(self.d_head)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn = F.softmax(scores, dim=-1)
        attn = F.dropout(attn, p=self.dropout, training=self.training)

        out = torch.matmul(attn, xv)
        out = out.transpose(1, 2).contiguous().view(batch, seq_len, -1)
        return self.wo(out)

9.4 KV Cache

During autoregressive generation, recomputing K and V for all previous tokens at every step is wasteful. KV Cache stores previously computed K and V tensors and reuses them:

class KVCacheAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        self.wq = nn.Linear(d_model, d_model, bias=False)
        self.wk = nn.Linear(d_model, d_model, bias=False)
        self.wv = nn.Linear(d_model, d_model, bias=False)
        self.wo = nn.Linear(d_model, d_model, bias=False)
        self.cache_k = None
        self.cache_v = None

    def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
        batch, seq_len, _ = x.shape

        xq = self.wq(x).view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
        xk = self.wk(x).view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
        xv = self.wv(x).view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)

        if start_pos > 0 and self.cache_k is not None:
            self.cache_k = torch.cat([self.cache_k, xk], dim=2)
            self.cache_v = torch.cat([self.cache_v, xv], dim=2)
        else:
            self.cache_k = xk
            self.cache_v = xv

        scores = torch.matmul(xq, self.cache_k.transpose(-2, -1)) / math.sqrt(self.d_head)
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, self.cache_v)
        out = out.transpose(1, 2).contiguous().view(batch, seq_len, -1)
        return self.wo(out)

10. Flash Attention

10.1 The Memory Problem with Standard Attention

The attention matrix is O(N^2) in memory. For N=8192, the attention matrix alone requires 8192 x 8192 x 4 bytes ≈ 256MB in FP32. Writing and reading this matrix to/from GPU HBM creates a bandwidth bottleneck.

Modern GPUs are heavily FLOP-bound: they can perform far more arithmetic than they can move data. Standard Attention is memory-bandwidth-bound, meaning it runs much slower than its FLOP count suggests.

10.2 IO-Aware Attention Algorithm

Flash Attention (Dao et al., 2022) computes exact Attention without ever materializing the full attention matrix in HBM.

Core Idea: Tiling

Split Q, K, V into blocks that fit in SRAM (fast on-chip memory). Process one block at a time using the Online Softmax algorithm, which accumulates the correct softmax result without seeing all scores at once.

Algorithm sketch:

  1. Load Q_i block to SRAM
  2. For each K_j, V_j block: load to SRAM, compute S_ij = Q_i * K_j^T
  3. Use running max/sum to update softmax incrementally
  4. Accumulate into output O_i

Complexity:

  • Memory: O(N) instead of O(N^2) — no full attention matrix stored
  • FLOPs: O(N^2) — same as standard
  • Wall-clock speed: 2-4x faster on A100 due to far fewer HBM reads/writes

10.3 Flash Attention Versions

Flash Attention 1 (2022): First IO-aware Attention formalization. Custom CUDA kernels for forward and backward pass.

Flash Attention 2 (2023): Outer loop over Q, inner loop over K/V (better parallelism). Optimized warp-level work partitioning. Roughly 2x additional speedup.

Flash Attention 3 (2024): Exploits H100's WGMMA (Warp Group Matrix Multiply-Accumulate) and TMA (Tensor Memory Accelerator) asynchronous copy. Another 1.5-2x speedup.

10.4 Usage

import torch
import torch.nn.functional as F

# PyTorch 2.0+ built-in Flash Attention
# On CUDA, automatically uses Flash Attention when possible
q = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)

# is_causal=True uses causal mask efficiently
with torch.backends.cuda.sdp_kernel(
    enable_flash=True,
    enable_math=False,
    enable_mem_efficient=False
):
    output = F.scaled_dot_product_attention(q, k, v, is_causal=True)

print(f"Output shape: {output.shape}")  # (2, 8, 1024, 64)

# Direct flash-attn package usage
# pip install flash-attn --no-build-isolation
try:
    from flash_attn import flash_attn_qkvpacked_func

    qkv = torch.randn(2, 1024, 3, 8, 64, device='cuda', dtype=torch.float16)
    out = flash_attn_qkvpacked_func(qkv, dropout_p=0.0, causal=True)
    print(f"Flash Attention output: {out.shape}")
except ImportError:
    print("flash-attn not installed")


class FlashAttentionMHA(nn.Module):
    """MHA using PyTorch's built-in Flash Attention"""

    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.0):
        super().__init__()
        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        self.wqkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.wo = nn.Linear(d_model, d_model, bias=False)
        self.dropout = dropout

    def forward(self, x: torch.Tensor, is_causal: bool = False) -> torch.Tensor:
        batch, seq_len, d_model = x.shape

        qkv = self.wqkv(x).view(batch, seq_len, 3, self.num_heads, self.d_head)
        q, k, v = qkv.unbind(dim=2)

        q = q.transpose(1, 2)  # (batch, heads, seq_len, d_head)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        out = F.scaled_dot_product_attention(
            q, k, v,
            dropout_p=self.dropout if self.training else 0.0,
            is_causal=is_causal,
        )

        out = out.transpose(1, 2).contiguous().view(batch, seq_len, d_model)
        return self.wo(out)

11. Mixture of Experts (MoE)

11.1 The Core Idea

Mixture of Experts scales model capacity (parameter count) while keeping active computation constant per forward pass. The key: each token is routed to only a small subset of "expert" networks.

For an 8-expert MoE with Top-2 routing:

  • Total parameters: 8 expert FFNs = 8x a dense model's FFN
  • Active parameters per token: 2 experts = same compute as a dense model

This achieves a favorable tradeoff: larger models (more capacity/knowledge) without proportionally larger inference cost.

11.2 Top-k Routing

class MoELayer(nn.Module):
    """Mixture of Experts FFN layer"""

    def __init__(
        self,
        d_model: int,
        d_ff: int,
        num_experts: int = 8,
        top_k: int = 2,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k

        # Router / gating network
        self.router = nn.Linear(d_model, num_experts, bias=False)

        # Experts
        self.experts = nn.ModuleList([
            SwiGLUFeedForward(d_model, d_ff)
            for _ in range(num_experts)
        ])

        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> tuple:
        batch, seq_len, d_model = x.shape
        x_flat = x.view(-1, d_model)

        router_logits = self.router(x_flat)
        router_probs = F.softmax(router_logits, dim=-1)

        topk_probs, topk_indices = router_probs.topk(self.top_k, dim=-1)
        topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True)

        output = torch.zeros_like(x_flat)

        for i in range(self.top_k):
            expert_idx = topk_indices[:, i]
            prob = topk_probs[:, i:i+1]

            for e in range(self.num_experts):
                token_mask = (expert_idx == e)
                if token_mask.any():
                    expert_input = x_flat[token_mask]
                    expert_output = self.experts[e](expert_input)
                    output[token_mask] += prob[token_mask] * expert_output

        aux_loss = self._load_balancing_loss(router_probs)
        return self.dropout(output.view(batch, seq_len, d_model)), aux_loss

    def _load_balancing_loss(self, router_probs: torch.Tensor) -> torch.Tensor:
        """Switch Transformer auxiliary loss to prevent expert collapse"""
        num_tokens = router_probs.size(0)
        avg_probs = router_probs.mean(dim=0)
        top1_indices = router_probs.argmax(dim=-1)
        expert_counts = torch.bincount(top1_indices, minlength=self.num_experts).float()
        expert_fractions = expert_counts / num_tokens
        return self.num_experts * (avg_probs * expert_fractions).sum()

11.3 Mixtral and DeepSeek

Mixtral 8x7B:

  • 8 expert FFNs, Top-2 routing
  • Total parameters: 46.7B; Active parameters per token: 12.9B
  • Uses standard MHA + RoPE + SwiGLU + Grouped experts

DeepSeek MoE:

  • Fine-grained expert segmentation: more, smaller experts
  • Shared expert: a base expert that processes all tokens plus routed experts
  • Advanced expert collapse prevention strategies

11.4 MoE Tradeoffs

Advantages:

  • Larger model capacity per FLOPs
  • Specialization of experts for different token types
  • Efficient scaling

Disadvantages:

  • All expert parameters must fit in memory even if only 2 are active
  • Inter-device communication overhead when experts are sharded across GPUs (expert parallelism)
  • Training instability from imbalanced routing without auxiliary losses

Summary

The Transformer has evolved from a seq2seq translation model into the universal architecture powering virtually all state-of-the-art AI systems.

Key concepts covered:

  1. Scaled Dot-Product Attention: soft database retrieval with Q/K/V
  2. Multi-Head Attention: parallel attention in multiple subspaces
  3. Positional Encoding: sinusoidal PE → RoPE → ALiBi
  4. Encoder/Decoder: foundation for BERT and GPT families
  5. Modern Optimizations: RMSNorm, SwiGLU, GQA, KV Cache
  6. Flash Attention: IO-aware memory-efficient exact attention
  7. MoE: sparse activation for efficient scaling

Next steps to explore:

  • Speculative decoding for inference acceleration
  • LoRA/QLoRA fine-tuning
  • Alignment techniques (RLHF, DPO)
  • Production serving with vLLM and TensorRT-LLM

References