Skip to content
Published on

FlashAttention: Optimizing Attention Through GPU Memory Hierarchy

Authors
  • Name
    Twitter

1. Introduction

Self-Attention, the core component of the Transformer architecture, computes relationships between all token pairs in a sequence. While this operation provides powerful representational capacity, it suffers from a fundamental limitation: both time and memory complexity grow as O(N2)O(N^2) with respect to the sequence length NN. For state-of-the-art LLMs such as GPT-4, LLaMA, and Gemini to handle long contexts exceeding 128K tokens, this O(N2)O(N^2) bottleneck must be effectively addressed.

FlashAttention (Dao et al., 2022) solves this problem without any approximation. The core idea is simple yet profound: rather than reducing the computational cost of the attention operation itself, it minimizes data movement (IO) between GPU memory hierarchy levels. In this article, we systematically analyze the principles of FlashAttention from a GPU hardware perspective and trace its evolution through FlashAttention-2 and FlashAttention-3.


2. The Memory Problem of Standard Attention

2.1 Standard Attention Computation Flow

Standard Self-Attention is computed as follows. Given inputs Q,K,VRN×dQ, K, V \in \mathbb{R}^{N \times d}:

S=QKTRN×NS = QK^T \in \mathbb{R}^{N \times N} P=softmax(S)RN×NP = \text{softmax}(S) \in \mathbb{R}^{N \times N} O=PVRN×dO = PV \in \mathbb{R}^{N \times d}

Here, NN is the sequence length and dd is the head dimension.

2.2 Memory Complexity Analysis

The crux of the problem lies in the intermediate matrices SS and PP. These matrices are of size N×NN \times N, requiring quadratic memory with respect to sequence length. To put this in concrete numbers:

Sequence Length (NN)Attention Matrix SizeFP16 Memory
1,0241M elements2 MB
4,09616.7M elements33 MB
16,384268M elements536 MB
65,5364.3B elements8.6 GB
131,07217.2B elements34.4 GB

These figures are for a single head and single batch. In multi-head attention, multiplying by the number of heads hh and the batch size BB results in significantly larger actual memory consumption. At sequence length 65,536, even a single head consumes a substantial portion of the HBM on an A100 80GB GPU.

2.3 HBM Bottleneck

In the standard attention implementation, these N×NN \times N matrices are materialized in GPU HBM (High Bandwidth Memory). That is, S=QKTS = QK^T is computed and written to HBM, then read back for softmax, the result PP is written to HBM, and then read back for O=PVO = PV. The total number of HBM reads and writes in this process is Ω(Nd+N2)\Omega(Nd + N^2).

The real reason this operation is slow on actual GPUs is that memory access, not compute, is the bottleneck. The A100 GPU delivers 312 TFLOPS (FP16) of compute throughput, while its HBM bandwidth is only about 2 TB/s. Attention is a classic memory-bound operation due to its low arithmetic intensity (ratio of compute to memory access).


3. GPU Memory Hierarchy

Understanding FlashAttention requires precise knowledge of the GPU memory hierarchy.

3.1 HBM (High Bandwidth Memory)

  • Capacity: 40GB or 80GB on the A100
  • Bandwidth: Approximately 1.5-2.0 TB/s (A100 80GB SXM: 2,039 GB/s)
  • Access latency: Approximately 200-600 cycles
  • Role: The GPU's main memory. All data including model parameters, input tensors, and output tensors are stored here

3.2 SRAM (On-chip Shared Memory)

  • Capacity: Approximately 192KB per SM on the A100, approximately 20MB total (108 SMs)
  • Bandwidth: Approximately 19 TB/s
  • Access latency: Approximately 20-30 cycles
  • Role: High-speed on-chip memory within each Streaming Multiprocessor (SM)

3.3 The Critical Asymmetry

A dramatic asymmetry exists between SRAM and HBM:

PropertySRAMHBM
Bandwidth~19 TB/s~2 TB/s
Capacity~20 MB40-80 GB
Access Latency20-30 cycles200-600 cycles

SRAM is approximately 10x faster than HBM, but approximately 4,000x smaller in capacity. FlashAttention's key insight is to actively exploit this asymmetry: instead of materializing the entire N×NN \times N matrix in HBM, performing computations in small blocks that fit in SRAM can dramatically reduce HBM accesses.


4. IO Complexity Analysis

4.1 IO Complexity of Standard Attention

Standard attention exhibits the following HBM access pattern:

  1. Read Q,KQ, K from HBM, compute S=QKTS = QK^T, write SS to HBM: Θ(Nd+N2)\Theta(Nd + N^2) IO
  2. Read SS from HBM, compute P=softmax(S)P = \text{softmax}(S), write PP to HBM: Θ(N2)\Theta(N^2) IO
  3. Read P,VP, V from HBM, compute O=PVO = PV, write OO to HBM: Θ(Nd+N2)\Theta(Nd + N^2) IO

Total HBM access: Θ(Nd+N2)\Theta(Nd + N^2)

Since the sequence length NN is typically much larger than the head dimension dd (usually 64 or 128), the N2N^2 term dominates.

4.2 FlashAttention's IO Complexity

FlashAttention reduces HBM access through tiling to:

O(N2d2M)O\left(\frac{N^2 d^2}{M}\right)

where MM is the SRAM size. Intuitively, larger SRAM allows processing larger blocks at once, reducing HBM accesses.

4.3 Optimality Proof (Lower Bound)

The paper goes further to prove the following lower bound:

Theorem: For all SRAM sizes MM where dMNdd \leq M \leq Nd, any algorithm computing exact attention requires Ω(N2d2/M)\Omega(N^2 d^2 / M) HBM accesses.

This means FlashAttention is optimal in terms of IO complexity. Excluding constant and polylogarithmic factors, it is impossible to compute exact attention with fewer HBM accesses.

4.4 Numerical Example

For the A100 with SRAM size M192M \approx 192KB, head dimension d=64d = 64, and sequence length N=4096N = 4096:

  • Standard attention IO: Θ(Nd+N2)4096×64+4096217M\Theta(Nd + N^2) \approx 4096 \times 64 + 4096^2 \approx 17M elements
  • FlashAttention IO: Θ(N2d2/M)40962×642/(192×512)7M\Theta(N^2 d^2 / M) \approx 4096^2 \times 64^2 / (192 \times 512) \approx 7M elements (varies with block size)

In practice, since the N2N^2-sized intermediate matrices are never written to HBM at all, the savings are even greater. The benefits become particularly pronounced as sequence length increases.


5. Tiling: Block-wise Computation That Fits in SRAM

5.1 Algorithm Overview

The core FlashAttention algorithm works as follows:

  1. Partition QQ into Tr=N/BrT_r = \lceil N / B_r \rceil blocks: Q1,Q2,,QTrQ_1, Q_2, \ldots, Q_{T_r}, each of size Br×dB_r \times d
  2. Partition K,VK, V into Tc=N/BcT_c = \lceil N / B_c \rceil blocks: K1,,KTcK_1, \ldots, K_{T_c} and V1,,VTcV_1, \ldots, V_{T_c}, each of size Bc×dB_c \times d
  3. Block sizes Br,BcB_r, B_c are set to fit in SRAM of size MM: Bc=M/(4d)B_c = \lceil M / (4d) \rceil, Br=min(M/(4d),d)B_r = \min(\lceil M / (4d) \rceil, d)

5.2 Forward Pass Pseudocode

Algorithm: FlashAttention Forward Pass
---------------------------------------
Input: Q, K, V in HBM, SRAM size M
Output: O in HBM

1. Set block sizes: B_c = ceil(M / 4d), B_r = min(ceil(M / 4d), d)
2. Initialize O = zeros(N, d), l = zeros(N), m = -inf * ones(N) in HBM

3. for j = 1 to T_c:                        # Outer loop: K, V blocks
     Load K_j, V_j from HBM to SRAM

     for i = 1 to T_r:                      # Inner loop: Q blocks
       Load Q_i, O_i, l_i, m_i from HBM to SRAM

       # Perform block-wise computation in SRAM
       S_ij = Q_i @ K_j^T                   # (B_r x B_c)
       m_ij = rowmax(S_ij)
       P_ij = exp(S_ij - m_ij)
       l_ij = rowsum(P_ij)

       # Combine with statistics from previous blocks (Online Softmax)
       m_new = max(m_i, m_ij)
       l_new = exp(m_i - m_new) * l_i + exp(m_ij - m_new) * l_ij

       # Update output (with rescaling)
       O_i = diag(exp(m_i - m_new))^(-1) * (diag(l_i) * O_i)
             + diag(exp(m_ij - m_new))^(-1) * P_ij @ V_j
       O_i = diag(l_new)^(-1) * O_i

       # Update statistics
       m_i = m_new, l_i = l_new

       Write O_i, l_i, m_i back to HBM
     end for
   end for

4. return O

5.3 Why This Works

The key point is that the N×NN \times N attention matrices SS and PP are never materialized in HBM. Each Br×BcB_r \times B_c block SijS_{ij} is computed within SRAM, immediately used for softmax statistics updates and output accumulation, and then discarded.

The mathematical technique that makes this possible is Online Softmax.


6. Online Softmax (Safe Softmax) Algorithm

6.1 The Problem with Standard Softmax

Softmax is a global operation. For a row vector x=[x1,,xN]x = [x_1, \ldots, x_N]:

softmax(xi)=exij=1Nexj\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{N} e^{x_j}}

Computing this requires seeing the entire row at once to calculate the denominator sum. This is the fundamental barrier that makes tiling difficult -- looking only at block Si1S_{i1} is insufficient to complete the softmax, because the denominator changes depending on the values in the remaining blocks Si2,Si3,S_{i2}, S_{i3}, \ldots.

Additionally, for numerical stability, "safe softmax" is used:

softmax(xi)=eximj=1Nexjm,m=maxjxj\text{softmax}(x_i) = \frac{e^{x_i - m}}{\sum_{j=1}^{N} e^{x_j - m}}, \quad m = \max_j x_j

This also requires the global maximum mm, necessitating a scan of the entire row first.

6.2 The Online Softmax Trick

The key idea of Online Softmax (Milakov & Gimelshein, 2018) is to compute softmax incrementally, block by block, while maintaining running statistics.

Two scalars are maintained per row:

  • mm: the maximum of all elements seen so far (running max)
  • \ell: the normalization constant so far (running sum of exponentials)

When a new block SijS_{ij} arrives:

  1. Compute the row-wise maximum of the new block: m~=rowmax(Sij)\tilde{m} = \text{rowmax}(S_{ij})
  2. Update the global maximum: mnew=max(m,m~)m_{\text{new}} = \max(m, \tilde{m})
  3. Rescale the previous normalization constant: new=emmnew+em~mnew~\ell_{\text{new}} = e^{m - m_{\text{new}}} \cdot \ell + e^{\tilde{m} - m_{\text{new}}} \cdot \tilde{\ell}
  4. Rescale the previous output: Onew=emmnewO+em~mnewP~VjnewO_{\text{new}} = \frac{e^{m - m_{\text{new}}} \cdot \ell \cdot O + e^{\tilde{m} - m_{\text{new}}} \cdot \tilde{P}V_j}{\ell_{\text{new}}}

This process is mathematically exact. It is not an approximation. Regardless of the order in which blocks are processed, the final result is identical to standard attention bit-for-bit (except for minor numerical differences due to floating-point operation ordering).

6.3 Mathematical Justification

The core of the proof is the rescaling property of softmax:

eximjexjm=eximemmjexjmemm=eximjexjm\frac{e^{x_i - m'}}{\sum_j e^{x_j - m'}} = \frac{e^{x_i - m} \cdot e^{m - m'}}{\sum_j e^{x_j - m} \cdot e^{m - m'}} = \frac{e^{x_i - m}}{\sum_j e^{x_j - m}}

Even when the maximum is updated from mm to mm', the same factor emme^{m - m'} multiplies both numerator and denominator, so the ratio remains unchanged. This property allows the results from previous blocks to be safely rescaled to the new maximum.


7. Backward Pass Recomputation Strategy

7.1 The Problem with Standard Backward Pass

In the standard attention backward pass, the intermediate matrices SS and PP saved during the forward pass are needed for gradient computation. Since their size is N×NN \times N, storing them in the forward pass and reading them back in the backward pass requires O(N2)O(N^2) memory.

7.2 FlashAttention's Recomputation

FlashAttention uses a variant of gradient checkpointing. During the forward pass, SS and PP are not saved. Instead, only the following are stored:

  • The final output ORN×dO \in \mathbb{R}^{N \times d}
  • The softmax normalization statistics m,RNm, \ell \in \mathbb{R}^{N} (per-row maximum and sum)

During the backward pass, these statistics and the original Q,K,VQ, K, V are used to recompute the needed blocks of SS and PP in SRAM. This recomputation requires additional FLOPs but significantly reduces HBM access.

7.3 The Paradoxical Effect of Recomputation

Typically, gradient checkpointing saves memory at the cost of speed. However, FlashAttention's recomputation actually improves speed as well. The reason is:

  • FLOPs increase: Since what was computed once in the forward pass is recomputed in the backward pass, total FLOPs slightly increase.
  • HBM IO decreases: The cost of writing and reading the N×NN \times N-sized SS and PP to and from HBM is eliminated.

On modern GPUs, HBM access is much slower than computation, so the benefit from IO reduction outweighs the FLOP increase. Experimental results show that the additional runtime overhead from recomputation is less than 5%, while memory usage decreases from O(N2)O(N^2) to O(N)O(N).

7.4 Memory Savings

Sequence LengthStandard Attention MemoryFlashAttention MemorySavings Ratio
1K~2 MB~0.13 MB~15x
2K~8 MB~0.26 MB~30x
4K~33 MB~0.52 MB~63x
8K~131 MB~1.04 MB~126x

These savings enable processing longer sequences or using larger batch sizes with the same GPU memory.


8. FlashAttention-2 Improvements

Dao (2023) introduced three key improvements in FlashAttention-2.

8.1 Minimizing Non-matmul FLOPs

The A100 GPU's Tensor Cores deliver 312 TFLOPS (FP16) for matrix multiplication (matmul), but non-matmul operations (softmax's exp, max, sum, etc.) run at 19.5 TFLOPS (FP32) -- approximately 16x slower. In FlashAttention-1, the proportion of non-matmul operations was significant.

FlashAttention-2 restructures the algorithm to minimize these non-matmul FLOPs. Specifically, it reduces the number of rescaling operations and performs softmax statistics updates more efficiently. The key change is performing the final rescaling only once at the end of the loop.

8.2 Improved Parallelism: Sequence Length Dimension Parallelization

FlashAttention-1 only parallelized across the batch and head dimensions. When the batch size was small or the number of heads was low, the GPU's SMs (Streaming Multiprocessors) were underutilized.

FlashAttention-2 also parallelizes across the sequence length dimension. By changing the outer loop to iterate over Q blocks (rather than K,VK, V blocks), each Q block can be processed by an independent thread block. This change significantly improves occupancy during the forward pass.

8.3 Work Partitioning Optimization

Work distribution among warps within a thread block was also improved:

  • FlashAttention-1: K, V are split across 4 warps, each warp independently computes QKTQK^T and then synchronizes results. This approach incurs communication and synchronization overhead through shared memory.
  • FlashAttention-2: Q is split across 4 warps, while K and V are shared by all warps. Since each warp computes outputs for different parts of Q independently, no inter-warp communication is needed.

8.4 Performance Results

Combining these three improvements:

  • Approximately 2x speedup over FlashAttention-1
  • Achieves 230 TFLOPS in FP16/BF16 on the A100 (approximately 73% of the theoretical maximum)
  • Up to 9x speedup over standard PyTorch attention
  • Approaches the efficiency of GEMM (matrix multiplication) operations

9. FlashAttention-3: Latest Advances

FlashAttention-3 (Shah et al., 2024) took a further step forward by leveraging the new hardware capabilities of the NVIDIA Hopper architecture (H100).

9.1 New Capabilities of the Hopper GPU

The H100 GPU provides the following key capabilities over the A100:

  • WGMMA (Warpgroup Matrix Multiply-Accumulate): A new Tensor Core instruction with much higher throughput than the A100's mma.sync
  • TMA (Tensor Memory Accelerator): A dedicated hardware unit for data transfers between global memory and shared memory, handling index computation and bounds checking in hardware

9.2 Three Key Techniques

1. Asynchronous Execution via Warp Specialization

Computation (WGMMA) and data movement (TMA) are assigned to different warp groups for pipelined, overlapping execution. While one warp group computes the current block, another prefetches data for the next block.

2. Interleaving of Matmul and Softmax

Previously, matmul was followed by softmax, then another matmul, in a sequential manner. FlashAttention-3 interleaves these so that matmul and softmax execute simultaneously on different hardware units. While Tensor Cores compute QKTQK^T for the next block, CUDA Cores process the softmax of the current block.

3. FP8 Low-precision Support

Leveraging the H100's FP8 Tensor Cores doubles throughput. Naive FP8 quantization degrades accuracy, but FlashAttention-3 addresses this with two techniques:

  • Block quantization: Maintaining separate scale factors per block to preserve dynamic range
  • Incoherent processing: Multiplying by a random orthogonal matrix to distribute outliers before quantization, achieving 2.6x lower numerical error compared to the FP8 baseline

9.3 Performance Results

FlashAttention-3 performance on the H100:

ConfigurationTFLOPSGPU Utilization
FP16 FlashAttention-2~400~50%
FP16 FlashAttention-3~740~75%
FP8 FlashAttention-3~1,200~75%

In FP16, it achieves a 1.5-2.0x speedup over FlashAttention-2, and in FP8, it approaches 1.2 PFLOPS.


10. Benchmarks: Speed and Memory Comparison

10.1 Attention Forward Pass Speed (A100 80GB, FP16)

Key figures reported in the FlashAttention paper and subsequent benchmarks:

Sequence LengthStandard AttentionFlashAttentionFlashAttention-2Speedup (FA2 vs Std)
51212.2 ms3.5 ms1.9 ms6.4x
1K45.8 ms7.8 ms4.1 ms11.2x
2K178 ms18.9 ms9.8 ms18.2x
4K710 ms52.3 ms27.1 ms26.2x
8KOOM145 ms75 ms-
16KOOM520 ms270 ms-

The speedup becomes increasingly dramatic as sequence length grows. At 8K and above, standard attention fails with OOM (Out of Memory), while FlashAttention handles them without issue.

10.2 End-to-End Training Performance

ModelStandardFlashAttentionSpeedup
BERT-large (seq 512)100% (MLPerf ref.)115%1.15x
GPT-2 (seq 1K)100%300%3.0x
Long-range Arena (seq 1K-4K)100%240%2.4x

10.3 Memory Usage Comparison

FlashAttention's attention operation memory scales linearly with sequence length, a dramatic improvement over standard attention's quadratic scaling:

  • Sequence length 2K: approximately 10x memory savings
  • Sequence length 4K: approximately 20x memory savings
  • Sequence length 64K: standard attention causes OOM even on an A100 80GB, while FlashAttention runs normally

11. Integration with PyTorch torch.nn.functional.scaled_dot_product_attention

11.1 Native Integration

Starting with PyTorch 2.0, FlashAttention is natively integrated into torch.nn.functional.scaled_dot_product_attention (SDPA). Since PyTorch 2.2, FlashAttention-2 is used as the default backend.

import torch
import torch.nn.functional as F

# Basic usage - automatically selects FlashAttention backend
query = torch.randn(batch_size, num_heads, seq_len, head_dim,
                    device='cuda', dtype=torch.float16)
key = torch.randn(batch_size, num_heads, seq_len, head_dim,
                  device='cuda', dtype=torch.float16)
value = torch.randn(batch_size, num_heads, seq_len, head_dim,
                    device='cuda', dtype=torch.float16)

# PyTorch automatically selects the optimal backend
output = F.scaled_dot_product_attention(query, key, value)

11.2 Explicit Backend Selection

You can force or exclude specific backends:

from torch.nn.attention import sdpa_kernel, SDPBackend

# Use FlashAttention backend only
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    output = F.scaled_dot_product_attention(query, key, value)

# Use Memory-efficient attention backend only
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
    output = F.scaled_dot_product_attention(query, key, value)

# Use Math (naive) backend - for debugging
with sdpa_kernel(SDPBackend.MATH):
    output = F.scaled_dot_product_attention(query, key, value)

# Use CuDNN backend (PyTorch 2.2+)
with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
    output = F.scaled_dot_product_attention(query, key, value)

11.3 Using with Causal Mask

Causal masking, essential for autoregressive generation in LLMs, is also supported:

# Apply causal mask with is_causal=True
# FlashAttention handles the causal mask within the fused kernel, requiring no extra memory
output = F.scaled_dot_product_attention(
    query, key, value,
    is_causal=True
)

# Using a custom attention mask
attn_mask = torch.tril(torch.ones(seq_len, seq_len, device='cuda', dtype=torch.bool))
output = F.scaled_dot_product_attention(
    query, key, value,
    attn_mask=attn_mask
)

11.4 Backend Selection Criteria

Conditions for PyTorch SDPA to select the FlashAttention backend:

  • dtype: float16 or bfloat16 (float32 is not supported)
  • device: CUDA GPU (CPU is not supported)
  • head dimension: Maximum 256 (for FlashAttention-2)
  • attention mask: Boolean mask or is_causal=True are supported; arbitrary float masks are not supported

If these conditions are not met, PyTorch automatically falls back to the memory-efficient attention or math backend.

11.5 Practical Tips

# Check which backend is being used
import torch.backends.cuda

# Check the enabled state of each backend
print(f"Flash SDP enabled: {torch.backends.cuda.flash_sdp_enabled()}")
print(f"Mem efficient SDP enabled: {torch.backends.cuda.mem_efficient_sdp_enabled()}")
print(f"Math SDP enabled: {torch.backends.cuda.math_sdp_enabled()}")

# Globally disable a specific backend
torch.backends.cuda.enable_flash_sdp(False)  # Disable FlashAttention
torch.backends.cuda.enable_mem_efficient_sdp(True)

11.6 Using the flash-attn Library Directly

In addition to PyTorch's native SDPA, you can use Tri Dao's flash-attn package directly. This package provides more features than PyTorch SDPA (e.g., sliding window attention, ALiBi, cross-attention optimization):

# pip install flash-attn
from flash_attn import flash_attn_func

# Shape: (batch, seqlen, nheads, headdim)
output = flash_attn_func(q, k, v, causal=True)

12. Summary and Key Takeaways

The key lesson of FlashAttention is that FLOP complexity alone does not determine performance. On modern GPUs, memory access patterns dominate actual execution time, and IO-aware algorithm design is decisive for practical performance.

The main contributions can be summarized as follows:

  1. IO-Aware Design Principle: Algorithm design that exploits the asymmetry of the GPU memory hierarchy (HBM vs SRAM)
  2. Tiling + Online Softmax: Block-wise computation that fits in SRAM, eliminating HBM materialization of the N×NN \times N matrix
  3. Recomputation Strategy: Recomputing intermediate values in the backward pass to reduce memory from O(N2)O(N^2) to O(N)O(N), while simultaneously improving speed
  4. Optimality Proof: Proving the lower bound from an IO complexity perspective to establish the algorithm's optimality
  5. Exact Computation: Maintaining exact attention without approximation despite all optimizations

FlashAttention is a rare piece of research that combines theoretical elegance with practical effectiveness, and it has become core infrastructure for modern LLM training and inference. Thanks to its native integration in PyTorch, its benefits can be enjoyed simply by calling F.scaled_dot_product_attention without any additional implementation.


References