Skip to content

Split View: FlashAttention: GPU 메모리 계층을 활용한 어텐션 최적화 분석

|

FlashAttention: GPU 메모리 계층을 활용한 어텐션 최적화 분석

1. 들어가며

Transformer 아키텍처의 핵심인 Self-Attention은 시퀀스 내 모든 토큰 쌍 간의 관계를 계산한다. 이 연산은 강력한 표현력을 제공하지만, 시퀀스 길이 NN에 대해 시간 및 메모리 복잡도가 O(N2)O(N^2)으로 증가하는 근본적인 한계를 갖는다. GPT-4, LLaMA, Gemini 등 최신 LLM이 128K 이상의 긴 컨텍스트를 처리하려면, 이 O(N2)O(N^2) 병목을 실질적으로 해결해야 한다.

FlashAttention(Dao et al., 2022)은 이 문제를 근사(approximation) 없이 해결한다. 핵심 아이디어는 단순하면서도 깊다: attention 연산 자체의 계산량을 줄이는 것이 아니라, GPU 메모리 계층 간 데이터 이동(IO)을 최소화하는 것이다. 이 글에서는 FlashAttention의 원리를 GPU 하드웨어 관점에서 체계적으로 분석하고, FlashAttention-2와 FlashAttention-3까지의 발전을 살펴본다.


2. Standard Attention의 메모리 문제

2.1 Standard Attention 연산 흐름

Standard Self-Attention은 다음과 같이 계산된다. 입력 Q,K,VRN×dQ, K, V \in \mathbb{R}^{N \times d}에 대해:

S=QKTRN×NS = QK^T \in \mathbb{R}^{N \times N} P=softmax(S)RN×NP = \text{softmax}(S) \in \mathbb{R}^{N \times N} O=PVRN×dO = PV \in \mathbb{R}^{N \times d}

여기서 NN은 시퀀스 길이, dd는 head dimension이다.

2.2 메모리 복잡도 분석

문제의 핵심은 중간 행렬 SSPP에 있다. 이 행렬들의 크기는 N×NN \times N이며, 시퀀스 길이에 대해 이차(quadratic) 메모리를 요구한다. 구체적인 수치를 계산하면:

시퀀스 길이 (NN)Attention 행렬 크기FP16 메모리
1,0241M 원소2 MB
4,09616.7M 원소33 MB
16,384268M 원소536 MB
65,5364.3B 원소8.6 GB
131,07217.2B 원소34.4 GB

이 수치는 단일 head, 단일 배치에 대한 것이다. Multi-head attention에서 head 수 hh를 곱하고, 배치 크기 BB를 곱하면 실제 메모리 사용량은 훨씬 커진다. 시퀀스 길이 65,536에서 이미 단일 head만으로 A100 80GB GPU의 HBM 상당 부분을 소비하게 된다.

2.3 HBM 병목 현상

Standard attention의 구현에서는 이 N×NN \times N 행렬을 GPU HBM(High Bandwidth Memory)에 materialization한다. 즉, S=QKTS = QK^T를 계산하여 HBM에 쓰고, softmax를 위해 다시 읽고, 결과 PP를 HBM에 쓰고, O=PVO = PV를 위해 다시 읽는다. 이 과정에서 HBM에 대한 읽기/쓰기 횟수는 Ω(Nd+N2)\Omega(Nd + N^2)이 된다.

실제 GPU에서 이 연산이 느린 이유는 계산(compute)이 아니라 메모리 접근(memory access)이 병목이기 때문이다. A100 GPU의 계산 처리량은 312 TFLOPS(FP16)인 반면, HBM 대역폭은 약 2 TB/s에 불과하다. Attention 연산은 arithmetic intensity(연산량/메모리 접근량 비율)가 낮아 전형적인 memory-bound 연산이다.


3. GPU 메모리 계층 구조

FlashAttention을 이해하려면 GPU의 메모리 계층을 정확히 알아야 한다.

3.1 HBM (High Bandwidth Memory)

  • 용량: A100 기준 40GB 또는 80GB
  • 대역폭: 약 1.5-2.0 TB/s (A100 80GB SXM: 2,039 GB/s)
  • 접근 지연: 약 200-600 사이클
  • 역할: GPU의 메인 메모리. 모델 파라미터, 입력 텐서, 출력 텐서 등 모든 데이터가 저장됨

3.2 SRAM (On-chip Shared Memory)

  • 용량: A100 기준 SM당 약 192KB, 전체 약 20MB (108개 SM)
  • 대역폭: 약 19 TB/s
  • 접근 지연: 약 20-30 사이클
  • 역할: 각 Streaming Multiprocessor(SM) 내의 고속 온칩 메모리

3.3 핵심 비대칭성

SRAM과 HBM 사이에는 극적인 비대칭이 존재한다:

특성SRAMHBM
대역폭~19 TB/s~2 TB/s
용량~20 MB40-80 GB
접근 지연20-30 사이클200-600 사이클

SRAM은 HBM보다 약 10배 빠르지만, 용량은 약 4000배 작다. FlashAttention의 핵심 통찰은 이 비대칭성을 적극 활용하는 것이다: N×NN \times N 행렬 전체를 HBM에 materialization하는 대신, SRAM에 들어가는 작은 블록 단위로 연산을 수행하면 HBM 접근을 극적으로 줄일 수 있다.


4. IO Complexity 분석

4.1 Standard Attention의 IO Complexity

Standard attention은 다음과 같은 HBM 접근 패턴을 보인다:

  1. Q,KQ, K를 HBM에서 읽어 S=QKTS = QK^T 계산 -> SS를 HBM에 쓰기: Θ(Nd+N2)\Theta(Nd + N^2) IO
  2. SS를 HBM에서 읽어 P=softmax(S)P = \text{softmax}(S) 계산 -> PP를 HBM에 쓰기: Θ(N2)\Theta(N^2) IO
  3. P,VP, V를 HBM에서 읽어 O=PVO = PV 계산 -> OO를 HBM에 쓰기: Θ(Nd+N2)\Theta(Nd + N^2) IO

총 HBM 접근량: Θ(Nd+N2)\Theta(Nd + N^2)

시퀀스 길이 NN이 head dimension dd (보통 64 또는 128)보다 훨씬 크므로, N2N^2 항이 지배적이 된다.

4.2 FlashAttention의 IO Complexity

FlashAttention은 tiling을 통해 HBM 접근량을 다음으로 줄인다:

O(N2d2M)O\left(\frac{N^2 d^2}{M}\right)

여기서 MM은 SRAM 크기이다. 직관적으로, SRAM이 클수록 더 큰 블록을 한 번에 처리할 수 있어 HBM 접근이 줄어든다.

4.3 최적성 증명 (Lower Bound)

논문은 더 나아가 다음의 하한(lower bound)을 증명한다:

Theorem: dMNdd \leq M \leq Nd인 모든 SRAM 크기 MM에 대해, exact attention을 계산하는 어떤 알고리즘도 Ω(N2d2/M)\Omega(N^2 d^2 / M)의 HBM 접근이 필요하다.

이는 FlashAttention이 **IO complexity 관점에서 최적(optimal)**임을 의미한다. 상수 인자나 다항 로그 인자를 제외하면, 더 적은 HBM 접근으로 exact attention을 계산하는 것은 불가능하다.

4.4 수치 예시

A100의 SRAM 크기 M192M \approx 192KB, head dimension d=64d = 64, 시퀀스 길이 N=4096N = 4096일 때:

  • Standard attention IO: Θ(Nd+N2)4096×64+4096217M\Theta(Nd + N^2) \approx 4096 \times 64 + 4096^2 \approx 17M 원소
  • FlashAttention IO: Θ(N2d2/M)40962×642/(192×512)7M\Theta(N^2 d^2 / M) \approx 4096^2 \times 64^2 / (192 \times 512) \approx 7M 원소 (블록 크기에 따라 달라짐)

실제로는 N2N^2 크기의 중간 행렬이 HBM에 전혀 기록되지 않으므로, 절약 효과는 더 크다. 특히 시퀀스 길이가 길어질수록 효과가 극대화된다.


5. Tiling 기법: SRAM에 맞는 블록 단위 연산

5.1 알고리즘 개요

FlashAttention의 핵심 알고리즘은 다음과 같다:

  1. QQTr=N/BrT_r = \lceil N / B_r \rceil개의 블록으로 나눈다: Q1,Q2,,QTrQ_1, Q_2, \ldots, Q_{T_r}, 각 블록 크기 Br×dB_r \times d
  2. K,VK, VTc=N/BcT_c = \lceil N / B_c \rceil개의 블록으로 나눈다: K1,,KTcK_1, \ldots, K_{T_c}V1,,VTcV_1, \ldots, V_{T_c}, 각 블록 크기 Bc×dB_c \times d
  3. 블록 크기 Br,BcB_r, B_c는 SRAM 크기 MM에 맞게 설정: Bc=M/(4d)B_c = \lceil M / (4d) \rceil, Br=min(M/(4d),d)B_r = \min(\lceil M / (4d) \rceil, d)

5.2 Forward Pass 의사코드

Algorithm: FlashAttention Forward Pass
---------------------------------------
Input: Q, K, V in HBM, SRAM size M
Output: O in HBM

1. 블록 크기 설정: B_c = ceil(M / 4d), B_r = min(ceil(M / 4d), d)
2. O = zeros(N, d), l = zeros(N), m = -inf * ones(N)HBM에 초기화

3. for j = 1 to T_c:                        # Outer loop: K, V 블록
     K_j, V_jHBM에서 SRAM으로 로드

     for i = 1 to T_r:                      # Inner loop: Q 블록
       Q_i, O_i, l_i, m_i 를 HBM에서 SRAM으로 로드

       # SRAM에서 블록 단위 연산 수행
       S_ij = Q_i @ K_j^T                   # (B_r x B_c)
       m_ij = rowmax(S_ij)
       P_ij = exp(S_ij - m_ij)
       l_ij = rowsum(P_ij)

       # 이전 블록의 통계치와 결합 (Online Softmax)
       m_new = max(m_i, m_ij)
       l_new = exp(m_i - m_new) * l_i + exp(m_ij - m_new) * l_ij

       # 출력 업데이트 (rescaling 포함)
       O_i = diag(exp(m_i - m_new))^(-1) * (diag(l_i) * O_i)
             + diag(exp(m_ij - m_new))^(-1) * P_ij @ V_j
       O_i = diag(l_new)^(-1) * O_i

       # 통계치 업데이트
       m_i = m_new, l_i = l_new

       O_i, l_i, m_i 를 HBM에 다시 쓰기
     end for
   end for

4. return O

5.3 왜 이것이 작동하는가

핵심은 N×NN \times N 크기의 attention 행렬 SSPPHBM에 전혀 materialization되지 않는다는 점이다. 각 Br×BcB_r \times B_c 블록 SijS_{ij}는 SRAM 내에서 계산되고, 즉시 softmax 통계치 업데이트와 출력 누적에 사용된 후 폐기된다.

이를 가능하게 하는 수학적 기법이 바로 Online Softmax이다.


6. Online Softmax (Safe Softmax) 알고리즘

6.1 Standard Softmax의 문제

Softmax는 **전역 연산(global operation)**이다. 행 벡터 x=[x1,,xN]x = [x_1, \ldots, x_N]에 대해:

softmax(xi)=exij=1Nexj\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{N} e^{x_j}}

이를 계산하려면 분모의 합계를 위해 전체 행을 한 번에 봐야 한다. 이것이 tiling을 어렵게 만드는 근본적인 장벽이다 -- 블록 Si1S_{i1}만 보고는 softmax를 완성할 수 없다. 나머지 블록 Si2,Si3,S_{i2}, S_{i3}, \ldots의 값에 따라 분모가 달라지기 때문이다.

또한 수치 안정성을 위해 "safe softmax"를 사용한다:

softmax(xi)=eximj=1Nexjm,m=maxjxj\text{softmax}(x_i) = \frac{e^{x_i - m}}{\sum_{j=1}^{N} e^{x_j - m}}, \quad m = \max_j x_j

이 역시 전역 최댓값 mm이 필요하므로 전체 행을 먼저 스캔해야 한다.

6.2 Online Softmax 트릭

Online Softmax(Milakov & Gimelshein, 2018)의 핵심 아이디어는 running statistics를 유지하면서 블록 단위로 점진적(incremental)으로 softmax를 계산하는 것이다.

두 개의 스칼라를 행마다 유지한다:

  • mm: 지금까지 본 원소들의 최댓값 (running max)
  • \ell: 지금까지의 정규화 상수 (running sum of exponentials)

새로운 블록 SijS_{ij}가 들어오면:

  1. 새 블록의 행별 최댓값 계산: m~=rowmax(Sij)\tilde{m} = \text{rowmax}(S_{ij})
  2. 전역 최댓값 업데이트: mnew=max(m,m~)m_{\text{new}} = \max(m, \tilde{m})
  3. 이전 정규화 상수를 rescale: new=emmnew+em~mnew~\ell_{\text{new}} = e^{m - m_{\text{new}}} \cdot \ell + e^{\tilde{m} - m_{\text{new}}} \cdot \tilde{\ell}
  4. 이전 출력도 rescale: Onew=emmnewO+em~mnewP~VjnewO_{\text{new}} = \frac{e^{m - m_{\text{new}}} \cdot \ell \cdot O + e^{\tilde{m} - m_{\text{new}}} \cdot \tilde{P}V_j}{\ell_{\text{new}}}

이 과정은 **수학적으로 정확(exact)**하다. 근사가 아니다. 블록을 어떤 순서로 처리하든, 최종 결과는 standard attention과 비트 단위로 동일하다(부동소수점 연산 순서에 따른 미세한 수치 차이 제외).

6.3 수학적 정당성

증명의 핵심은 softmax의 rescaling property이다:

eximjexjm=eximemmjexjmemm=eximjexjm\frac{e^{x_i - m'}}{\sum_j e^{x_j - m'}} = \frac{e^{x_i - m} \cdot e^{m - m'}}{\sum_j e^{x_j - m} \cdot e^{m - m'}} = \frac{e^{x_i - m}}{\sum_j e^{x_j - m}}

최댓값이 mm에서 mm'으로 업데이트되더라도, 분자와 분모에 동일한 factor emme^{m - m'}가 곱해지므로 비율은 변하지 않는다. 이 성질 덕분에 이전 블록의 결과를 새로운 최댓값 기준으로 안전하게 rescale할 수 있다.


7. Backward Pass의 Recomputation 전략

7.1 Standard Backward Pass의 문제

Standard attention의 backward pass에서는 gradient 계산을 위해 forward pass에서 저장한 중간 행렬 SSPP가 필요하다. 이들의 크기가 N×NN \times N이므로, forward에서 저장하고 backward에서 다시 읽는 것은 O(N2)O(N^2) 메모리를 요구한다.

7.2 FlashAttention의 Recomputation

FlashAttention은 gradient checkpointing의 변형을 사용한다. Forward pass에서 SSPP를 저장하지 않고, 대신 다음만 저장한다:

  • 최종 출력 ORN×dO \in \mathbb{R}^{N \times d}
  • Softmax 정규화 통계치 m,RNm, \ell \in \mathbb{R}^{N} (행별 최댓값과 합계)

Backward pass에서는 이 통계치와 원본 Q,K,VQ, K, V를 사용하여 **SSPP의 필요한 블록을 SRAM에서 다시 계산(recompute)**한다. 이 recomputation은 추가적인 FLOP을 요구하지만, HBM 접근을 크게 줄인다.

7.3 Recomputation의 역설적 효과

일반적으로 gradient checkpointing은 메모리를 절약하는 대신 속도를 희생한다. 그러나 FlashAttention의 recomputation은 오히려 속도까지 향상시킨다. 이유는 다음과 같다:

  • FLOP은 증가한다: forward에서 한 번 계산한 것을 backward에서 다시 계산하므로, 총 FLOP은 약간 증가한다.
  • HBM IO는 감소한다: N×NN \times N 크기의 SS, PP를 HBM에 쓰고 읽는 비용이 사라진다.

현대 GPU에서는 HBM 접근이 계산보다 훨씬 느리므로, FLOP 증가량보다 IO 감소의 이득이 더 크다. 실험 결과, recomputation으로 인한 추가 런타임 오버헤드는 5% 미만이면서, 메모리 사용량은 O(N2)O(N^2)에서 O(N)O(N)으로 감소한다.

7.4 메모리 절약 효과

시퀀스 길이Standard Attention 메모리FlashAttention 메모리절약 비율
1K~2 MB~0.13 MB~15x
2K~8 MB~0.26 MB~30x
4K~33 MB~0.52 MB~63x
8K~131 MB~1.04 MB~126x

이 절약 효과 덕분에 동일한 GPU 메모리로 더 긴 시퀀스를 처리하거나, 더 큰 배치 크기를 사용할 수 있다.


8. FlashAttention-2 개선점

Dao(2023)는 FlashAttention-2에서 세 가지 핵심 개선을 도입했다.

8.1 Non-matmul FLOP 최소화

A100 GPU의 Tensor Core는 행렬 곱셈(matmul)에 대해 312 TFLOPS(FP16)를 제공하지만, non-matmul 연산(softmax의 exp, max, sum 등)은 19.5 TFLOPS(FP32)로 약 16배 느리다. FlashAttention-1에서는 non-matmul 연산 비중이 상당했다.

FlashAttention-2는 알고리즘을 재구성하여 이러한 non-matmul FLOP을 최소화한다. 구체적으로, rescaling 연산의 횟수를 줄이고, softmax 통계치 업데이트를 더 효율적으로 수행한다. 최종 rescaling을 루프 마지막에 한 번만 수행하도록 변경한 것이 핵심이다.

8.2 Parallelism 향상: Sequence Length 차원 병렬화

FlashAttention-1은 batch 차원과 head 차원에서만 병렬화했다. 배치 크기가 작거나 head 수가 적으면 GPU의 SM(Streaming Multiprocessor)을 충분히 활용하지 못했다.

FlashAttention-2는 시퀀스 길이 차원에서도 병렬화한다. 외부 루프를 Q 블록 기준으로 변경하여 (K,VK, V 블록이 아닌 QQ 블록을 외부 루프로), 각 Q 블록을 독립적인 thread block에서 처리할 수 있게 했다. 이 변경으로 forward pass에서의 occupancy가 크게 향상된다.

8.3 Work Partitioning 최적화

Thread block 내에서 warp 간의 작업 분배도 개선되었다:

  • FlashAttention-1: K, V를 4개의 warp에 분할, 각 warp가 독립적으로 QKTQK^T 계산 후 결과를 동기화. 이 방식은 shared memory를 통한 통신과 동기화 오버헤드가 발생한다.
  • FlashAttention-2: Q를 4개의 warp에 분할, K와 V는 모든 warp가 공유. 각 warp는 Q의 다른 부분에 대해 독립적으로 출력을 계산하므로, warp 간 통신이 불필요하다.

8.4 성능 결과

이 세 가지 개선을 합치면:

  • FlashAttention-1 대비 약 2배 speedup
  • A100에서 FP16/BF16 기준 230 TFLOPS 달성 (이론적 최대의 약 73%)
  • Standard PyTorch attention 대비 최대 9배 speedup
  • GEMM(행렬 곱셈) 연산의 효율에 근접

9. FlashAttention-3 최신 발전

FlashAttention-3(Shah et al., 2024)는 NVIDIA Hopper 아키텍처(H100)의 새로운 하드웨어 기능을 활용하여 한 단계 더 발전했다.

9.1 Hopper GPU의 새로운 기능

H100 GPU는 A100에 비해 다음의 핵심 기능을 제공한다:

  • WGMMA (Warpgroup Matrix Multiply-Accumulate): A100의 mma.sync보다 훨씬 높은 처리량을 가진 새로운 Tensor Core 명령어
  • TMA (Tensor Memory Accelerator): Global memory와 Shared memory 간 데이터 전송을 전담하는 하드웨어 유닛. 인덱스 계산과 경계 검사를 하드웨어에서 처리

9.2 세 가지 핵심 기법

1. Warp Specialization을 통한 비동기 실행

연산(WGMMA)과 데이터 이동(TMA)을 서로 다른 warp group에 할당하여 파이프라인 방식으로 중첩(overlap) 실행한다. 한 warp group이 현재 블록을 계산하는 동안, 다른 warp group이 다음 블록의 데이터를 prefetch한다.

2. Matmul과 Softmax의 Interleaving

기존에는 matmul이 끝난 후 softmax를 수행하고, 다시 matmul을 수행하는 순차적 방식이었다. FlashAttention-3는 이를 인터리빙하여, matmul과 softmax가 서로 다른 하드웨어 유닛에서 동시에 실행되도록 한다. Tensor Core가 다음 블록의 QKTQK^T를 계산하는 동안, CUDA Core가 현재 블록의 softmax를 처리한다.

3. FP8 Low-precision 지원

H100의 FP8 Tensor Core를 활용하여 처리량을 2배 높인다. 단순히 FP8로 양자화하면 정확도가 떨어지지만, FlashAttention-3는 두 가지 기법으로 이를 해결한다:

  • Block quantization: 블록 단위로 별도의 스케일 팩터를 유지하여 동적 범위를 보존
  • Incoherent processing: 랜덤 직교 행렬을 곱하여 outlier를 분산시킨 후 양자화. 이를 통해 FP8 baseline 대비 2.6배 낮은 수치 오차 달성

9.3 성능 결과

H100에서의 FlashAttention-3 성능:

설정TFLOPSGPU 활용률
FP16 FlashAttention-2~400~50%
FP16 FlashAttention-3~740~75%
FP8 FlashAttention-3~1,200~75%

FP16에서 FlashAttention-2 대비 1.5-2.0배 speedup, FP8에서는 1.2 PFLOPS에 근접하는 성능을 달성했다.


10. 벤치마크: 속도/메모리 비교

10.1 Attention Forward Pass 속도 (A100 80GB, FP16)

FlashAttention 논문과 후속 벤치마크에서 보고된 주요 수치는 다음과 같다:

시퀀스 길이Standard AttentionFlashAttentionFlashAttention-2Speedup (FA2 vs Std)
51212.2 ms3.5 ms1.9 ms6.4x
1K45.8 ms7.8 ms4.1 ms11.2x
2K178 ms18.9 ms9.8 ms18.2x
4K710 ms52.3 ms27.1 ms26.2x
8KOOM145 ms75 ms-
16KOOM520 ms270 ms-

시퀀스 길이가 길어질수록 speedup이 더 극적으로 증가한다. 8K 이상에서는 standard attention이 OOM(Out of Memory)으로 실행 자체가 불가능하지만, FlashAttention은 문제없이 처리한다.

10.2 End-to-End 학습 성능

모델StandardFlashAttentionSpeedup
BERT-large (seq 512)100% (MLPerf 기준)115%1.15x
GPT-2 (seq 1K)100%300%3.0x
Long-range Arena (seq 1K-4K)100%240%2.4x

10.3 메모리 사용량 비교

FlashAttention의 attention 연산 메모리는 시퀀스 길이에 대해 **선형(linear)**으로, standard attention의 이차(quadratic) 대비 극적인 개선이다:

  • 시퀀스 길이 2K: 약 10배 메모리 절약
  • 시퀀스 길이 4K: 약 20배 메모리 절약
  • 시퀀스 길이 64K: standard attention은 A100 80GB에서도 OOM, FlashAttention은 정상 동작

11. PyTorch torch.nn.functional.scaled_dot_product_attention 연동

11.1 네이티브 통합

PyTorch 2.0부터 FlashAttention이 torch.nn.functional.scaled_dot_product_attention (SDPA)에 네이티브로 통합되어 있다. PyTorch 2.2부터는 FlashAttention-2가 기본 backend로 사용된다.

import torch
import torch.nn.functional as F

# 기본 사용법 - 자동으로 FlashAttention backend 선택
query = torch.randn(batch_size, num_heads, seq_len, head_dim,
                    device='cuda', dtype=torch.float16)
key = torch.randn(batch_size, num_heads, seq_len, head_dim,
                  device='cuda', dtype=torch.float16)
value = torch.randn(batch_size, num_heads, seq_len, head_dim,
                    device='cuda', dtype=torch.float16)

# PyTorch가 자동으로 최적의 backend를 선택한다
output = F.scaled_dot_product_attention(query, key, value)

11.2 Backend 명시적 선택

특정 backend를 강제로 사용하거나 제외할 수 있다:

from torch.nn.attention import sdpa_kernel, SDPBackend

# FlashAttention backend만 사용
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    output = F.scaled_dot_product_attention(query, key, value)

# Memory-efficient attention backend만 사용
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
    output = F.scaled_dot_product_attention(query, key, value)

# Math (naive) backend 사용 - 디버깅용
with sdpa_kernel(SDPBackend.MATH):
    output = F.scaled_dot_product_attention(query, key, value)

# CuDNN backend 사용 (PyTorch 2.2+)
with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
    output = F.scaled_dot_product_attention(query, key, value)

11.3 Causal Mask와 함께 사용

LLM의 autoregressive generation에 필수적인 causal mask도 지원된다:

# is_causal=True로 causal mask 적용
# FlashAttention은 causal mask를 fused kernel 내에서 처리하여 추가 메모리 불필요
output = F.scaled_dot_product_attention(
    query, key, value,
    is_causal=True
)

# 커스텀 attention mask 사용
attn_mask = torch.tril(torch.ones(seq_len, seq_len, device='cuda', dtype=torch.bool))
output = F.scaled_dot_product_attention(
    query, key, value,
    attn_mask=attn_mask
)

11.4 Backend 선택 조건

PyTorch SDPA가 FlashAttention backend를 선택하기 위한 조건:

  • dtype: float16 또는 bfloat16 (float32는 불가)
  • device: CUDA GPU (CPU 지원 불가)
  • head dimension: 최대 256 (FlashAttention-2 기준)
  • attention mask: boolean mask 또는 is_causal=True 지원, 임의의 float mask는 비지원

이 조건을 만족하지 않으면, PyTorch는 자동으로 memory-efficient attention 또는 math backend로 fallback한다.

11.5 실무 적용 팁

# 어떤 backend가 사용되는지 확인
import torch.backends.cuda

# 각 backend의 활성화 상태 확인
print(f"Flash SDP enabled: {torch.backends.cuda.flash_sdp_enabled()}")
print(f"Mem efficient SDP enabled: {torch.backends.cuda.mem_efficient_sdp_enabled()}")
print(f"Math SDP enabled: {torch.backends.cuda.math_sdp_enabled()}")

# 전역적으로 특정 backend 비활성화
torch.backends.cuda.enable_flash_sdp(False)  # FlashAttention 비활성화
torch.backends.cuda.enable_mem_efficient_sdp(True)

11.6 flash-attn 라이브러리 직접 사용

PyTorch 네이티브 SDPA 외에도, Tri Dao의 flash-attn 패키지를 직접 사용할 수 있다. 이 패키지는 PyTorch SDPA보다 더 많은 기능(예: sliding window attention, ALiBi, cross-attention 최적화)을 제공한다:

# pip install flash-attn
from flash_attn import flash_attn_func

# (batch, seqlen, nheads, headdim) 형태
output = flash_attn_func(q, k, v, causal=True)

12. 정리 및 핵심 교훈

FlashAttention의 핵심 교훈은 알고리즘의 FLOP 복잡도만이 성능을 결정하지 않는다는 것이다. 현대 GPU에서는 메모리 접근 패턴이 실제 실행 시간을 지배하며, IO-aware 알고리즘 설계가 실용적 성능에 결정적이다.

주요 기여를 요약하면:

  1. IO-Aware 설계 원칙: GPU 메모리 계층(HBM vs SRAM)의 비대칭성을 활용한 알고리즘 설계
  2. Tiling + Online Softmax: SRAM에 맞는 블록 단위 연산으로 N×NN \times N 행렬의 HBM materialization 제거
  3. Recomputation 전략: Backward pass에서 중간값을 재계산하여 O(N2)O(N^2) -> O(N)O(N) 메모리 절약, 동시에 속도 향상
  4. 최적성 증명: IO complexity 관점에서 하한을 증명하여 알고리즘의 최적성 입증
  5. Exact Computation: 모든 최적화에도 불구하고 근사 없는 exact attention 유지

FlashAttention은 이론적 아름다움과 실용적 효과를 동시에 갖춘 드문 연구로, 현대 LLM 학습과 추론의 핵심 인프라가 되었다. PyTorch의 네이티브 통합으로 인해, 별도의 구현 없이도 F.scaled_dot_product_attention 호출만으로 그 혜택을 누릴 수 있다.


References

FlashAttention: Optimizing Attention Through GPU Memory Hierarchy

1. Introduction

Self-Attention, the core component of the Transformer architecture, computes relationships between all token pairs in a sequence. While this operation provides powerful representational capacity, it suffers from a fundamental limitation: both time and memory complexity grow as O(N2)O(N^2) with respect to the sequence length NN. For state-of-the-art LLMs such as GPT-4, LLaMA, and Gemini to handle long contexts exceeding 128K tokens, this O(N2)O(N^2) bottleneck must be effectively addressed.

FlashAttention (Dao et al., 2022) solves this problem without any approximation. The core idea is simple yet profound: rather than reducing the computational cost of the attention operation itself, it minimizes data movement (IO) between GPU memory hierarchy levels. In this article, we systematically analyze the principles of FlashAttention from a GPU hardware perspective and trace its evolution through FlashAttention-2 and FlashAttention-3.


2. The Memory Problem of Standard Attention

2.1 Standard Attention Computation Flow

Standard Self-Attention is computed as follows. Given inputs Q,K,VRN×dQ, K, V \in \mathbb{R}^{N \times d}:

S=QKTRN×NS = QK^T \in \mathbb{R}^{N \times N} P=softmax(S)RN×NP = \text{softmax}(S) \in \mathbb{R}^{N \times N} O=PVRN×dO = PV \in \mathbb{R}^{N \times d}

Here, NN is the sequence length and dd is the head dimension.

2.2 Memory Complexity Analysis

The crux of the problem lies in the intermediate matrices SS and PP. These matrices are of size N×NN \times N, requiring quadratic memory with respect to sequence length. To put this in concrete numbers:

Sequence Length (NN)Attention Matrix SizeFP16 Memory
1,0241M elements2 MB
4,09616.7M elements33 MB
16,384268M elements536 MB
65,5364.3B elements8.6 GB
131,07217.2B elements34.4 GB

These figures are for a single head and single batch. In multi-head attention, multiplying by the number of heads hh and the batch size BB results in significantly larger actual memory consumption. At sequence length 65,536, even a single head consumes a substantial portion of the HBM on an A100 80GB GPU.

2.3 HBM Bottleneck

In the standard attention implementation, these N×NN \times N matrices are materialized in GPU HBM (High Bandwidth Memory). That is, S=QKTS = QK^T is computed and written to HBM, then read back for softmax, the result PP is written to HBM, and then read back for O=PVO = PV. The total number of HBM reads and writes in this process is Ω(Nd+N2)\Omega(Nd + N^2).

The real reason this operation is slow on actual GPUs is that memory access, not compute, is the bottleneck. The A100 GPU delivers 312 TFLOPS (FP16) of compute throughput, while its HBM bandwidth is only about 2 TB/s. Attention is a classic memory-bound operation due to its low arithmetic intensity (ratio of compute to memory access).


3. GPU Memory Hierarchy

Understanding FlashAttention requires precise knowledge of the GPU memory hierarchy.

3.1 HBM (High Bandwidth Memory)

  • Capacity: 40GB or 80GB on the A100
  • Bandwidth: Approximately 1.5-2.0 TB/s (A100 80GB SXM: 2,039 GB/s)
  • Access latency: Approximately 200-600 cycles
  • Role: The GPU's main memory. All data including model parameters, input tensors, and output tensors are stored here

3.2 SRAM (On-chip Shared Memory)

  • Capacity: Approximately 192KB per SM on the A100, approximately 20MB total (108 SMs)
  • Bandwidth: Approximately 19 TB/s
  • Access latency: Approximately 20-30 cycles
  • Role: High-speed on-chip memory within each Streaming Multiprocessor (SM)

3.3 The Critical Asymmetry

A dramatic asymmetry exists between SRAM and HBM:

PropertySRAMHBM
Bandwidth~19 TB/s~2 TB/s
Capacity~20 MB40-80 GB
Access Latency20-30 cycles200-600 cycles

SRAM is approximately 10x faster than HBM, but approximately 4,000x smaller in capacity. FlashAttention's key insight is to actively exploit this asymmetry: instead of materializing the entire N×NN \times N matrix in HBM, performing computations in small blocks that fit in SRAM can dramatically reduce HBM accesses.


4. IO Complexity Analysis

4.1 IO Complexity of Standard Attention

Standard attention exhibits the following HBM access pattern:

  1. Read Q,KQ, K from HBM, compute S=QKTS = QK^T, write SS to HBM: Θ(Nd+N2)\Theta(Nd + N^2) IO
  2. Read SS from HBM, compute P=softmax(S)P = \text{softmax}(S), write PP to HBM: Θ(N2)\Theta(N^2) IO
  3. Read P,VP, V from HBM, compute O=PVO = PV, write OO to HBM: Θ(Nd+N2)\Theta(Nd + N^2) IO

Total HBM access: Θ(Nd+N2)\Theta(Nd + N^2)

Since the sequence length NN is typically much larger than the head dimension dd (usually 64 or 128), the N2N^2 term dominates.

4.2 FlashAttention's IO Complexity

FlashAttention reduces HBM access through tiling to:

O(N2d2M)O\left(\frac{N^2 d^2}{M}\right)

where MM is the SRAM size. Intuitively, larger SRAM allows processing larger blocks at once, reducing HBM accesses.

4.3 Optimality Proof (Lower Bound)

The paper goes further to prove the following lower bound:

Theorem: For all SRAM sizes MM where dMNdd \leq M \leq Nd, any algorithm computing exact attention requires Ω(N2d2/M)\Omega(N^2 d^2 / M) HBM accesses.

This means FlashAttention is optimal in terms of IO complexity. Excluding constant and polylogarithmic factors, it is impossible to compute exact attention with fewer HBM accesses.

4.4 Numerical Example

For the A100 with SRAM size M192M \approx 192KB, head dimension d=64d = 64, and sequence length N=4096N = 4096:

  • Standard attention IO: Θ(Nd+N2)4096×64+4096217M\Theta(Nd + N^2) \approx 4096 \times 64 + 4096^2 \approx 17M elements
  • FlashAttention IO: Θ(N2d2/M)40962×642/(192×512)7M\Theta(N^2 d^2 / M) \approx 4096^2 \times 64^2 / (192 \times 512) \approx 7M elements (varies with block size)

In practice, since the N2N^2-sized intermediate matrices are never written to HBM at all, the savings are even greater. The benefits become particularly pronounced as sequence length increases.


5. Tiling: Block-wise Computation That Fits in SRAM

5.1 Algorithm Overview

The core FlashAttention algorithm works as follows:

  1. Partition QQ into Tr=N/BrT_r = \lceil N / B_r \rceil blocks: Q1,Q2,,QTrQ_1, Q_2, \ldots, Q_{T_r}, each of size Br×dB_r \times d
  2. Partition K,VK, V into Tc=N/BcT_c = \lceil N / B_c \rceil blocks: K1,,KTcK_1, \ldots, K_{T_c} and V1,,VTcV_1, \ldots, V_{T_c}, each of size Bc×dB_c \times d
  3. Block sizes Br,BcB_r, B_c are set to fit in SRAM of size MM: Bc=M/(4d)B_c = \lceil M / (4d) \rceil, Br=min(M/(4d),d)B_r = \min(\lceil M / (4d) \rceil, d)

5.2 Forward Pass Pseudocode

Algorithm: FlashAttention Forward Pass
---------------------------------------
Input: Q, K, V in HBM, SRAM size M
Output: O in HBM

1. Set block sizes: B_c = ceil(M / 4d), B_r = min(ceil(M / 4d), d)
2. Initialize O = zeros(N, d), l = zeros(N), m = -inf * ones(N) in HBM

3. for j = 1 to T_c:                        # Outer loop: K, V blocks
     Load K_j, V_j from HBM to SRAM

     for i = 1 to T_r:                      # Inner loop: Q blocks
       Load Q_i, O_i, l_i, m_i from HBM to SRAM

       # Perform block-wise computation in SRAM
       S_ij = Q_i @ K_j^T                   # (B_r x B_c)
       m_ij = rowmax(S_ij)
       P_ij = exp(S_ij - m_ij)
       l_ij = rowsum(P_ij)

       # Combine with statistics from previous blocks (Online Softmax)
       m_new = max(m_i, m_ij)
       l_new = exp(m_i - m_new) * l_i + exp(m_ij - m_new) * l_ij

       # Update output (with rescaling)
       O_i = diag(exp(m_i - m_new))^(-1) * (diag(l_i) * O_i)
             + diag(exp(m_ij - m_new))^(-1) * P_ij @ V_j
       O_i = diag(l_new)^(-1) * O_i

       # Update statistics
       m_i = m_new, l_i = l_new

       Write O_i, l_i, m_i back to HBM
     end for
   end for

4. return O

5.3 Why This Works

The key point is that the N×NN \times N attention matrices SS and PP are never materialized in HBM. Each Br×BcB_r \times B_c block SijS_{ij} is computed within SRAM, immediately used for softmax statistics updates and output accumulation, and then discarded.

The mathematical technique that makes this possible is Online Softmax.


6. Online Softmax (Safe Softmax) Algorithm

6.1 The Problem with Standard Softmax

Softmax is a global operation. For a row vector x=[x1,,xN]x = [x_1, \ldots, x_N]:

softmax(xi)=exij=1Nexj\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{N} e^{x_j}}

Computing this requires seeing the entire row at once to calculate the denominator sum. This is the fundamental barrier that makes tiling difficult -- looking only at block Si1S_{i1} is insufficient to complete the softmax, because the denominator changes depending on the values in the remaining blocks Si2,Si3,S_{i2}, S_{i3}, \ldots.

Additionally, for numerical stability, "safe softmax" is used:

softmax(xi)=eximj=1Nexjm,m=maxjxj\text{softmax}(x_i) = \frac{e^{x_i - m}}{\sum_{j=1}^{N} e^{x_j - m}}, \quad m = \max_j x_j

This also requires the global maximum mm, necessitating a scan of the entire row first.

6.2 The Online Softmax Trick

The key idea of Online Softmax (Milakov & Gimelshein, 2018) is to compute softmax incrementally, block by block, while maintaining running statistics.

Two scalars are maintained per row:

  • mm: the maximum of all elements seen so far (running max)
  • \ell: the normalization constant so far (running sum of exponentials)

When a new block SijS_{ij} arrives:

  1. Compute the row-wise maximum of the new block: m~=rowmax(Sij)\tilde{m} = \text{rowmax}(S_{ij})
  2. Update the global maximum: mnew=max(m,m~)m_{\text{new}} = \max(m, \tilde{m})
  3. Rescale the previous normalization constant: new=emmnew+em~mnew~\ell_{\text{new}} = e^{m - m_{\text{new}}} \cdot \ell + e^{\tilde{m} - m_{\text{new}}} \cdot \tilde{\ell}
  4. Rescale the previous output: Onew=emmnewO+em~mnewP~VjnewO_{\text{new}} = \frac{e^{m - m_{\text{new}}} \cdot \ell \cdot O + e^{\tilde{m} - m_{\text{new}}} \cdot \tilde{P}V_j}{\ell_{\text{new}}}

This process is mathematically exact. It is not an approximation. Regardless of the order in which blocks are processed, the final result is identical to standard attention bit-for-bit (except for minor numerical differences due to floating-point operation ordering).

6.3 Mathematical Justification

The core of the proof is the rescaling property of softmax:

eximjexjm=eximemmjexjmemm=eximjexjm\frac{e^{x_i - m'}}{\sum_j e^{x_j - m'}} = \frac{e^{x_i - m} \cdot e^{m - m'}}{\sum_j e^{x_j - m} \cdot e^{m - m'}} = \frac{e^{x_i - m}}{\sum_j e^{x_j - m}}

Even when the maximum is updated from mm to mm', the same factor emme^{m - m'} multiplies both numerator and denominator, so the ratio remains unchanged. This property allows the results from previous blocks to be safely rescaled to the new maximum.


7. Backward Pass Recomputation Strategy

7.1 The Problem with Standard Backward Pass

In the standard attention backward pass, the intermediate matrices SS and PP saved during the forward pass are needed for gradient computation. Since their size is N×NN \times N, storing them in the forward pass and reading them back in the backward pass requires O(N2)O(N^2) memory.

7.2 FlashAttention's Recomputation

FlashAttention uses a variant of gradient checkpointing. During the forward pass, SS and PP are not saved. Instead, only the following are stored:

  • The final output ORN×dO \in \mathbb{R}^{N \times d}
  • The softmax normalization statistics m,RNm, \ell \in \mathbb{R}^{N} (per-row maximum and sum)

During the backward pass, these statistics and the original Q,K,VQ, K, V are used to recompute the needed blocks of SS and PP in SRAM. This recomputation requires additional FLOPs but significantly reduces HBM access.

7.3 The Paradoxical Effect of Recomputation

Typically, gradient checkpointing saves memory at the cost of speed. However, FlashAttention's recomputation actually improves speed as well. The reason is:

  • FLOPs increase: Since what was computed once in the forward pass is recomputed in the backward pass, total FLOPs slightly increase.
  • HBM IO decreases: The cost of writing and reading the N×NN \times N-sized SS and PP to and from HBM is eliminated.

On modern GPUs, HBM access is much slower than computation, so the benefit from IO reduction outweighs the FLOP increase. Experimental results show that the additional runtime overhead from recomputation is less than 5%, while memory usage decreases from O(N2)O(N^2) to O(N)O(N).

7.4 Memory Savings

Sequence LengthStandard Attention MemoryFlashAttention MemorySavings Ratio
1K~2 MB~0.13 MB~15x
2K~8 MB~0.26 MB~30x
4K~33 MB~0.52 MB~63x
8K~131 MB~1.04 MB~126x

These savings enable processing longer sequences or using larger batch sizes with the same GPU memory.


8. FlashAttention-2 Improvements

Dao (2023) introduced three key improvements in FlashAttention-2.

8.1 Minimizing Non-matmul FLOPs

The A100 GPU's Tensor Cores deliver 312 TFLOPS (FP16) for matrix multiplication (matmul), but non-matmul operations (softmax's exp, max, sum, etc.) run at 19.5 TFLOPS (FP32) -- approximately 16x slower. In FlashAttention-1, the proportion of non-matmul operations was significant.

FlashAttention-2 restructures the algorithm to minimize these non-matmul FLOPs. Specifically, it reduces the number of rescaling operations and performs softmax statistics updates more efficiently. The key change is performing the final rescaling only once at the end of the loop.

8.2 Improved Parallelism: Sequence Length Dimension Parallelization

FlashAttention-1 only parallelized across the batch and head dimensions. When the batch size was small or the number of heads was low, the GPU's SMs (Streaming Multiprocessors) were underutilized.

FlashAttention-2 also parallelizes across the sequence length dimension. By changing the outer loop to iterate over Q blocks (rather than K,VK, V blocks), each Q block can be processed by an independent thread block. This change significantly improves occupancy during the forward pass.

8.3 Work Partitioning Optimization

Work distribution among warps within a thread block was also improved:

  • FlashAttention-1: K, V are split across 4 warps, each warp independently computes QKTQK^T and then synchronizes results. This approach incurs communication and synchronization overhead through shared memory.
  • FlashAttention-2: Q is split across 4 warps, while K and V are shared by all warps. Since each warp computes outputs for different parts of Q independently, no inter-warp communication is needed.

8.4 Performance Results

Combining these three improvements:

  • Approximately 2x speedup over FlashAttention-1
  • Achieves 230 TFLOPS in FP16/BF16 on the A100 (approximately 73% of the theoretical maximum)
  • Up to 9x speedup over standard PyTorch attention
  • Approaches the efficiency of GEMM (matrix multiplication) operations

9. FlashAttention-3: Latest Advances

FlashAttention-3 (Shah et al., 2024) took a further step forward by leveraging the new hardware capabilities of the NVIDIA Hopper architecture (H100).

9.1 New Capabilities of the Hopper GPU

The H100 GPU provides the following key capabilities over the A100:

  • WGMMA (Warpgroup Matrix Multiply-Accumulate): A new Tensor Core instruction with much higher throughput than the A100's mma.sync
  • TMA (Tensor Memory Accelerator): A dedicated hardware unit for data transfers between global memory and shared memory, handling index computation and bounds checking in hardware

9.2 Three Key Techniques

1. Asynchronous Execution via Warp Specialization

Computation (WGMMA) and data movement (TMA) are assigned to different warp groups for pipelined, overlapping execution. While one warp group computes the current block, another prefetches data for the next block.

2. Interleaving of Matmul and Softmax

Previously, matmul was followed by softmax, then another matmul, in a sequential manner. FlashAttention-3 interleaves these so that matmul and softmax execute simultaneously on different hardware units. While Tensor Cores compute QKTQK^T for the next block, CUDA Cores process the softmax of the current block.

3. FP8 Low-precision Support

Leveraging the H100's FP8 Tensor Cores doubles throughput. Naive FP8 quantization degrades accuracy, but FlashAttention-3 addresses this with two techniques:

  • Block quantization: Maintaining separate scale factors per block to preserve dynamic range
  • Incoherent processing: Multiplying by a random orthogonal matrix to distribute outliers before quantization, achieving 2.6x lower numerical error compared to the FP8 baseline

9.3 Performance Results

FlashAttention-3 performance on the H100:

ConfigurationTFLOPSGPU Utilization
FP16 FlashAttention-2~400~50%
FP16 FlashAttention-3~740~75%
FP8 FlashAttention-3~1,200~75%

In FP16, it achieves a 1.5-2.0x speedup over FlashAttention-2, and in FP8, it approaches 1.2 PFLOPS.


10. Benchmarks: Speed and Memory Comparison

10.1 Attention Forward Pass Speed (A100 80GB, FP16)

Key figures reported in the FlashAttention paper and subsequent benchmarks:

Sequence LengthStandard AttentionFlashAttentionFlashAttention-2Speedup (FA2 vs Std)
51212.2 ms3.5 ms1.9 ms6.4x
1K45.8 ms7.8 ms4.1 ms11.2x
2K178 ms18.9 ms9.8 ms18.2x
4K710 ms52.3 ms27.1 ms26.2x
8KOOM145 ms75 ms-
16KOOM520 ms270 ms-

The speedup becomes increasingly dramatic as sequence length grows. At 8K and above, standard attention fails with OOM (Out of Memory), while FlashAttention handles them without issue.

10.2 End-to-End Training Performance

ModelStandardFlashAttentionSpeedup
BERT-large (seq 512)100% (MLPerf ref.)115%1.15x
GPT-2 (seq 1K)100%300%3.0x
Long-range Arena (seq 1K-4K)100%240%2.4x

10.3 Memory Usage Comparison

FlashAttention's attention operation memory scales linearly with sequence length, a dramatic improvement over standard attention's quadratic scaling:

  • Sequence length 2K: approximately 10x memory savings
  • Sequence length 4K: approximately 20x memory savings
  • Sequence length 64K: standard attention causes OOM even on an A100 80GB, while FlashAttention runs normally

11. Integration with PyTorch torch.nn.functional.scaled_dot_product_attention

11.1 Native Integration

Starting with PyTorch 2.0, FlashAttention is natively integrated into torch.nn.functional.scaled_dot_product_attention (SDPA). Since PyTorch 2.2, FlashAttention-2 is used as the default backend.

import torch
import torch.nn.functional as F

# Basic usage - automatically selects FlashAttention backend
query = torch.randn(batch_size, num_heads, seq_len, head_dim,
                    device='cuda', dtype=torch.float16)
key = torch.randn(batch_size, num_heads, seq_len, head_dim,
                  device='cuda', dtype=torch.float16)
value = torch.randn(batch_size, num_heads, seq_len, head_dim,
                    device='cuda', dtype=torch.float16)

# PyTorch automatically selects the optimal backend
output = F.scaled_dot_product_attention(query, key, value)

11.2 Explicit Backend Selection

You can force or exclude specific backends:

from torch.nn.attention import sdpa_kernel, SDPBackend

# Use FlashAttention backend only
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    output = F.scaled_dot_product_attention(query, key, value)

# Use Memory-efficient attention backend only
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
    output = F.scaled_dot_product_attention(query, key, value)

# Use Math (naive) backend - for debugging
with sdpa_kernel(SDPBackend.MATH):
    output = F.scaled_dot_product_attention(query, key, value)

# Use CuDNN backend (PyTorch 2.2+)
with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
    output = F.scaled_dot_product_attention(query, key, value)

11.3 Using with Causal Mask

Causal masking, essential for autoregressive generation in LLMs, is also supported:

# Apply causal mask with is_causal=True
# FlashAttention handles the causal mask within the fused kernel, requiring no extra memory
output = F.scaled_dot_product_attention(
    query, key, value,
    is_causal=True
)

# Using a custom attention mask
attn_mask = torch.tril(torch.ones(seq_len, seq_len, device='cuda', dtype=torch.bool))
output = F.scaled_dot_product_attention(
    query, key, value,
    attn_mask=attn_mask
)

11.4 Backend Selection Criteria

Conditions for PyTorch SDPA to select the FlashAttention backend:

  • dtype: float16 or bfloat16 (float32 is not supported)
  • device: CUDA GPU (CPU is not supported)
  • head dimension: Maximum 256 (for FlashAttention-2)
  • attention mask: Boolean mask or is_causal=True are supported; arbitrary float masks are not supported

If these conditions are not met, PyTorch automatically falls back to the memory-efficient attention or math backend.

11.5 Practical Tips

# Check which backend is being used
import torch.backends.cuda

# Check the enabled state of each backend
print(f"Flash SDP enabled: {torch.backends.cuda.flash_sdp_enabled()}")
print(f"Mem efficient SDP enabled: {torch.backends.cuda.mem_efficient_sdp_enabled()}")
print(f"Math SDP enabled: {torch.backends.cuda.math_sdp_enabled()}")

# Globally disable a specific backend
torch.backends.cuda.enable_flash_sdp(False)  # Disable FlashAttention
torch.backends.cuda.enable_mem_efficient_sdp(True)

11.6 Using the flash-attn Library Directly

In addition to PyTorch's native SDPA, you can use Tri Dao's flash-attn package directly. This package provides more features than PyTorch SDPA (e.g., sliding window attention, ALiBi, cross-attention optimization):

# pip install flash-attn
from flash_attn import flash_attn_func

# Shape: (batch, seqlen, nheads, headdim)
output = flash_attn_func(q, k, v, causal=True)

12. Summary and Key Takeaways

The key lesson of FlashAttention is that FLOP complexity alone does not determine performance. On modern GPUs, memory access patterns dominate actual execution time, and IO-aware algorithm design is decisive for practical performance.

The main contributions can be summarized as follows:

  1. IO-Aware Design Principle: Algorithm design that exploits the asymmetry of the GPU memory hierarchy (HBM vs SRAM)
  2. Tiling + Online Softmax: Block-wise computation that fits in SRAM, eliminating HBM materialization of the N×NN \times N matrix
  3. Recomputation Strategy: Recomputing intermediate values in the backward pass to reduce memory from O(N2)O(N^2) to O(N)O(N), while simultaneously improving speed
  4. Optimality Proof: Proving the lower bound from an IO complexity perspective to establish the algorithm's optimality
  5. Exact Computation: Maintaining exact attention without approximation despite all optimizations

FlashAttention is a rare piece of research that combines theoretical elegance with practical effectiveness, and it has become core infrastructure for modern LLM training and inference. Thanks to its native integration in PyTorch, its benefits can be enjoyed simply by calling F.scaled_dot_product_attention without any additional implementation.


References