Skip to content
Published on

Dissecting the Transformer — From Attention to KV Cache

Authors

Introduction — Why the Transformer, Still

Since the 2017 paper "Attention Is All You Need," the Transformer has become the default backbone of nearly every deep learning field — not just NLP, but vision, speech, code, and multimodal systems. As of 2026, almost every large language model (LLM) we use is essentially a variant of the Transformer decoder. Modern techniques like GQA, RoPE, FlashAttention, and MoE are all refinements bolted onto the original; the core skeleton has not changed.

Because of this, once you truly understand the Transformer, reading new models and papers becomes a matter of spotting "what changed." Conversely, if your grasp of self-attention and the KV cache is fuzzy, you will get stuck at every step — serving optimization, fine-tuning, and beyond.

This article dissects one Transformer block from start to finish. It connects, as a single flow, what tensor shape each operation takes in and produces, how many parameters appear, and why the KV cache shows up at inference. Equations are written in plain notation without dollar signs, and code is shown in working form.

The Full Structure at a Glance

The flow from an input token sequence to the next-token distribution looks like this.

[input token IDs]  (integer sequence, length N)
      |
      v
[token embedding + positional encoding]   ->  X: (B, N, D)
      |
      v
+---------------------------+
|  Transformer block x L     |
|                           |
|  LayerNorm                |
|  Multi-Head Attention     |
|  + residual               |
|                           |
|  LayerNorm                |
|  Feed-Forward Network     |
|  + residual               |
+---------------------------+
      |
      v
[final LayerNorm]
      |
      v
[unembedding / LM head]   ->  logits: (B, N, V)
      |
      v
[softmax]                 ->  next-token probability distribution

The symbols are used as follows.

B = batch size
N = sequence length (number of tokens)
D = model dimension (d_model)
H = number of attention heads
d_k = dimension per head = D / H
V = vocabulary size
L = number of blocks (layers)

GPT-family models stack only the decoder block L times in the diagram above. Encoder-decoder structures (the original paper, T5, etc.) keep separate encoder and decoder stacks. This article focuses on the decoder while clearly pointing out the differences from the encoder.

1. Embedding — Turning Tokens into Vectors

The very first step turns integer token IDs into D-dimensional vectors. This is simply a lookup table of size (V, D).

import torch
import torch.nn as nn

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model

    def forward(self, token_ids):
        # token_ids: (B, N)  ->  (B, N, D)
        return self.embed(token_ids) * (self.d_model ** 0.5)

The original paper multiplies the embedding by sqrt(D) to match scale — a convention to keep its magnitude similar to the positional encoding. The number of embedding parameters is V * D. For a 50k vocabulary and D=4096, embeddings alone take roughly 200 million parameters.

2. Positional Encoding — Injecting Order

Self-attention by itself has no sense of order; permuting input tokens leaves the set of attention outputs unchanged (permutation equivariance). So positional information must be injected explicitly. The most classic approach is absolute positional encoding built from sine and cosine functions.

PE(pos, 2i)   = sin( pos / 10000^(2i / D) )
PE(pos, 2i+1) = cos( pos / 10000^(2i / D) )

pos = position index (0, 1, 2, ...)
i   = dimension index

Each position is given a unique pattern from a combination of sine waves at different frequencies, which is added to the embedding. Modern models more often use RoPE (rotary position encoding) or ALiBi (attention bias) instead of absolute positions. We cover that in a separate article; here, just remember that positional information is always injected somewhere into the attention input.

3. Self-Attention — The Core Mechanism

Query, Key, Value

The intuition behind self-attention is this. Each token poses "what am I looking for" (Query), every token holds a label "this is what I am" (Key), and the actual content to retrieve sits in the Value. The higher the similarity between a Query and a Key, the more of that Value is retrieved.

For input X: (B, N, D), three linear transformations produce Q, K, V.

Q = X · W_Q     W_Q: (D, D)   ->  Q: (B, N, D)
K = X · W_K     W_K: (D, D)   ->  K: (B, N, D)
V = X · W_V     W_V: (D, D)   ->  V: (B, N, D)

Scaled Dot-Product Attention

The core equation is as follows (written without dollar signs).

Attention(Q, K, V) = softmax( (Q · K^T) / sqrt(d_k) ) · V

Q · K^T : (B, N, N)  -- similarity scores for all token pairs
/ sqrt(d_k) : scaling (prevents scores from growing too large)
softmax : normalize along rows -> attention weights
· V : produce output as a weighted average  ->  (B, N, D)

The reason for dividing by sqrt(d_k) matters. When d_k is large, the variance of the dot product between Q and K grows in proportion to d_k; if the softmax input becomes too large, it gets pushed into a near-zero-gradient (saturated) region and training destabilizes. Dividing by the standard deviation keeps the variance near 1.

import torch
import torch.nn.functional as F

def scaled_dot_product_attention(q, k, v, mask=None):
    # q, k, v: (B, H, N, d_k)
    d_k = q.size(-1)
    scores = torch.matmul(q, k.transpose(-2, -1)) / (d_k ** 0.5)
    # scores: (B, H, N, N)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    weights = F.softmax(scores, dim=-1)
    out = torch.matmul(weights, v)  # (B, H, N, d_k)
    return out, weights

Causal Mask — Hiding the Future

In a decoder (GPT family), the token at position i must not see tokens that come after i, because that would amount to peeking at the answer during training. So the upper-triangular part (future positions) of the attention score matrix is filled with negative infinity, making the post-softmax weights zero.

causal mask (N=4, 1=allowed, 0=blocked):

        key0  key1  key2  key3
query0   1     0     0     0
query1   1     1     0     0
query2   1     1     1     0
query3   1     1     1     1

The encoder (BERT, the encoder stack of the original paper) has no such mask. It is bidirectional attention where every token can see every other. This difference is the fundamental distinction between decoder (generation) and encoder (understanding/representation) models.

4. Multi-Head Attention — Seeing with Many Eyes

Using a single attention head lets the model learn only one perspective. Multi-head attention splits the D dimensions into H heads, letting each head perform attention independently in d_k = D / H dimensions. One head may learn grammatical dependencies, another semantic associations — roles become specialized.

1) Split Q, K, V into H heads:
   (B, N, D)  ->  (B, N, H, d_k)  ->  (B, H, N, d_k)

2) Scaled dot-product attention per head:
   (B, H, N, d_k)  ->  (B, H, N, d_k)

3) Merge heads back:
   (B, H, N, d_k)  ->  (B, N, H, d_k)  ->  (B, N, D)

4) Output projection:
   (B, N, D) · W_O  ->  (B, N, D)
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        B, N, D = x.shape
        H, d_k = self.num_heads, self.d_k

        def split_heads(t):
            return t.view(B, N, H, d_k).transpose(1, 2)  # (B, H, N, d_k)

        q = split_heads(self.w_q(x))
        k = split_heads(self.w_k(x))
        v = split_heads(self.w_v(x))

        out, _ = scaled_dot_product_attention(q, k, v, mask)
        # (B, H, N, d_k) -> (B, N, D)
        out = out.transpose(1, 2).contiguous().view(B, N, D)
        return self.w_o(out)

Multi-head attention has four (D, D) matrices — W_Q, W_K, W_V, W_O — so, ignoring biases, the parameter count is 4 * D * D. For D=4096, that is about 67 million.

5. Feed-Forward Network — Per-Position Transformation

This is a 2-layer MLP applied independently at each position to the attention output. It usually expands the intermediate dimension to 4 times D and then shrinks it back.

FFN(x) = activation(x · W_1 + b_1) · W_2 + b_2

W_1: (D, 4D)   ->  intermediate (B, N, 4D)
activation: ReLU / GELU / SwiGLU, etc.
W_2: (4D, D)   ->  back to (B, N, D)

Modern models often use SwiGLU (a gated variant) as the activation. The FFN parameter count is roughly 2 * D * 4D = 8 * D * D — in fact, the FFN (8 D^2) holds more parameters in a block than attention (4 D^2). MoE is exactly the technique of splitting this FFN into multiple experts and activating only some of them.

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.act = nn.GELU()

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))

6. Residual Connections and LayerNorm — Stabilizing Deep Nets

Stacking dozens of blocks causes vanishing gradients and training instability. Two devices solve this. First, the residual connection adds the input directly to the sublayer output, creating a shortcut so gradients can pass through a deep net. Second, LayerNorm normalizes each token vector to stabilize the distribution of activations.

The original paper used a post-norm structure with normalization after the sublayer, but this destabilizes deep models, so nearly everyone now uses pre-norm, which normalizes before the sublayer.

post-norm (original paper):
   x = LayerNorm(x + Sublayer(x))

pre-norm (modern standard):
   x = x + Sublayer(LayerNorm(x))

Pre-norm is stable because the residual path is added without passing through normalization, so the signal propagates undistorted even in deep nets. Many modern models use RMSNorm, which omits mean subtraction, to save a bit of computation over LayerNorm.

class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = FeedForward(d_model, d_ff)

    def forward(self, x, mask=None):
        # pre-norm structure
        x = x + self.attn(self.ln1(x), mask)
        x = x + self.ffn(self.ln2(x))
        return x

7. Tensor Flow Through One Block

Here is how shapes change while passing through a single decoder block.

input x                         : (B, N, D)
LayerNorm(x)                    : (B, N, D)
  W_Q/W_K/W_V projection         : (B, N, D) each
  split heads                    : (B, H, N, d_k)
  Q·K^T                          : (B, H, N, N)
  softmax · V                    : (B, H, N, d_k)
  merge heads                    : (B, N, D)
  W_O projection                 : (B, N, D)
add residual                    : (B, N, D)
LayerNorm(x)                    : (B, N, D)
  FFN expand                     : (B, N, 4D)
  FFN shrink                     : (B, N, D)
add residual                    : (B, N, D)
output                          : (B, N, D)

After passing through a block, the shape is preserved as (B, N, D). That is why the same block can be repeated and stacked L times.

8. Estimating the Parameter Count

The main parameters of one block are summarized below.

ComponentParameter count (approx.)
Attention (W_Q, W_K, W_V, W_O)4 times D times D
FFN (expand + shrink)8 times D times D
Block totalabout 12 times D times D

For D=4096 and L=32 blocks, each block is about 12 times 4096 times 4096 = about 200 million, and the total is about 6.4 billion. Adding embeddings (V times D) gives the scale of the "7B model" we often hear about. This estimate is useful for intuitively gauging model size.

9. Inference and the KV Cache — Why It Is Needed

During training, the entire sequence is processed in parallel at once. But inference (generation) produces tokens one at a time, autoregressively. Recomputing attention from scratch for each generated token wastefully recomputes the Key/Value of all previous tokens every time.

The KV cache stores the K and V of already-generated tokens in memory, so when a new token arrives, only its Q is recomputed and attended against the cached K, V.

without KV cache (recompute):
   when generating token t -> recompute K, V for tokens 1..t  (O(t) work)

with KV cache:
   when generating token t -> compute only K, V of token t and append to cache
   cached K, V for 1..t-1 are reused as is  (O(1) extra work)

The memory occupied by the KV cache can be estimated as follows.

KV cache size (bytes)
 = 2(K and V) x L(layers) x N(sequence length) x D(dimension)
   x B(batch) x bytes_per_element

ex: L=32, N=8192, D=4096, B=1, FP16 (2 bytes)
 = 2 x 32 x 8192 x 4096 x 1 x 2
 about 4.3 GB

As the sequence lengthens and the batch grows, the KV cache grows linearly, and in long-context serving the KV cache can consume more memory than the model weights. The techniques to solve this are exactly GQA/MQA (shrinking the cache by sharing KV heads), PagedAttention (virtual-memory-style block management), and so on — covered in depth in the follow-up article.

10. Encoder, Decoder, and the Variants

Here we summarize the difference between the original paper's encoder-decoder structure and the decoder-only structure that dominates today.

AspectEncoderDecoder (GPT family)Encoder-Decoder (T5, etc.)
Attention directionbidirectionalcausalencoder bidirectional, decoder causal + cross-attention
Masknonecausal maskcausal mask in decoder
Typical useunderstanding/representation (classification, embeddings)generationtranslation, summarization, sequence-to-sequence
Representative modelsBERTGPT, Llama, QwenT5, BART

Cross-attention has the decoder attend using the encoder's output as Key/Value, which suits tasks like translation where input (source) and output (target) differ. Pure generative models, by contrast, stack only decoder blocks — simple yet powerful.

11. Following Attention Through a Small Numeric Example

Abstract equations alone are hard to grasp. Let us trace the flow with a tiny example. Say we have 3 tokens, dimension D=4, and 1 head.

input X: (N=3, D=4)
 x0 = [1, 0, 1, 0]
 x1 = [0, 1, 0, 1]
 x2 = [1, 1, 0, 0]

1) Build Q, K, V via linear transforms (weights are learned values)
   -> Q, K, V: each (3, 4)

2) scores = Q · K^T  ->  (3, 3) matrix
   each element score[i][j] is the raw score of how much token i attends to token j

3) scale by dividing by sqrt(d_k)=2

4) apply causal mask (if a decoder):
   score[0][1], score[0][2], score[1][2] -> negative infinity

5) softmax per row -> attention weights (each row sums to 1)

6) weights · V  ->  output (3, 4)
   output of token i = weighted average of the V of itself and past tokens

The key is that each row of the score matrix is a distribution of "how much this token references other tokens," and the causal mask zeroes out the future in that distribution. The output is ultimately a mix of past tokens' Values tailored to one's own focus.

12. The Output Layer — From Logits to Tokens

The (B, N, D) representation that passes through the last block goes through the LM head (usually a (D, V) matrix that often shares weights with the embedding) to become scores (logits) over the entire vocabulary.

final representation: (B, N, D)
multiply by LM head (D, V):  ->  logits: (B, N, V)
softmax (at the last position):  ->  next-token probability distribution (V,)

At generation time, only the logits of the last position are used to pick the next token. The way you pick (sampling) governs output diversity.

decoding strategy summary:
 greedy   : pick the highest-probability token each step -> deterministic, can be monotonous
 temperature: divide logits by T to flatten (T>1) / sharpen (T<1) the distribution
 top-k    : sample only from the top k tokens
 top-p    : sample only from tokens up to cumulative probability p (nucleus)
import torch
import torch.nn.functional as F

def sample_next_token(logits, temperature=1.0, top_k=None):
    # logits: (V,)
    logits = logits / max(temperature, 1e-6)
    if top_k is not None:
        v, _ = torch.topk(logits, top_k)
        logits[logits < v[-1]] = float('-inf')
    probs = F.softmax(logits, dim=-1)
    return torch.multinomial(probs, num_samples=1)

This output layer and decoding strategy are the final link that turns a trained Transformer into an actual generator. Even with the same model structure, how you sample greatly changes the character of the results.

13. The Difference in Compute Patterns Between Training and Inference

Even for the same Transformer, training and inference have different compute characteristics. Knowing this difference reveals the starting point for serving optimization.

AspectTraining (or prefill)Inference decode stage
Processing unitfull sequence in parallelone token at a time, sequentially
Bottleneckcompute (matmul) boundmemory bandwidth bound
Matrix shapelarge matmul (N x N)thin matrix-vector product
Optimization pointFLOPs, tensor-core utilizationKV cache, memory bandwidth

It is important that the decode stage of inference is memory-bound. Each generated token requires reading the model weights and the KV cache from memory, so memory reads — not compute — become the bottleneck. That is why quantization (smaller weights) and KV cache optimization have a large effect on decode throughput. The 2026 serving stack's continuous (in-flight) batching, paged KV cache, FP8/INT4 quantization, and speculative decoding all attack this memory-bound characteristic.

Pitfalls and Troubleshooting

  • Missing softmax scaling: Forgetting to divide by sqrt(d_k) makes training diverge when head dimensions are large, or causes attention to concentrate excessively on one token.
  • Wrong mask placement: The causal mask must fill scores with negative infinity before softmax. Multiplying by zero after softmax breaks normalization.
  • Confusing pre-norm and post-norm: Training a deep model from scratch with post-norm tends to diverge without warmup. Use pre-norm unless you have a specific reason not to.
  • KV cache dtype mismatch: If weights are FP16 but the KV cache is FP32, memory doubles. Match the dtype unless you are deliberately quantizing.
  • Positional encoding length overflow: Using absolute positional encoding on inputs longer than the training length causes a sharp drop in performance. For long contexts, consider RoPE-family methods and extrapolation techniques.
  • Missing transpose in head splitting: Forgetting the transpose that turns (B, N, H, d_k) into (B, H, N, d_k) silently mixes the head and sequence dimensions, producing wrong results quietly.

Closing

A Transformer block is ultimately the repetition of three simple ideas: (1) mixing information across tokens with attention, (2) doing a per-position nonlinear transformation with the FFN, and (3) stabilizing the deep net with residuals and normalization. On top of that, positional encoding injects order, and at inference the KV cache avoids recomputation.

With this skeleton in hand, everything falls into place cleanly: RoPE is a positional-encoding variant, GQA/MQA make the KV side of attention efficient, FlashAttention is an IO optimization of the attention computation, and MoE is conditional activation of the FFN. In the next articles we dive deep into the evolution of attention (MQA/GQA/FlashAttention) and into positional encoding (RoPE).

References