- Published on
The Evolution of Attention — MQA, GQA, FlashAttention, and Long Context
- Authors

- Name
- Youngju Kim
- @fjvbn20031
- Introduction — Why Attention Is Expensive
- 1. Cost Analysis of Standard Attention
- 2. MQA and GQA — Share the KV Heads
- 3. FlashAttention — Reduce the IO
- 4. Attention Variants for Long Context
- 5. The Effect on Serving Memory
- 5.5. FlashAttention, the Causal Mask, and Later Improvements
- 6. Computing a Serving Memory Budget Yourself
- 7. Why Decode Is Memory-Bound, and What It Means
- 8. Attention Optimization by Serving Framework
- 9. A Guide to Choosing Attention Variants
- Pitfalls and Troubleshooting
- 9.5. How Attention Optimization Interacts with Other Efficiency Techniques
- 9.7. Correcting Common Misconceptions
- 10. How to Read Benchmark Numbers
- Closing
- References
Introduction — Why Attention Is Expensive
The self-attention of the Transformer is powerful but expensive. For sequence length N, it compares every pair of tokens, so the compute and intermediate memory scale with the square of N. Stretching context from 2K to 32K — a 16x increase — makes the attention score matrix 256x larger. Longer context translates directly into a cost explosion.
There is also a separate bottleneck at inference. Autoregressive generation produces tokens one at a time, so the KV cache that stores previous tokens' Keys/Values grows linearly with sequence length. In long-context serving, it is common for the KV cache to consume more GPU memory than the model weights.
This article tackles both bottlenecks — the quadratic cost of attention and the linear growth of the KV cache. MQA/GQA shrink the KV cache, FlashAttention reduces the memory IO of the attention computation, and techniques like sliding windows limit the attention scope itself.
1. Cost Analysis of Standard Attention
Compute and Memory
Here we summarize the core costs of standard multi-head attention.
symbols:
B = batch, H = number of heads, N = sequence length, d_k = dim per head, D = H x d_k
Q·K^T score matrix: (B, H, N, N)
- compute: O(B x H x N^2 x d_k)
- memory: O(B x H x N^2) <- square of N!
weighted sum with V after softmax: another O(B x H x N^2 x d_k)
The crux of the problem is the (N, N) score matrix. With N=8192, H=32, B=8, a single score matrix in FP16 reaches roughly 34GB. The real bottleneck of a naive implementation is writing and reading this enormous intermediate matrix to slow memory (HBM) rather than the GPU's fast memory (SRAM).
The KV Cache at Inference
At inference, the KV cache size is estimated as follows.
KV cache bytes
= 2 x L x N x D x B x bytes_per_element
(2 for K and V, L is the number of layers)
ex: L=32, D=4096, N=8192, B=1, FP16
about 4.3 GB
raising the batch to 32 -> about 137 GB -> exceeds single-GPU memory
In other words, even if you want to raise the batch for throughput, the KV cache eats up memory and caps the batch size. This is the direct motivation for MQA/GQA.
2. MQA and GQA — Share the KV Heads
The Core Idea
Standard multi-head attention (MHA) keeps H Key/Value heads, one for each of the H Query heads. MQA (Multi-Query Attention) flips this. It keeps all H Query heads but uses only a single Key/Value head that every Query head shares.
GQA (Grouped-Query Attention) is the middle ground. It splits the Query heads into G groups, with each group sharing one KV head. G=1 is MQA, and G=H is the same as MHA.
MHA : H Query heads, H KV heads (largest KV cache)
GQA : H Query heads, G KV heads (G < H, middle ground)
MQA : H Query heads, 1 KV head (smallest KV cache)
ex) when H=32
MHA: 32 KV heads -> baseline
GQA(G=8): 8 KV heads -> KV cache 1/4
MQA: 1 KV head -> KV cache 1/32
Why It Works
The KV cache scales with the number of KV heads. Reducing KV heads from H to G shrinks the KV cache to G/H of its size. Reducing it to 1 as in MQA shrinks it dramatically, letting you fit a longer context or a larger batch into the same memory.
The cost is expressiveness. Reducing KV heads too far (MQA) can slightly degrade quality. So in practice, GQA with something like G=8 is widely adopted as a balance point that shrinks the KV cache substantially with almost no quality loss. Many modern models such as Llama 2/3 and Qwen use GQA.
import torch
import torch.nn.functional as F
def gqa_attention(q, k, v, num_kv_groups):
# q: (B, H, N, d_k)
# k, v: (B, G, N, d_k) G = num_kv_groups
B, H, N, d_k = q.shape
G = num_kv_groups
rep = H // G # how many Query heads share each KV head
# repeat KV heads to match the number of Query heads
k = k.repeat_interleave(rep, dim=1) # (B, H, N, d_k)
v = v.repeat_interleave(rep, dim=1)
scores = torch.matmul(q, k.transpose(-2, -1)) / (d_k ** 0.5)
weights = F.softmax(scores, dim=-1)
return torch.matmul(weights, v)
Comparison Table
| Method | KV heads | KV cache size | Quality | Representative models |
|---|---|---|---|---|
| MHA | H | baseline (largest) | highest | early GPT, original paper |
| GQA | G (e.g., 8) | about G/H | close to MHA | Llama 2/3, Qwen |
| MQA | 1 | about 1/H (smallest) | may drop slightly | PaLM, some lightweight models |
3. FlashAttention — Reduce the IO
The Real Bottleneck Is Memory Movement, Not Compute
The slowness of standard attention comes less from the floating-point operations themselves and more from the round trips of writing and re-reading the enormous (N, N) score matrix to slow HBM. The GPU's SRAM is very fast but small; HBM is large but slow. Since the score matrix does not fit in SRAM, it shuttles to and from HBM.
The core of FlashAttention is an IO-aware algorithm. It never materializes the full (N, N) score matrix in HBM. Instead it splits Q, K, V into small tiles (blocks), loads them into SRAM, and computes attention block by block while incrementally updating the softmax online.
standard attention:
1) S = Q·K^T -> store full (N, N) in HBM
2) P = softmax(S) -> read and write (N, N) again
3) O = P·V -> read again
=> HBM round trips O(N^2)
FlashAttention (tiling):
split Q, K, V into blocks -> load into SRAM
compute partial scores per block -> accumulate via online softmax
never builds the full (N, N) matrix
=> HBM round trips drop sharply, memory also reduced to O(N)
The Intuition of Online Softmax
Softmax usually needs the whole row to compute its normalization constant. FlashAttention carries along the maximum and sum seen so far while traversing blocks; when a new block arrives, it safely rescales and accumulates. This yields exactly the same result without ever holding the full matrix in memory at once.
# conceptual pseudocode: real FlashAttention is a CUDA kernel
# the key is block traversal + online softmax accumulation
def flash_attention_concept(Q, K, V, block_size):
N, d = Q.shape
O = zeros(N, d)
for i in range(0, N, block_size): # traverse Q blocks
q_block = Q[i:i+block_size]
running_max = -inf
running_sum = 0
acc = zeros(block_size, d)
for j in range(0, N, block_size): # traverse K, V blocks
k_block = K[j:j+block_size]
v_block = V[j:j+block_size]
s = q_block @ k_block.T / sqrt(d) # only small block scores in SRAM
block_max = s.max(axis=-1)
new_max = maximum(running_max, block_max)
# rescale the previous accumulator and current block to the same scale
acc = rescale_and_accumulate(acc, s, v_block, running_max, new_max)
running_max = new_max
O[i:i+block_size] = finalize(acc, running_sum)
return O
In real implementations, this is selected automatically as a backend in PyTorch's scaled_dot_product_attention, or provided by FlashAttention libraries/kernels. You rarely write the pseudocode yourself; you just enable the library.
Summary of Effects
| Item | Standard attention | FlashAttention |
|---|---|---|
| Intermediate memory | scales with square of N | linear in N |
| HBM round trips | many | drastically reduced by tiling |
| Accuracy | baseline | numerically identical (not an approximation) |
| Where the gain is large | short sequences | large speedup on long sequences |
The crucial point is that FlashAttention is not an approximation — it computes exactly the same attention, just faster and with less memory. It is a free optimization with no quality loss, which is why it has effectively become standard.
4. Attention Variants for Long Context
KV cache reduction (GQA/MQA) and IO optimization (FlashAttention) make standard attention efficient, but the attention scope is still the full sequence. When context stretches to hundreds of thousands of tokens, you need techniques that limit the scope itself.
Sliding-Window Attention
Each token is restricted to seeing only the most recent W tokens rather than everything. The attention cost drops from the square of N to N times W, and the KV cache is capped by the window size. Stacking deep layers makes windows overlap indirectly, so distant information still propagates to some degree. Models such as Mistral adopt this.
full attention (N=8):
each query sees all keys -> cost O(N^2)
sliding window (W=3):
query5 sees only key3, key4, key5
query6 sees only key4, key5, key6
-> cost O(N x W), KV cache capped at W
Other Long-Context Approaches
- Global + local mix: Most tokens see only a local window, while a few global tokens see everything, securing a path for distant information.
- Hierarchical / sparse attention: Distant tokens are attended sparsely and nearby tokens densely, sparsifying the pattern.
- Positional-encoding extrapolation: Correcting RoPE (NTK, YaRN, etc.) on inputs longer than training extends the context. This is covered in the positional encoding article.
5. The Effect on Serving Memory
All of these choices ultimately come down to GPU memory budget and throughput. In a 2026 serving stack, the following elements work together.
GPU memory = model weights + KV cache + activations/work buffers
shrinking the KV cache (GQA/MQA, sliding window, KV quantization)
-> larger batch possible -> higher throughput
-> can accommodate longer context
- Continuous (in-flight) batching: New requests are slotted in the moment a request finishes, keeping the GPU busy. This is the 2026 standard.
- Paged KV cache: The KV cache is split into fixed-size blocks managed like OS virtual memory (PagedAttention), eliminating fragmentation and raising memory utilization.
- KV quantization: Storing the KV cache in FP8/INT4 cuts memory by half or more. Since the decode stage is memory-bound, the effect is large.
- Prefill/decode separation, chunked prefill: Separating compute-bound prefill from memory-bound decode, or chunking prefill to smooth out latency.
| Technique | What it reduces | Throughput effect | Quality effect |
|---|---|---|---|
| GQA/MQA | KV cache (head count) | rises by enabling larger batch | almost none ~ slight |
| FlashAttention | attention IO / intermediate memory | direct speedup | none (identical result) |
| Sliding window | KV cache / attention scope | rises on long context | caution for long-range deps |
| KV quantization | KV cache precision | rises by enabling larger batch | slight (config-dependent) |
| Paged KV cache | memory fragmentation | rises via utilization | none |
5.5. FlashAttention, the Causal Mask, and Later Improvements
FlashAttention's tiling also fits well with the causal mask. In a decoder, key blocks that lie in the future relative to a query block need not be computed at all, so they can be skipped block by block, reducing computation further.
block skipping in causal attention:
for query block i, key block j
blocks with j > i (entirely future) -> skip entirely
blocks with j == i -> triangular mask only inside the block
blocks with j < i -> compute fully
-> attention computation drops roughly by half when causal
FlashAttention has raised hardware utilization further over successive versions. The core idea (tiling + online softmax + IO awareness) stays the same, but improved warp/thread scheduling and work partitioning fill the compute units of modern GPUs more thoroughly. In practice, libraries or frameworks automatically select the implementation suited to the hardware, so the user only needs to confirm that it is "on."
direction of generational progress (conceptual):
gen 1: no materialization of (N,N) + online softmax -> memory O(N), fewer HBM round trips
later: rearranged parallelization axes, improved warp-level work partitioning
-> higher GPU occupancy with the same algorithm
The key point is that all of these improvements do not change the result. Whichever generation you use, it computes numerically identical attention; the difference is only speed and memory efficiency.
6. Computing a Serving Memory Budget Yourself
To turn theory into practical intuition, you have to plug in numbers yourself. Assume a 7B-class model on a single 80GB GPU and work out how large you can grow the batch.
assumptions:
model weights: FP16, about 14 GB
L=32, D=4096, FP16 (2 bytes)
context N=8192
out of 80 GB GPU memory:
weights 14 GB
activations/work buffers about 6 GB (rough)
headroom for KV cache = 80 - 14 - 6 = about 60 GB
KV cache per request for MHA (32 KV heads):
= 2 x 32 x 8192 x 4096 x 2 (bytes)
about 4.3 GB
-> 60 / 4.3 = about 13 concurrent requests
switching to GQA (G=8, 8 KV heads), KV cache per request:
= above x 8/32 = about 1.1 GB
-> 60 / 1.1 = about 54 concurrent requests
Merely reducing KV heads from 32 to 8 raised the number of concurrent requests roughly 4x. The number of concurrent requests is throughput, so a single switch to GQA substantially lowers serving cost on the same GPU. Add KV quantization (FP8) on top and you can halve it again.
7. Why Decode Is Memory-Bound, and What It Means
Inference splits into two stages: prefill, which processes the input prompt at once, and decode, which produces tokens one at a time. Their characters are opposite.
prefill (processing N prompt tokens):
large matmul -> compute (FLOPs) bound
fills the GPU tensor cores
decode (generating one token at a time):
thin matrix-vector product -> memory bandwidth bound
each step reads weights + KV cache from memory
little compute, but memory reads are the bottleneck
The fact that decode is memory-bound sets the optimization direction. Since memory reads, not compute, are the bottleneck, it is effective to (1) make weights small (quantization), (2) make the KV cache small (GQA/MQA, KV quantization), and (3) do more work per memory read (larger batch, speculative decoding).
2026 techniques attacking memory-bound decode:
continuous batching : slot a new request immediately where one finished
paged KV cache : manage KV in blocks to eliminate fragmentation
FP8/INT4 quantization: smaller weights/KV -> fewer memory reads
speculative decoding: draft model produces several tokens ahead and verifies
about 2-3x speedup in the memory-bound regime
8. Attention Optimization by Serving Framework
The major 2026 serving frameworks implement the above techniques in their own ways.
| Framework | Characteristics | Attention/KV strengths |
|---|---|---|
| vLLM | general-purpose, wide model/hardware support | original PagedAttention, continuous batching |
| TensorRT-LLM | compilation-based, NVIDIA-optimized | high throughput on H100 (cases of about 15-30% higher) |
| SGLang | RadixAttention | prefix cache reuse, strong on multi-turn |
All three frameworks build on FlashAttention-family kernels and a paged KV cache by default and support GQA/MQA models. They differ in compilation optimization level, prefix cache reuse strategy, and breadth of supported hardware. vLLM shines in generality and wide support, TensorRT-LLM in peak performance on NVIDIA hardware, and SGLang in multi-turn/structured workloads that share a lot of prefix.
9. A Guide to Choosing Attention Variants
Here is how to choose a combination by situation.
general text LLM serving:
GQA model + FlashAttention + paged KV cache + continuous batching
(the default for most)
memory tight and throughput is top priority:
add KV quantization (FP8) + more aggressive batching
ultra-long context (hundreds of thousands of tokens):
sliding-window/sparse-attention model + position extrapolation (NTK/YaRN)
+ paged KV cache is essential
multi-turn chat / workloads with much shared prefix:
large gains from prefix cache reuse (e.g., RadixAttention)
The core principle is to combine three axes to fit the situation: "tame standard attention with IO optimization (FlashAttention), the KV cache with head sharing and quantization, and the attention scope with windowing/sparsification."
Pitfalls and Troubleshooting
- GQA head mismatch: If KV heads G do not evenly divide Query heads H, repeat_interleave raises a shape error. H must be a multiple of G.
- Overlooking MQA quality loss: Going to MQA purely for memory can hurt quality on certain tasks. Start with GQA (G=8, etc.) to strike a balance.
- FlashAttention not applied: If the library/backend is off, you can hit OOM or slowness on long sequences. In PyTorch, verify that the appropriate backend is selected.
- Long-range loss with sliding windows: If the window is too small, information from the front of a document cannot propagate to the back. Consider the relationship between window size and number of layers.
- KV quantization precision pitfall: Quantizing KV to too few bits can degrade quality on long context due to accumulated error. FP8 is often safer than INT4.
- Skipping batch memory estimation: The KV cache scales linearly with batch. If you raise the batch without precomputing memory using the formula above, you will hit OOM in production.
9.5. How Attention Optimization Interacts with Other Efficiency Techniques
Attention/KV optimization is not used in isolation; it works together with other efficiency techniques across the model and serving stack. Let us note a few interactions.
relationship with quantization (weights):
weight quantization reduces model memory; KV quantization reduces the KV cache
-> they are independent and can be applied together
-> together they accommodate larger batches / longer contexts
relationship with MoE:
MoE splits the FFN into multiple experts and activates only some
-> the attention part is unchanged, so it is orthogonal to GQA/FlashAttention
-> note: MoE has fewer active parameters but a large total parameter (memory) count
relationship with speculative decoding:
a draft model produces several tokens ahead and the main model verifies in one pass
-> KV cache management gets more complex, but the speedup on memory-bound decode is large
-> shrinking KV with GQA leaves memory headroom for both draft and verification
The key is that most of these techniques are mutually orthogonal. Attention variants (GQA/FlashAttention) make "attention cheap," quantization makes "weights/KV small," MoE makes "compute conditional," and speculative decoding makes "decode fast." Since they attack different axes, they can be stacked together, and real production serving is built by combining them.
9.7. Correcting Common Misconceptions
Here we clear up common misconceptions about attention efficiency.
misconception 1: "FlashAttention is an approximation, so quality drops a little"
-> wrong. It is numerically identical attention. No quality loss.
misconception 2: "Using MQA is always faster than GQA"
-> memory is smaller, but quality loss may require a larger model / retraining.
The effective cost depends on the workload.
misconception 3: "For long context, you just grow the KV cache"
-> beyond memory, you must also handle position extrapolation (quality) and attention cost (compute).
misconception 4: "Bigger batch is always better"
-> throughput rises, but per-request latency (TPOT) can increase. Balance against the SLA.
10. How to Read Benchmark Numbers
Reading serving-optimization articles, you are flooded with figures like "X% throughput improvement" or "Y ms latency." To interpret these correctly, you must distinguish a few axes.
key metrics:
throughput: tokens per unit time (across all users)
latency:
TTFT (time to first token): time to the first token -> heavily influenced by prefill
TPOT (time per output token): the gap between tokens -> heavily influenced by decode
concurrency: number of requests handled at once (directly tied to batch)
Even the same optimization improves different metrics. For example, raising the batch increases overall throughput but can increase per-request latency. KV cache reduction (GQA) raises throughput by enabling larger batches, FlashAttention reduces TTFT in prefill and on long sequences, and speculative decoding mainly reduces TPOT (decode latency).
| Metric | Techniques that mainly improve it | Caution |
|---|---|---|
| Throughput | GQA, larger batch, KV quantization | possible trade-off with latency |
| TTFT | FlashAttention, chunked prefill | felt strongly on long prompts |
| TPOT | speculative decoding, KV optimization | effect varies with acceptance rate |
When citing benchmark numbers, the crux is "on what hardware, with what model, at what sequence length and batch" the measurement was taken. Since the same technique can show very different gains depending on conditions, understanding the mechanism of which axis is improved and how matters more than the absolute numbers.
Closing
The two bottlenecks of standard attention — quadratic compute cost and a linear KV cache — are tackled with different tools. FlashAttention reduces the attention's memory IO while preserving accuracy, giving an effectively free speedup. GQA/MQA share KV heads to shrink the KV cache, enabling larger batches and longer contexts. Sliding windows and sparse attention limit the scope itself to tame the cost of ultra-long contexts.
In practice, these are used together. Shrinking KV with GQA, accelerating attention with FlashAttention, and using every bit of memory with continuous batching and a paged KV cache is the baseline combination for serving in 2026. In the next article we dive deep into positional encoding, especially RoPE and length extrapolation.
References
- Dao et al., "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (arxiv 2205.14135): https://arxiv.org/abs/2205.14135
- Vaswani et al., "Attention Is All You Need" (arxiv 1706.03762): https://arxiv.org/abs/1706.03762
- vLLM official docs (PagedAttention, KV cache): https://docs.vllm.ai
- vLLM repository: https://github.com/vllm-project/vllm
- SGLang repository (RadixAttention, prefix cache): https://github.com/sgl-project/sglang
- TensorRT-LLM repository: https://github.com/NVIDIA/TensorRT-LLM
- PyTorch scaled_dot_product_attention docs: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
- Hugging Face Transformers docs: https://huggingface.co/docs/transformers/index