- Published on
FlashAttention 논문 분석: IO-Aware Exact Attention으로 Transformer 학습·추론 속도 혁신
- Authors
- Name
- 들어가며
- Standard Attention의 문제점: O(N^2) 메모리와 IO 병목
- GPU 메모리 계층 구조: SRAM vs HBM
- FlashAttention v1: Tiling + Recomputation
- FlashAttention-2: 병렬화와 Work Partitioning 개선
- FlashAttention-3: FP8, 비동기 파이프라이닝
- 성능 벤치마크 비교
- 실전 적용: PyTorch와 Triton 코드
- 트러블슈팅: 실패 사례와 복구
- 운영 시 주의사항
- 마치며
- 참고자료

들어가며
Transformer 아키텍처는 NLP, 컴퓨터 비전, 음성 처리 등 거의 모든 딥러닝 분야의 기반 모델로 자리잡았다. 그러나 Self-Attention 메커니즘의 메모리 복잡도와 반복적인 GPU 메모리 접근은 학습과 추론 모두에서 심각한 병목 지점을 형성한다. 특히 시퀀스 길이가 수천에서 수만 토큰으로 확장되는 현대 LLM 시대에, 이 병목은 더 이상 무시할 수 없는 수준에 이르렀다.
2022년, Stanford의 Tri Dao를 필두로 한 연구진은 이 문제를 알고리즘 관점이 아닌 하드웨어 IO 관점에서 공략하는 FlashAttention을 발표했다. FlashAttention은 기존 어텐션의 정확도를 전혀 손상시키지 않으면서(exact attention), GPU 메모리 계층 구조를 명시적으로 고려한 IO-aware 알고리즘 설계를 통해 학습 시 24배의 벽시계 속도 향상과 520배의 메모리 절감을 달성했다.
이후 FlashAttention-2(2023)에서는 GPU 내부의 워프(warp) 레벨 작업 분배를 최적화하여 A100에서 이론적 최대 FLOPS의 50~73%를 달성했고, FlashAttention-3(2024)에서는 Hopper 아키텍처(H100)의 비동기 실행과 FP8 텐서 코어를 활용하여 H100에서 740 TFLOPS/s(FP16)와 1.2 PFLOPS/s(FP8)라는 경이적인 성능을 기록했다.
이 글에서는 FlashAttention 시리즈의 전체 발전 과정을 논문 수준에서 분석한다. Standard Attention의 근본적 문제를 GPU 메모리 계층 관점에서 진단하고, v1의 tiling과 recomputation 전략, v2의 병렬화 개선, v3의 비동기 파이프라이닝과 저정밀도 지원까지 단계적으로 다룬다. 나아가 PyTorch 및 Triton 기반 실전 코드, 성능 벤치마크, 그리고 프로덕션 적용 시 마주치는 실패 사례와 복구 방법까지 포괄적으로 정리한다.
Standard Attention의 문제점: O(N^2) 메모리와 IO 병목
수학적 정의
Self-Attention의 수학적 정의는 다음과 같다.
여기서 이고, 은 시퀀스 길이, 는 헤드 차원이다. 문제는 중간 행렬 을 전체적으로 materialization(물리적 저장) 해야 한다는 점이다.
메모리 분석
시퀀스 길이별 어텐션 스코어 행렬의 메모리 요구량은 다음과 같다.
| 시퀀스 길이 (N) | 어텐션 행렬 크기 | FP16 메모리 | FP32 메모리 |
|---|---|---|---|
| 1,024 | 1M | 2 MB | 4 MB |
| 4,096 | 16.7M | 33 MB | 67 MB |
| 16,384 | 268M | 536 MB | 1.07 GB |
| 65,536 | 4.29B | 8.59 GB | 17.18 GB |
| 131,072 | 17.18B | 34.36 GB | 68.72 GB |
배치 크기와 헤드 수를 곱하면 실제 메모리 사용량은 이보다 훨씬 커진다. 예를 들어 batch=4, heads=32, N=16384인 경우, 어텐션 스코어만으로 약 68GB가 필요하여 A100 80GB의 거의 전부를 소비한다.
IO 병목의 실체
그런데 순수 연산 관점에서 보면 이야기가 달라진다. Self-Attention의 FLOPS는 이고, 메모리 접근량은 이다. 여기서 산술 강도(arithmetic intensity)를 계산하면 대략 정도인데, 현대 GPU의 FLOPS 대 메모리 대역폭 비율(ops:byte ratio)에 비해 상당히 낮다.
A100 GPU의 경우, 텐서 코어 FP16 성능은 312 TFLOPS/s이고 HBM 대역폭은 2TB/s이다. 이는 ops:byte ratio가 약 156이라는 뜻인데, 헤드 차원 가 일반적으로 64~128인 self-attention은 이 비율에 비해 연산 밀도가 매우 낮다. 즉, GPU는 연산을 기다리는 것이 아니라 데이터를 읽고 쓰는 것을 기다리고 있다. 이것이 바로 Standard Attention이 GPU 활용률 30% 이하에 머무르는 근본 원인이다.
import torch
import torch.nn.functional as F
import time
def standard_attention(Q, K, V, mask=None):
"""Standard Self-Attention: 전체 N x N 스코어 행렬을 materialization한다."""
d_k = Q.size(-1)
# S = Q @ K^T -> (batch, heads, N, N) 전체 메모리에 저장
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# softmax도 N x N 행렬 전체에 대해 수행
attn_weights = F.softmax(scores, dim=-1) # (batch, heads, N, N) 저장
# 최종 출력 계산
output = torch.matmul(attn_weights, V)
return output
# 벤치마크
batch, heads, seq_len, d_head = 4, 32, 4096, 64
Q = torch.randn(batch, heads, seq_len, d_head, device='cuda', dtype=torch.float16)
K = torch.randn_like(Q)
V = torch.randn_like(Q)
# 워밍업
for _ in range(3):
_ = standard_attention(Q, K, V)
torch.cuda.synchronize()
# 측정
start = time.time()
for _ in range(100):
_ = standard_attention(Q, K, V)
torch.cuda.synchronize()
elapsed = (time.time() - start) / 100
print(f"Standard Attention: {elapsed*1000:.2f} ms/iter")
print(f"Peak memory: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")
이 코드에서 scores 텐서가 의 HBM을 소비하며, 이 값을 HBM에서 읽고 쓰는 과정에서 대부분의 시간이 소모된다.
GPU 메모리 계층 구조: SRAM vs HBM
FlashAttention의 핵심 통찰은 GPU의 메모리 계층 구조를 명시적으로 활용하는 것이다. GPU에는 크게 두 가지 메모리 수준이 존재한다.
HBM (High Bandwidth Memory)
- 용량: 40
80GB (A100), 80141GB (H100/H200) - 대역폭: 1.5~3.35 TB/s
- 역할: 모델 가중치, 활성화 값, 옵티마이저 상태 등 대용량 데이터 저장
- 특성: 용량은 크지만 접근 지연 시간이 상대적으로 길다
SRAM (On-chip Static RAM)
- 용량: SM당 192KB (A100), 전체 약 20~40MB 수준
- 대역폭: ~19 TB/s (A100 기준)
- 역할: 각 Streaming Multiprocessor(SM)의 공유 메모리로 커널 실행 중 임시 데이터 저장
- 특성: HBM 대비 약 10배 빠른 접근 속도, 하지만 용량이 약 1000배 작다
| 메모리 계층 | 용량 (A100) | 대역폭 | 접근 지연 | 비유 |
|---|---|---|---|---|
| L1/SRAM | ~20MB (총합) | ~19 TB/s | ~28 사이클 | 책상 위 메모 |
| L2 Cache | 40MB | ~5 TB/s | ~200 사이클 | 서랍 |
| HBM | 80GB | ~2 TB/s | ~400 사이클 | 서고 |
| CPU RAM | ~1TB | ~50 GB/s | ~수천 사이클 | 도서관 |
Standard Attention의 문제는 중간 결과(, )를 모두 HBM에 쓰고 다시 읽는다는 점이다. 전체 어텐션 연산에서 HBM 접근 횟수는 이다. FlashAttention의 목표는 이 접근 횟수를 SRAM 활용을 통해 로 줄이는 것이다. 여기서 은 SRAM 크기이며, 일반적인 (64128)와 (약 100KB192KB)에서 이 값은 standard 접근보다 수 배에서 수십 배 작다.
FlashAttention v1: Tiling + Recomputation
핵심 아이디어
FlashAttention v1의 알고리즘은 두 가지 핵심 기법으로 구성된다.
Tiling (타일링): Q, K, V 행렬을 SRAM에 들어가는 크기의 블록으로 분할하여 블록 단위로 어텐션을 계산한다. 이 과정에서 N x N 크기의 어텐션 스코어 행렬을 HBM에 한 번도 전체적으로 저장하지 않는다.
Recomputation (재계산): 역전파(backward pass)에서 어텐션 스코어를 저장해두지 않고, 순전파에서 저장한 소프트맥스 정규화 통계량(최대값 과 합계 )만을 사용하여 필요할 때 다시 계산한다.
Online Softmax와 블록 단위 누적
전체 행에 대한 softmax를 블록 단위로 정확하게 계산하기 위해 Online Softmax 기법을 사용한다. 각 블록 를 처리한 후의 누적 출력은 다음과 같이 갱신된다.
여기서 와 는 현재 블록의 로컬 softmax 통계량이다. 이 수식을 통해 전체 softmax를 한 번에 계산하지 않고도 블록별로 점진적으로 정확한 결과를 누적할 수 있다.
알고리즘 의사코드
FlashAttention v1의 순전파 알고리즘을 Python 의사코드로 표현하면 다음과 같다.
import torch
import math
def flash_attention_forward(Q, K, V, block_size_q, block_size_kv):
"""FlashAttention v1 Forward Pass 의사코드.
실제 CUDA 커널에서는 하나의 GPU 커널 내에서 모든 연산이 수행되며,
중간 결과는 SRAM에만 유지된다.
"""
batch, heads, N, d = Q.shape
O = torch.zeros_like(Q) # 출력 누적
m = torch.full((batch, heads, N, 1), float('-inf'), device=Q.device) # 행별 최대값
l = torch.zeros((batch, heads, N, 1), device=Q.device) # 행별 합
# 외부 루프: Q를 블록 단위로 순회
for i in range(0, N, block_size_q):
qi = Q[:, :, i:i+block_size_q, :] # SRAM으로 로드
# 내부 루프: K, V를 블록 단위로 순회
for j in range(0, N, block_size_kv):
kj = K[:, :, j:j+block_size_kv, :] # SRAM으로 로드
vj = V[:, :, j:j+block_size_kv, :] # SRAM으로 로드
# 1. 로컬 어텐션 스코어 계산 (SRAM 내에서)
sij = torch.matmul(qi, kj.transpose(-2, -1)) / math.sqrt(d)
# 2. 로컬 softmax 통계량
mij_local = sij.max(dim=-1, keepdim=True).values
pij_local = torch.exp(sij - mij_local)
lij_local = pij_local.sum(dim=-1, keepdim=True)
# 3. Online Softmax로 글로벌 통계량 갱신
mi_old = m[:, :, i:i+block_size_q, :]
li_old = l[:, :, i:i+block_size_q, :]
oi_old = O[:, :, i:i+block_size_q, :]
mi_new = torch.maximum(mi_old, mij_local)
alpha = torch.exp(mi_old - mi_new)
beta = torch.exp(mij_local - mi_new)
li_new = alpha * li_old + beta * lij_local
# 4. 출력 누적 갱신
O[:, :, i:i+block_size_q, :] = (
alpha * li_old * oi_old + beta * torch.matmul(pij_local, vj)
) / li_new
m[:, :, i:i+block_size_q, :] = mi_new
l[:, :, i:i+block_size_q, :] = li_new
return O, m, l # m, l은 backward에서 recomputation에 사용
Recomputation 전략
역전파에서 기존 방식은 순전파에서 저장한 행렬()을 사용하여 그래디언트를 계산한다. FlashAttention은 이 행렬을 저장하지 않고, 순전파에서 저장한 과 통계량(각각 크기)만으로 역전파 시 와 를 블록 단위로 다시 계산한다.
이 재계산에 드는 추가 연산량은 전체 순전파 FLOPS의 약 33% 수준이지만, HBM 접근을 극적으로 줄임으로써 벽시계 시간 기준으로는 오히려 2~4배 빨라지는 결과를 얻는다. 이는 현대 GPU가 compute-bound가 아닌 memory-bound 상태에 있다는 사실을 역설적으로 증명하는 결과이다.
IO 복잡도 증명
논문에서 증명한 FlashAttention의 HBM 접근 횟수는 다음과 같다.
여기서 은 SRAM 크기이다. Standard Attention의 와 비교하면, , , 일 때 FlashAttention은 약 9배 적은 HBM 접근을 수행한다. 또한 논문은 어떤 exact attention 알고리즘도 보다 적은 HBM 접근을 달성할 수 없음을 증명하여, FlashAttention이 IO 복잡도 관점에서 최적(optimal) 임을 보였다.
FlashAttention-2: 병렬화와 Work Partitioning 개선
v1의 한계
FlashAttention v1은 A100에서 이론적 최대 FLOPS의 약 30~50%만 활용했다. 그 원인은 세 가지로 분석되었다.
- 비효율적인 non-matmul 연산: softmax의 rescaling, 마스킹 연산 등 행렬 곱이 아닌 연산이 GPU 텐서 코어를 활용하지 못한다.
- 시퀀스 길이 방향 병렬화 부재: 배치와 헤드 차원에서만 병렬화되어, 작은 배치/적은 헤드 수에서 GPU 점유율이 낮다.
- 워프 간 비효율적 작업 분배: 하나의 스레드 블록 내에서 워프들이 공유 메모리를 통해 불필요한 동기화를 수행한다.
주요 개선사항
FlashAttention-2는 다음 세 가지 최적화를 도입했다.
1. Non-matmul FLOPS 절감
Online softmax의 rescaling 연산을 재구성하여 불필요한 스케일링 횟수를 줄였다. 또한 causal masking의 경우, 마스크가 필요없는 블록은 마스킹 연산 자체를 건너뛰도록 개선했다.
2. 시퀀스 길이 방향 병렬화
Forward pass에서 외부 루프를 Q 블록(행)이 아닌 K/V 블록(열)으로 변경하여 시퀀스 차원에서도 병렬화가 가능하도록 했다. Backward pass에서는 Q 블록과 K/V 블록 모두에 대해 병렬화를 적용했다. 이를 통해 배치 크기가 1이고 헤드 수가 적은 경우에도 높은 GPU 점유율을 유지할 수 있게 되었다.
3. 워프 레벨 작업 분배 최적화
v1에서는 4개의 워프가 블록과 블록을 나눠 처리한 뒤, 결과를 공유 메모리를 통해 합산했다. v2에서는 4개의 워프가 모두 같은 Q 블록을 처리하되 서로 다른 K/V 블록을 담당하도록 변경했다. 이렇게 하면 각 워프의 결과를 합산할 필요 없이 독립적으로 Q의 출력에 누적할 수 있어 공유 메모리 동기화가 크게 줄어든다.
| 항목 | FlashAttention v1 | FlashAttention-2 |
|---|---|---|
| A100 FP16 달성률 | ~30-50% | ~50-73% |
| 최대 TFLOPS (A100) | ~170 | ~230 |
| 외부 루프 축 | Q 블록 (행) | K/V 블록 (열) |
| 워프 분할 방식 | Q와 KV 분할 | KV만 분할, Q 공유 |
| Causal 최적화 | 모든 블록 마스킹 | 불필요 블록 스킵 |
| 시퀀스 병렬화 | 미지원 | 지원 |
FlashAttention-3: FP8, 비동기 파이프라이닝
Hopper 아키텍처 활용
FlashAttention-3는 NVIDIA Hopper 아키텍처(H100)의 세 가지 핵심 하드웨어 기능을 활용한다.
1. WGMMA (Warpgroup Matrix-Multiply-Accumulate)
Hopper의 새로운 행렬 곱 명령어로, 이전 세대의 WMMA보다 더 큰 타일 크기와 높은 처리량을 제공한다. 특히 비동기 실행을 지원하여 데이터 이동과 연산을 동시에 수행할 수 있다.
2. TMA (Tensor Memory Accelerator)
HBM에서 공유 메모리(SRAM)로의 데이터 전송을 하드웨어 수준에서 비동기적으로 처리하는 전용 유닛이다. CPU가 DMA 컨트롤러에게 데이터 전송을 위임하는 것과 유사하게, GPU의 계산 유닛이 메모리 전송을 기다리지 않고 다른 작업을 수행할 수 있다.
3. FP8 텐서 코어
Hopper는 FP8(E4M3, E5M2) 형식을 하드웨어 수준에서 지원하며, FP16 대비 2배의 연산 처리량을 제공한다.
워프 특화(Warp Specialization)
FlashAttention-3의 핵심 기법 중 하나는 워프 특화(Warp Specialization) 이다. 하나의 CTA(Cooperative Thread Array) 내의 워프들을 **생산자(producer)**와 소비자(consumer) 역할로 나눈다.
- 생산자 워프: TMA를 사용하여 HBM에서 SRAM으로 K/V 블록을 비동기적으로 로드한다.
- 소비자 워프: WGMMA를 사용하여 이미 SRAM에 로드된 데이터로 행렬 곱과 softmax를 수행한다.
Hopper의 setmaxnreg 명령어를 통해 소비자 워프에 더 많은 레지스터를 동적으로 할당할 수 있어, 생산자 워프는 최소한의 리소스로 데이터 전송만 담당하고, 소비자 워프는 최대한의 레지스터를 활용하여 연산을 수행한다.
GEMM-Softmax 비동기 파이프라이닝
FlashAttention-3는 2-stage 파이프라인을 구성하여 GEMM 연산과 softmax 연산을 오버랩시킨다. 블록 의 softmax를 계산하는 동안, 블록 의 GEMM이 동시에 진행된다. 이러한 파이프라이닝은 softmax가 non-matmul 연산이라 텐서 코어를 사용하지 않는다는 점을 활용한 것으로, 텐서 코어의 유휴 시간을 최소화한다.
FP8 지원과 정확도 유지
FP8의 제한된 정밀도로 인한 정확도 하락을 두 가지 기법으로 완화한다.
Block Quantization: 각 블록별로 독립적인 스케일 팩터를 계산하여 동적 범위를 최대화한다. 전체 텐서에 단일 스케일을 적용하는 것보다 표현 범위가 크게 향상된다.
Incoherent Processing: 와 에 랜덤 직교 행렬을 곱하여 값의 분포를 균일하게 만든 뒤 양자화한다. 이론적으로 이 변환은 양자화 오차의 기댓값을 최소화하며, 어텐션 계산 후 역변환을 적용하여 정확도를 복원한다.
| 항목 | FlashAttention-2 | FlashAttention-3 (FP16) | FlashAttention-3 (FP8) |
|---|---|---|---|
| 대상 GPU | A100 | H100 | H100 |
| 최대 TFLOPS | ~230 | ~740 | ~1,200 |
| 이론적 활용률 | ~73% | ~75% | ~76% |
| 워프 전략 | 균일 분배 | 생산자/소비자 특화 | 생산자/소비자 특화 |
| 비동기 파이프라인 | 미지원 | GEMM-Softmax 오버랩 | GEMM-Softmax 오버랩 |
| FP8 정밀도 보정 | - | - | Block Quant + Incoherent |
성능 벤치마크 비교
학습 속도 비교 (GPT-2 모델 기준)
| 구성 | Standard Attention | FlashAttention v1 | FlashAttention-2 | FlashAttention-3 |
|---|---|---|---|---|
| GPU | A100 | A100 | A100 | H100 |
| GPT-2 Small (seq=1K) | 1.0x | 1.7x | 2.3x | 3.8x |
| GPT-2 Medium (seq=1K) | 1.0x | 1.8x | 2.5x | 4.1x |
| GPT-2 Large (seq=2K) | 1.0x | 2.1x | 2.8x | 4.6x |
| GPT-2 XL (seq=2K) | OOM | 1.0x (baseline) | 1.5x | 2.8x |
| seq=4K, d=64 | OOM | 1.0x | 1.7x | 3.2x |
| seq=16K, d=128 | OOM | 1.0x | 1.9x | 3.5x |
FlashAttention v1 대비 v2는 약 1.52배, v3는 FP16 기준 약 1.52배(H100 하드웨어 향상 포함 시 3~4배) 빨라졌다. 특히 시퀀스 길이가 길어질수록 개선 폭이 커진다.
추론 속도 비교 (Prefill 단계)
추론의 prefill 단계는 학습과 유사하게 전체 시퀀스에 대한 어텐션을 한 번에 계산하므로, FlashAttention의 이점이 크게 적용된다.
| 시퀀스 길이 | Standard (ms) | FlashAttention-2 (ms) | 속도 향상 |
|---|---|---|---|
| 512 | 0.8 | 0.5 | 1.6x |
| 2,048 | 5.2 | 2.1 | 2.5x |
| 8,192 | 78.3 | 18.6 | 4.2x |
| 32,768 | OOM | 285.0 | - |
실전 적용: PyTorch와 Triton 코드
PyTorch SDPA를 통한 FlashAttention 사용
PyTorch 2.0 이후 torch.nn.functional.scaled_dot_product_attention에 FlashAttention이 통합되어, 별도 라이브러리 설치 없이 사용할 수 있다.
import torch
import torch.nn.functional as F
from torch.nn.attention import sdpa_kernel, SDPBackend
# 기본 사용: PyTorch가 자동으로 최적 백엔드 선택
batch, heads, seq_len, d_head = 2, 16, 4096, 64
Q = torch.randn(batch, heads, seq_len, d_head, device='cuda', dtype=torch.float16)
K = torch.randn_like(Q)
V = torch.randn_like(Q)
# 자동 백엔드 선택 (FlashAttention이 지원되면 자동 사용)
output = F.scaled_dot_product_attention(Q, K, V)
# FlashAttention 백엔드만 강제 사용
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
output_flash = F.scaled_dot_product_attention(
Q, K, V,
dropout_p=0.1, # 학습 시 dropout
is_causal=True, # causal masking (디코더용)
scale=None, # 기본 1/sqrt(d_k) 사용
)
# 어떤 백엔드가 사용되는지 확인
print(f"Flash available: {torch.backends.cuda.flash_sdp_enabled()}")
print(f"Memory efficient available: {torch.backends.cuda.mem_efficient_sdp_enabled()}")
flash-attn 라이브러리 직접 사용
flash-attn 패키지를 직접 사용하면 더 세밀한 제어와 추가 기능을 활용할 수 있다.
# pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func, flash_attn_qkvpacked_func
import torch
# 방법 1: Q, K, V 분리 입력
# 형태: (batch, seqlen, nheads, headdim) - 주의: PyTorch와 head/seq 순서가 다름
batch, seqlen, nheads, headdim = 2, 4096, 16, 64
Q = torch.randn(batch, seqlen, nheads, headdim, device='cuda', dtype=torch.float16)
K = torch.randn_like(Q)
V = torch.randn_like(Q)
output = flash_attn_func(
Q, K, V,
dropout_p=0.0,
softmax_scale=None, # 기본 1/sqrt(headdim)
causal=True,
window_size=(-1, -1), # (-1, -1)은 전체 어텐션, (w, 0)은 sliding window
return_attn_probs=False,
)
# 방법 2: QKV packed 형태
# 형태: (batch, seqlen, 3, nheads, headdim)
qkv = torch.randn(batch, seqlen, 3, nheads, headdim, device='cuda', dtype=torch.float16)
output_packed = flash_attn_qkvpacked_func(
qkv,
dropout_p=0.0,
causal=True,
)
# 방법 3: Sliding Window Attention (FlashAttention-2+)
output_sliding = flash_attn_func(
Q, K, V,
causal=True,
window_size=(256, 0), # 왼쪽 256 토큰 윈도우
)
print(f"Output shape: {output.shape}") # (batch, seqlen, nheads, headdim)
Triton을 이용한 FlashAttention 커널 구현 스케치
OpenAI의 Triton 컴파일러를 사용하면 CUDA C를 직접 작성하지 않고도 Python으로 GPU 커널을 작성할 수 있다.
import triton
import triton.language as tl
import torch
@triton.jit
def flash_attention_kernel(
Q_ptr, K_ptr, V_ptr, O_ptr,
stride_qb, stride_qh, stride_qm, stride_qk,
stride_kb, stride_kh, stride_kn, stride_kk,
stride_vb, stride_vh, stride_vn, stride_vk,
stride_ob, stride_oh, stride_om, stride_ok,
N_CTX: tl.constexpr,
D_HEAD: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""FlashAttention Forward Kernel의 Triton 구현 스케치.
실제 프로덕션 커널은 더 많은 최적화가 포함되지만,
핵심 tiling 로직의 구조를 보여준다.
"""
# 현재 프로그램이 담당하는 Q 블록 인덱스
pid_m = tl.program_id(0)
pid_bh = tl.program_id(1)
# 오프셋 계산
off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
off_k = tl.arange(0, D_HEAD)
# Q 블록 로드 (HBM -> SRAM)
q = tl.load(Q_ptr + pid_bh * stride_qh + off_m[:, None] * stride_qm + off_k[None, :] * stride_qk,
mask=off_m[:, None] < N_CTX)
# 누적 변수 초기화
m_i = tl.full([BLOCK_M], value=float('-inf'), dtype=tl.float32)
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
o_i = tl.zeros([BLOCK_M, D_HEAD], dtype=tl.float32)
# K/V 블록 순회 (내부 루프)
for start_n in range(0, N_CTX, BLOCK_N):
curr_n = start_n + off_n
# K, V 블록 로드 (HBM -> SRAM)
k = tl.load(K_ptr + pid_bh * stride_kh + curr_n[:, None] * stride_kn + off_k[None, :] * stride_kk,
mask=curr_n[:, None] < N_CTX)
v = tl.load(V_ptr + pid_bh * stride_vh + curr_n[:, None] * stride_vn + off_k[None, :] * stride_vk,
mask=curr_n[:, None] < N_CTX)
# 로컬 어텐션 스코어 계산 (SRAM 내)
s = tl.dot(q, tl.trans(k)) * (D_HEAD ** -0.5)
# Online Softmax 갱신
m_ij = tl.max(s, axis=1)
m_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_new)
beta = tl.exp(m_ij - m_new)
l_new = alpha * l_i + beta * tl.sum(tl.exp(s - m_ij[:, None]), axis=1)
# 출력 누적 갱신
p = tl.exp(s - m_new[:, None])
o_i = alpha[:, None] * l_i[:, None] * o_i + tl.dot(p.to(tl.float16), v)
o_i = o_i / l_new[:, None]
m_i = m_new
l_i = l_new
# 결과를 HBM에 저장
tl.store(O_ptr + pid_bh * stride_oh + off_m[:, None] * stride_om + off_k[None, :] * stride_ok,
o_i.to(tl.float16), mask=off_m[:, None] < N_CTX)
HuggingFace Transformers 통합
HuggingFace 모델에서 FlashAttention을 활용하는 방법이다.
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# FlashAttention-2를 사용하도록 모델 로드
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B",
torch_dtype=torch.float16,
attn_implementation="flash_attention_2", # 핵심 설정
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")
# 벤치마크 비교
import time
text = "FlashAttention은 " * 2000 # 긴 시퀀스
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=4096).to("cuda")
# FlashAttention-2 추론
torch.cuda.reset_peak_memory_stats()
start = time.time()
with torch.no_grad():
outputs = model(**inputs)
torch.cuda.synchronize()
flash_time = time.time() - start
flash_mem = torch.cuda.max_memory_allocated() / 1e9
print(f"FlashAttention-2: {flash_time:.3f}s, Peak Memory: {flash_mem:.2f} GB")
# SDPA eager 모드와 비교 (재로드 필요)
model_eager = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B",
torch_dtype=torch.float16,
attn_implementation="eager",
device_map="auto",
)
torch.cuda.reset_peak_memory_stats()
start = time.time()
with torch.no_grad():
outputs_eager = model_eager(**inputs)
torch.cuda.synchronize()
eager_time = time.time() - start
eager_mem = torch.cuda.max_memory_allocated() / 1e9
print(f"Eager Attention: {eager_time:.3f}s, Peak Memory: {eager_mem:.2f} GB")
print(f"Speedup: {eager_time/flash_time:.2f}x, Memory Savings: {eager_mem/flash_mem:.2f}x")
트러블슈팅: 실패 사례와 복구
사례 1: CUDA 호환성 오류
증상: flash-attn 설치 시 빌드 실패 또는 RuntimeError: FlashAttention only supports Ampere GPUs or newer
원인: FlashAttention은 SM80(A100) 이상의 GPU 아키텍처가 필요하다. V100(SM70)이나 T4(SM75)에서는 동작하지 않는다.
해결:
- GPU 아키텍처 확인:
nvidia-smi또는torch.cuda.get_device_capability() - SM75 이하인 경우
torch.nn.functional.scaled_dot_product_attention의mem_efficient백엔드를 대안으로 사용한다. 이 백엔드(xformers 기반)는 SM50 이상에서 동작한다. - 설치 시
MAX_JOBS=4 pip install flash-attn --no-build-isolation으로 병렬 빌드 수를 제한하여 OOM 방지
사례 2: 텐서 형태 불일치
증상: RuntimeError: expected query to have shape (batch, seqlen, nheads, headdim)
원인: flash_attn 라이브러리는 (batch, seqlen, nheads, headdim) 형태를 기대하지만, PyTorch의 MultiheadAttention은 (batch, nheads, seqlen, headdim) 순서를 사용한다.
해결: q.transpose(1, 2) 또는 einops.rearrange(q, 'b h s d -> b s h d')로 변환한다.
사례 3: 시퀀스 길이 정렬 문제
증상: 특정 시퀀스 길이에서 결과가 부정확하거나 성능이 급격히 저하됨
원인: FlashAttention 커널은 블록 크기(일반적으로 64 또는 128)의 배수일 때 최적 성능을 발휘한다. 배수가 아닌 경우 패딩 처리가 발생한다.
해결: 가능하면 시퀀스 길이를 128의 배수로 맞추고, flash_attn의 flash_attn_varlen_func을 사용하여 가변 길이 시퀀스를 효율적으로 처리한다.
사례 4: GQA/MQA 지원 문제
증상: Grouped Query Attention이나 Multi-Query Attention 모델에서 FlashAttention 적용 시 오류
원인: 초기 버전에서는 Q, K, V의 헤드 수가 동일해야 했다.
해결: FlashAttention-2 이상에서는 GQA/MQA를 네이티브로 지원한다. flash_attn_func에 헤드 수가 다른 Q와 K/V를 직접 전달하면 내부적으로 K/V 헤드가 반복(repeat)되어 처리된다.
사례 5: Backward Pass NaN 발생
증상: 학습 중 loss가 NaN으로 발산
원인: FP16에서 매우 긴 시퀀스(32K+)를 처리할 때, softmax의 지수 함수에서 수치적 오버플로우가 발생할 수 있다.
해결: BF16 사용을 권장한다. BF16은 FP16과 동일한 메모리를 사용하면서도 지수 범위가 FP32와 동일하여 오버플로우에 강하다. torch.autocast(device_type='cuda', dtype=torch.bfloat16) 컨텍스트에서 실행한다.
운영 시 주의사항
메모리 예산 계획
FlashAttention은 어텐션 스코어 행렬의 메모리를 절약하지만, Q/K/V 텐서 자체와 출력 텐서의 메모리는 여전히 필요하다. 대략적인 메모리 예산은 다음과 같이 계산한다.
기존의 항이 으로 줄어든 것이지, 전체 메모리가 0이 되는 것이 아니다.
CUDA 그래프와의 호환성
FlashAttention은 torch.compile과 CUDA Graph와 호환되지만, 동적 시퀀스 길이를 사용하는 경우 CUDA Graph의 정적 그래프 제약과 충돌할 수 있다. 추론 서빙에서는 시퀀스 길이를 미리 정해진 버킷으로 패딩하여 CUDA Graph를 재활용하는 전략이 효과적이다.
모니터링 지표
프로덕션 환경에서 FlashAttention 적용 후 모니터링해야 할 핵심 지표는 다음과 같다.
- GPU HBM 사용률:
nvidia-smi의 memory utilization이 아닌 실제 할당량 모니터링 - SM 활용률:
dcgm-exporter를 통한 Streaming Multiprocessor 활용률 추적 - 커널 실행 시간:
torch.profiler또는nsight systems를 통한 어텐션 커널 레이턴시 추적 - 수치 정밀도: FP8 사용 시 출력의 상대 오차를 주기적으로 FP16 기준과 비교
버전 선택 가이드
| 상황 | 권장 선택 |
|---|---|
| A100 + PyTorch 2.0+ | PyTorch SDPA (자동 백엔드 선택) |
| A100 + 최대 성능 필요 | flash-attn 라이브러리 직접 사용 |
| H100 + FP16 | FlashAttention-3 |
| H100 + 최대 추론 처리량 | FlashAttention-3 FP8 |
| V100/T4 (구형 GPU) | xformers memory_efficient_attention |
| 커스텀 어텐션 변형 필요 | Triton으로 직접 구현 |
마치며
FlashAttention 시리즈는 알고리즘의 수학적 결과를 변경하지 않으면서도 하드웨어의 메모리 계층 구조를 깊이 이해하고 활용하는 것만으로 극적인 성능 향상을 달성한 모범적인 사례이다. "같은 연산을 수행하되, 데이터를 어떻게 이동시키는가"에 대한 고민이 실제 벽시계 시간에 2~4배의 차이를 만들어낸다는 사실은, 현대 GPU 프로그래밍에서 IO-awareness가 얼마나 중요한지를 웅변한다.
v1에서 tiling과 recomputation이라는 핵심 아이디어를 확립하고, v2에서 GPU 내부의 미시적인 작업 분배를 최적화하고, v3에서 차세대 하드웨어의 새로운 기능을 선제적으로 활용하는 발전 과정은, AI 시스템 최적화 연구가 나아가야 할 방향을 명확하게 보여준다. FlashAttention은 이제 PyTorch에 내장되어 대부분의 실무자가 의식하지 않아도 자동으로 사용하는 인프라 수준의 기술이 되었지만, 그 내부 동작 원리를 이해하는 것은 더 나은 모델 설계와 시스템 최적화의 출발점이 될 것이다.
참고자료
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (arXiv) - FlashAttention v1 원본 논문
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning (arXiv) - FlashAttention-2 논문
- FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision (arXiv) - FlashAttention-3 논문
- Dao-AILab/flash-attention GitHub Repository - 공식 구현 및 설치 가이드
- PyTorch Scaled Dot Product Attention 공식 문서 - PyTorch SDPA API 문서
- FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision (PyTorch Blog) - PyTorch 공식 블로그의 FlashAttention-3 소개
- Stanford CRFM FlashAttention-2 해설 - Stanford의 FlashAttention-2 공식 해설
- Tri Dao FlashAttention-3 블로그 - 저자의 FlashAttention-3 해설