- Published on
Dissecting the Transformer — From Attention to KV Cache
- Authors

- Name
- Youngju Kim
- @fjvbn20031
- Introduction — Why the Transformer, Still
- The Full Structure at a Glance
- 1. Embedding — Turning Tokens into Vectors
- 2. Positional Encoding — Injecting Order
- 3. Self-Attention — The Core Mechanism
- 4. Multi-Head Attention — Seeing with Many Eyes
- 5. Feed-Forward Network — Per-Position Transformation
- 6. Residual Connections and LayerNorm — Stabilizing Deep Nets
- 7. Tensor Flow Through One Block
- 8. Estimating the Parameter Count
- 9. Inference and the KV Cache — Why It Is Needed
- 10. Encoder, Decoder, and the Variants
- 11. Following Attention Through a Small Numeric Example
- 12. The Output Layer — From Logits to Tokens
- 13. The Difference in Compute Patterns Between Training and Inference
- Pitfalls and Troubleshooting
- Closing
- References
Introduction — Why the Transformer, Still
Since the 2017 paper "Attention Is All You Need," the Transformer has become the default backbone of nearly every deep learning field — not just NLP, but vision, speech, code, and multimodal systems. As of 2026, almost every large language model (LLM) we use is essentially a variant of the Transformer decoder. Modern techniques like GQA, RoPE, FlashAttention, and MoE are all refinements bolted onto the original; the core skeleton has not changed.
Because of this, once you truly understand the Transformer, reading new models and papers becomes a matter of spotting "what changed." Conversely, if your grasp of self-attention and the KV cache is fuzzy, you will get stuck at every step — serving optimization, fine-tuning, and beyond.
This article dissects one Transformer block from start to finish. It connects, as a single flow, what tensor shape each operation takes in and produces, how many parameters appear, and why the KV cache shows up at inference. Equations are written in plain notation without dollar signs, and code is shown in working form.
The Full Structure at a Glance
The flow from an input token sequence to the next-token distribution looks like this.
[input token IDs] (integer sequence, length N)
|
v
[token embedding + positional encoding] -> X: (B, N, D)
|
v
+---------------------------+
| Transformer block x L |
| |
| LayerNorm |
| Multi-Head Attention |
| + residual |
| |
| LayerNorm |
| Feed-Forward Network |
| + residual |
+---------------------------+
|
v
[final LayerNorm]
|
v
[unembedding / LM head] -> logits: (B, N, V)
|
v
[softmax] -> next-token probability distribution
The symbols are used as follows.
B = batch size
N = sequence length (number of tokens)
D = model dimension (d_model)
H = number of attention heads
d_k = dimension per head = D / H
V = vocabulary size
L = number of blocks (layers)
GPT-family models stack only the decoder block L times in the diagram above. Encoder-decoder structures (the original paper, T5, etc.) keep separate encoder and decoder stacks. This article focuses on the decoder while clearly pointing out the differences from the encoder.
1. Embedding — Turning Tokens into Vectors
The very first step turns integer token IDs into D-dimensional vectors. This is simply a lookup table of size (V, D).
import torch
import torch.nn as nn
class TokenEmbedding(nn.Module):
def __init__(self, vocab_size, d_model):
super().__init__()
self.embed = nn.Embedding(vocab_size, d_model)
self.d_model = d_model
def forward(self, token_ids):
# token_ids: (B, N) -> (B, N, D)
return self.embed(token_ids) * (self.d_model ** 0.5)
The original paper multiplies the embedding by sqrt(D) to match scale — a convention to keep its magnitude similar to the positional encoding. The number of embedding parameters is V * D. For a 50k vocabulary and D=4096, embeddings alone take roughly 200 million parameters.
2. Positional Encoding — Injecting Order
Self-attention by itself has no sense of order; permuting input tokens leaves the set of attention outputs unchanged (permutation equivariance). So positional information must be injected explicitly. The most classic approach is absolute positional encoding built from sine and cosine functions.
PE(pos, 2i) = sin( pos / 10000^(2i / D) )
PE(pos, 2i+1) = cos( pos / 10000^(2i / D) )
pos = position index (0, 1, 2, ...)
i = dimension index
Each position is given a unique pattern from a combination of sine waves at different frequencies, which is added to the embedding. Modern models more often use RoPE (rotary position encoding) or ALiBi (attention bias) instead of absolute positions. We cover that in a separate article; here, just remember that positional information is always injected somewhere into the attention input.
3. Self-Attention — The Core Mechanism
Query, Key, Value
The intuition behind self-attention is this. Each token poses "what am I looking for" (Query), every token holds a label "this is what I am" (Key), and the actual content to retrieve sits in the Value. The higher the similarity between a Query and a Key, the more of that Value is retrieved.
For input X: (B, N, D), three linear transformations produce Q, K, V.
Q = X · W_Q W_Q: (D, D) -> Q: (B, N, D)
K = X · W_K W_K: (D, D) -> K: (B, N, D)
V = X · W_V W_V: (D, D) -> V: (B, N, D)
Scaled Dot-Product Attention
The core equation is as follows (written without dollar signs).
Attention(Q, K, V) = softmax( (Q · K^T) / sqrt(d_k) ) · V
Q · K^T : (B, N, N) -- similarity scores for all token pairs
/ sqrt(d_k) : scaling (prevents scores from growing too large)
softmax : normalize along rows -> attention weights
· V : produce output as a weighted average -> (B, N, D)
The reason for dividing by sqrt(d_k) matters. When d_k is large, the variance of the dot product between Q and K grows in proportion to d_k; if the softmax input becomes too large, it gets pushed into a near-zero-gradient (saturated) region and training destabilizes. Dividing by the standard deviation keeps the variance near 1.
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(q, k, v, mask=None):
# q, k, v: (B, H, N, d_k)
d_k = q.size(-1)
scores = torch.matmul(q, k.transpose(-2, -1)) / (d_k ** 0.5)
# scores: (B, H, N, N)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
weights = F.softmax(scores, dim=-1)
out = torch.matmul(weights, v) # (B, H, N, d_k)
return out, weights
Causal Mask — Hiding the Future
In a decoder (GPT family), the token at position i must not see tokens that come after i, because that would amount to peeking at the answer during training. So the upper-triangular part (future positions) of the attention score matrix is filled with negative infinity, making the post-softmax weights zero.
causal mask (N=4, 1=allowed, 0=blocked):
key0 key1 key2 key3
query0 1 0 0 0
query1 1 1 0 0
query2 1 1 1 0
query3 1 1 1 1
The encoder (BERT, the encoder stack of the original paper) has no such mask. It is bidirectional attention where every token can see every other. This difference is the fundamental distinction between decoder (generation) and encoder (understanding/representation) models.
4. Multi-Head Attention — Seeing with Many Eyes
Using a single attention head lets the model learn only one perspective. Multi-head attention splits the D dimensions into H heads, letting each head perform attention independently in d_k = D / H dimensions. One head may learn grammatical dependencies, another semantic associations — roles become specialized.
1) Split Q, K, V into H heads:
(B, N, D) -> (B, N, H, d_k) -> (B, H, N, d_k)
2) Scaled dot-product attention per head:
(B, H, N, d_k) -> (B, H, N, d_k)
3) Merge heads back:
(B, H, N, d_k) -> (B, N, H, d_k) -> (B, N, D)
4) Output projection:
(B, N, D) · W_O -> (B, N, D)
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
B, N, D = x.shape
H, d_k = self.num_heads, self.d_k
def split_heads(t):
return t.view(B, N, H, d_k).transpose(1, 2) # (B, H, N, d_k)
q = split_heads(self.w_q(x))
k = split_heads(self.w_k(x))
v = split_heads(self.w_v(x))
out, _ = scaled_dot_product_attention(q, k, v, mask)
# (B, H, N, d_k) -> (B, N, D)
out = out.transpose(1, 2).contiguous().view(B, N, D)
return self.w_o(out)
Multi-head attention has four (D, D) matrices — W_Q, W_K, W_V, W_O — so, ignoring biases, the parameter count is 4 * D * D. For D=4096, that is about 67 million.
5. Feed-Forward Network — Per-Position Transformation
This is a 2-layer MLP applied independently at each position to the attention output. It usually expands the intermediate dimension to 4 times D and then shrinks it back.
FFN(x) = activation(x · W_1 + b_1) · W_2 + b_2
W_1: (D, 4D) -> intermediate (B, N, 4D)
activation: ReLU / GELU / SwiGLU, etc.
W_2: (4D, D) -> back to (B, N, D)
Modern models often use SwiGLU (a gated variant) as the activation. The FFN parameter count is roughly 2 * D * 4D = 8 * D * D — in fact, the FFN (8 D^2) holds more parameters in a block than attention (4 D^2). MoE is exactly the technique of splitting this FFN into multiple experts and activating only some of them.
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
self.act = nn.GELU()
def forward(self, x):
return self.fc2(self.act(self.fc1(x)))
6. Residual Connections and LayerNorm — Stabilizing Deep Nets
Stacking dozens of blocks causes vanishing gradients and training instability. Two devices solve this. First, the residual connection adds the input directly to the sublayer output, creating a shortcut so gradients can pass through a deep net. Second, LayerNorm normalizes each token vector to stabilize the distribution of activations.
The original paper used a post-norm structure with normalization after the sublayer, but this destabilizes deep models, so nearly everyone now uses pre-norm, which normalizes before the sublayer.
post-norm (original paper):
x = LayerNorm(x + Sublayer(x))
pre-norm (modern standard):
x = x + Sublayer(LayerNorm(x))
Pre-norm is stable because the residual path is added without passing through normalization, so the signal propagates undistorted even in deep nets. Many modern models use RMSNorm, which omits mean subtraction, to save a bit of computation over LayerNorm.
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, d_ff):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = MultiHeadAttention(d_model, num_heads)
self.ln2 = nn.LayerNorm(d_model)
self.ffn = FeedForward(d_model, d_ff)
def forward(self, x, mask=None):
# pre-norm structure
x = x + self.attn(self.ln1(x), mask)
x = x + self.ffn(self.ln2(x))
return x
7. Tensor Flow Through One Block
Here is how shapes change while passing through a single decoder block.
input x : (B, N, D)
LayerNorm(x) : (B, N, D)
W_Q/W_K/W_V projection : (B, N, D) each
split heads : (B, H, N, d_k)
Q·K^T : (B, H, N, N)
softmax · V : (B, H, N, d_k)
merge heads : (B, N, D)
W_O projection : (B, N, D)
add residual : (B, N, D)
LayerNorm(x) : (B, N, D)
FFN expand : (B, N, 4D)
FFN shrink : (B, N, D)
add residual : (B, N, D)
output : (B, N, D)
After passing through a block, the shape is preserved as (B, N, D). That is why the same block can be repeated and stacked L times.
8. Estimating the Parameter Count
The main parameters of one block are summarized below.
| Component | Parameter count (approx.) |
|---|---|
| Attention (W_Q, W_K, W_V, W_O) | 4 times D times D |
| FFN (expand + shrink) | 8 times D times D |
| Block total | about 12 times D times D |
For D=4096 and L=32 blocks, each block is about 12 times 4096 times 4096 = about 200 million, and the total is about 6.4 billion. Adding embeddings (V times D) gives the scale of the "7B model" we often hear about. This estimate is useful for intuitively gauging model size.
9. Inference and the KV Cache — Why It Is Needed
During training, the entire sequence is processed in parallel at once. But inference (generation) produces tokens one at a time, autoregressively. Recomputing attention from scratch for each generated token wastefully recomputes the Key/Value of all previous tokens every time.
The KV cache stores the K and V of already-generated tokens in memory, so when a new token arrives, only its Q is recomputed and attended against the cached K, V.
without KV cache (recompute):
when generating token t -> recompute K, V for tokens 1..t (O(t) work)
with KV cache:
when generating token t -> compute only K, V of token t and append to cache
cached K, V for 1..t-1 are reused as is (O(1) extra work)
The memory occupied by the KV cache can be estimated as follows.
KV cache size (bytes)
= 2(K and V) x L(layers) x N(sequence length) x D(dimension)
x B(batch) x bytes_per_element
ex: L=32, N=8192, D=4096, B=1, FP16 (2 bytes)
= 2 x 32 x 8192 x 4096 x 1 x 2
about 4.3 GB
As the sequence lengthens and the batch grows, the KV cache grows linearly, and in long-context serving the KV cache can consume more memory than the model weights. The techniques to solve this are exactly GQA/MQA (shrinking the cache by sharing KV heads), PagedAttention (virtual-memory-style block management), and so on — covered in depth in the follow-up article.
10. Encoder, Decoder, and the Variants
Here we summarize the difference between the original paper's encoder-decoder structure and the decoder-only structure that dominates today.
| Aspect | Encoder | Decoder (GPT family) | Encoder-Decoder (T5, etc.) |
|---|---|---|---|
| Attention direction | bidirectional | causal | encoder bidirectional, decoder causal + cross-attention |
| Mask | none | causal mask | causal mask in decoder |
| Typical use | understanding/representation (classification, embeddings) | generation | translation, summarization, sequence-to-sequence |
| Representative models | BERT | GPT, Llama, Qwen | T5, BART |
Cross-attention has the decoder attend using the encoder's output as Key/Value, which suits tasks like translation where input (source) and output (target) differ. Pure generative models, by contrast, stack only decoder blocks — simple yet powerful.
11. Following Attention Through a Small Numeric Example
Abstract equations alone are hard to grasp. Let us trace the flow with a tiny example. Say we have 3 tokens, dimension D=4, and 1 head.
input X: (N=3, D=4)
x0 = [1, 0, 1, 0]
x1 = [0, 1, 0, 1]
x2 = [1, 1, 0, 0]
1) Build Q, K, V via linear transforms (weights are learned values)
-> Q, K, V: each (3, 4)
2) scores = Q · K^T -> (3, 3) matrix
each element score[i][j] is the raw score of how much token i attends to token j
3) scale by dividing by sqrt(d_k)=2
4) apply causal mask (if a decoder):
score[0][1], score[0][2], score[1][2] -> negative infinity
5) softmax per row -> attention weights (each row sums to 1)
6) weights · V -> output (3, 4)
output of token i = weighted average of the V of itself and past tokens
The key is that each row of the score matrix is a distribution of "how much this token references other tokens," and the causal mask zeroes out the future in that distribution. The output is ultimately a mix of past tokens' Values tailored to one's own focus.
12. The Output Layer — From Logits to Tokens
The (B, N, D) representation that passes through the last block goes through the LM head (usually a (D, V) matrix that often shares weights with the embedding) to become scores (logits) over the entire vocabulary.
final representation: (B, N, D)
multiply by LM head (D, V): -> logits: (B, N, V)
softmax (at the last position): -> next-token probability distribution (V,)
At generation time, only the logits of the last position are used to pick the next token. The way you pick (sampling) governs output diversity.
decoding strategy summary:
greedy : pick the highest-probability token each step -> deterministic, can be monotonous
temperature: divide logits by T to flatten (T>1) / sharpen (T<1) the distribution
top-k : sample only from the top k tokens
top-p : sample only from tokens up to cumulative probability p (nucleus)
import torch
import torch.nn.functional as F
def sample_next_token(logits, temperature=1.0, top_k=None):
# logits: (V,)
logits = logits / max(temperature, 1e-6)
if top_k is not None:
v, _ = torch.topk(logits, top_k)
logits[logits < v[-1]] = float('-inf')
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1)
This output layer and decoding strategy are the final link that turns a trained Transformer into an actual generator. Even with the same model structure, how you sample greatly changes the character of the results.
13. The Difference in Compute Patterns Between Training and Inference
Even for the same Transformer, training and inference have different compute characteristics. Knowing this difference reveals the starting point for serving optimization.
| Aspect | Training (or prefill) | Inference decode stage |
|---|---|---|
| Processing unit | full sequence in parallel | one token at a time, sequentially |
| Bottleneck | compute (matmul) bound | memory bandwidth bound |
| Matrix shape | large matmul (N x N) | thin matrix-vector product |
| Optimization point | FLOPs, tensor-core utilization | KV cache, memory bandwidth |
It is important that the decode stage of inference is memory-bound. Each generated token requires reading the model weights and the KV cache from memory, so memory reads — not compute — become the bottleneck. That is why quantization (smaller weights) and KV cache optimization have a large effect on decode throughput. The 2026 serving stack's continuous (in-flight) batching, paged KV cache, FP8/INT4 quantization, and speculative decoding all attack this memory-bound characteristic.
Pitfalls and Troubleshooting
- Missing softmax scaling: Forgetting to divide by sqrt(d_k) makes training diverge when head dimensions are large, or causes attention to concentrate excessively on one token.
- Wrong mask placement: The causal mask must fill scores with negative infinity before softmax. Multiplying by zero after softmax breaks normalization.
- Confusing pre-norm and post-norm: Training a deep model from scratch with post-norm tends to diverge without warmup. Use pre-norm unless you have a specific reason not to.
- KV cache dtype mismatch: If weights are FP16 but the KV cache is FP32, memory doubles. Match the dtype unless you are deliberately quantizing.
- Positional encoding length overflow: Using absolute positional encoding on inputs longer than the training length causes a sharp drop in performance. For long contexts, consider RoPE-family methods and extrapolation techniques.
- Missing transpose in head splitting: Forgetting the transpose that turns
(B, N, H, d_k)into(B, H, N, d_k)silently mixes the head and sequence dimensions, producing wrong results quietly.
Closing
A Transformer block is ultimately the repetition of three simple ideas: (1) mixing information across tokens with attention, (2) doing a per-position nonlinear transformation with the FFN, and (3) stabilizing the deep net with residuals and normalization. On top of that, positional encoding injects order, and at inference the KV cache avoids recomputation.
With this skeleton in hand, everything falls into place cleanly: RoPE is a positional-encoding variant, GQA/MQA make the KV side of attention efficient, FlashAttention is an IO optimization of the attention computation, and MoE is conditional activation of the FFN. In the next articles we dive deep into the evolution of attention (MQA/GQA/FlashAttention) and into positional encoding (RoPE).
References
- Vaswani et al., "Attention Is All You Need" (arxiv 1706.03762): https://arxiv.org/abs/1706.03762
- Dao et al., "FlashAttention" (arxiv 2205.14135): https://arxiv.org/abs/2205.14135
- PyTorch Transformer docs: https://pytorch.org/docs/stable/nn.html#transformer-layers
- Hugging Face Transformers docs: https://huggingface.co/docs/transformers/index
- vLLM official docs (KV cache, serving): https://docs.vllm.ai
- vLLM repository: https://github.com/vllm-project/vllm
- The Illustrated Transformer (Jay Alammar): https://jalammar.github.io/illustrated-transformer/
- Qwen model repository: https://github.com/QwenLM