필사 모드: FlashAttention & Efficient Attention Deep Dive — Tiling, Online Softmax, PagedAttention, GQA 완전 정복 (2025)
한국어TL;DR
- **Naive attention**은 $N \times N$ attention matrix를 HBM에 저장. 시퀀스 길이 $N$에 대해 **$O(N^2)$ 메모리**. LLM 긴 컨텍스트에서 치명적.
- **FlashAttention** (Tri Dao, 2022): 핵심 통찰은 attention이 **memory-bound**이지 compute-bound 아니라는 것. **Tiling + Online Softmax**로 SRAM만 쓰고 HBM에 N×N 행렬을 절대 쓰지 않음.
- **Online Softmax**: Softmax를 한 번에 계산하는 대신 블록 단위로 누적. 수학적으로 동등한데 IO가 완전히 다름.
- **FlashAttention-2** (2023): 병렬화 개선, Forward/Backward 최적화. 2-3배 빠름.
- **FlashAttention-3** (2024, Hopper): Tensor Core + 비동기 WGMMA + warp 특수화. H100에서 75% 효율.
- **PagedAttention** (vLLM, 2023): OS 가상 메모리 아이디어를 KV cache에. **2-4배 throughput**.
- **Grouped Query Attention (GQA)**: Query head 여러 개가 같은 K/V head 공유. **KV cache 크기 1/N**.
- **Sliding Window Attention**: 컨텍스트 윈도우 제한. Mistral, Gemma가 사용.
- **Ring Attention**: 여러 GPU로 sequence 분할. 10M 토큰 컨텍스트 가능.
1. Attention의 병목 문제
1.1 Transformer의 심장
Transformer의 핵심 연산:
$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$
- $Q$, $K$, $V$: (N, d) 크기. N = 시퀀스 길이, d = 헤드 차원.
- $QK^T$: (N, N) 점수 행렬.
- Softmax: 행별 정규화.
- 결과에 $V$ 곱: (N, d) 출력.
이것이 Transformer의 품질의 원천이지만, 동시에 **가장 큰 비용**.
1.2 O(N²) 메모리
$QK^T$ 크기 = $N \times N$. 시퀀스 길이에 제곱 비례.
N = 1,000: N² = 1,000,000 = 1 M 요소
N = 10,000: N² = 100 M
N = 100,000: N² = 10 G ← GPU 메모리 초과
FP16이면 요소당 2 바이트. 100,000 토큰 컨텍스트는 **20 GB**의 attention matrix만으로 차지. 한 레이어에. 한 헤드에.
실제로는 layer × head × batch → 수백 GB. **불가능**.
1.3 Compute vs Memory
Naive attention 구현:
S = Q @ K.T # (N, N)
P = softmax(S) # (N, N)
O = P @ V # (N, d)
각 단계에서 **HBM 읽기/쓰기**.
- $Q$, $K$, $V$ 읽기: $O(Nd)$
- $S$ 쓰기: $O(N^2)$
- $S$ 읽기, $P$ 쓰기: $O(N^2)$
- $P$ 읽기, $V$ 읽기, $O$ 쓰기: $O(N^2 + Nd)$
HBM 총 I/O: **$O(N^2)$**.
GPU의 HBM 대역폭이 아무리 빨라도(3 TB/s on H100) $N^2$ 데이터를 옮기는 데 시간이 걸린다. **대부분 시간을 HBM ↔ SRAM 복사에 씀**. 컴퓨트 유닛은 대부분 놀고 있다.
1.4 "Memory-bound, not compute-bound"
이것이 Tri Dao의 핵심 통찰이었다. "Attention 성능을 올리려면 FLOPs를 줄이는 게 아니라 **HBM 접근을 줄여야** 한다."
2. GPU 메모리 계층 복습
2.1 두 가지 메모리
**HBM** (High Bandwidth Memory):
- GPU의 "main memory". 수십~수백 GB.
- 대역폭 ~3 TB/s (H100).
- 지연 시간 ~500 cycle.
**SRAM** (Shared Memory / L1):
- SM 내부의 **on-chip** 메모리.
- 매우 작음 (각 SM당 ~228 KB on H100).
- 대역폭 ~19 TB/s (H100).
- 지연 시간 ~20 cycle.
**SRAM이 6배 이상 빠르다**. 그리고 **지연 시간이 25배 적다**.
2.2 위치의 중요성
데이터가 SRAM에 있으면 compute unit이 놀 시간 없이 곧바로 처리.
HBM에 있으면 SRAM으로 가져오는 동안 기다려야 함.
Attention 최적화의 목표: **SRAM에서 가능한 한 모든 계산을 완료**.
문제: SRAM이 너무 작아 $N \times N$ 행렬이 들어가지 않는다.
해결: **Tiling**.
3. FlashAttention의 핵심 아이디어
3.1 Tiling
행렬을 **블록**으로 나누어 순차 처리:
Q: (N, d) → Q_1, Q_2, ..., Q_{N/B_r} (각 (B_r, d))
K, V: (N, d) → K_1, K_2, ..., K_{N/B_c} (각 (B_c, d))
블록 크기 $B_r$, $B_c$는 SRAM에 맞게 선택.
3.2 Naive Blockwise
각 $Q_i$에 대해 전체 K, V를 순회해서 $O_i$ 계산:
for i in range(N / B_r):
Q_i = load(Q, i)
for j in range(N / B_c):
K_j, V_j = load(K, V, j)
S_ij = Q_i @ K_j.T
하지만 softmax는 전체 행에 걸쳐...?
**문제**: softmax는 각 행의 **모든 값**을 알아야 정규화 가능. 블록 단위로는 지역적 정보만.
3.3 Online Softmax — 핵심 트릭
Naive softmax:
$$
\text{softmax}(x)_i = \frac{e^{x_i}}{\sum_j e^{x_j}}
$$
**수치 안정화** 버전 (max 뺌):
$$
\text{softmax}(x)_i = \frac{e^{x_i - m}}{\sum_j e^{x_j - m}}, \quad m = \max_j x_j
$$
이것을 **블록 단위로** 계산할 수 있다 — 만약 "현재까지의 max와 sum"을 유지하면.
**Incremental softmax**:
m_prev = -∞ # 현재까지 본 max
l_prev = 0 # 현재까지 본 sum
For each block:
m_block = max(block)
m_new = max(m_prev, m_block)
기존 누적치 rescale
l_prev = l_prev * exp(m_prev - m_new)
현재 블록
l_block = sum(exp(block - m_new))
합침
l_prev = l_prev + l_block
m_prev = m_new
마지막에 `softmax = exp(x - m_prev) / l_prev`.
**수학적으로 정확**. 단, 각 블록 처리 시 이전 누적을 재조정해야 함.
3.4 FlashAttention 알고리즘
초기화: O = 0, l = 0, m = -∞ (출력, sum, max)
For j in range(N / B_c): # K, V 블록
K_j, V_j ← HBM 로드 (SRAM으로)
For i in range(N / B_r): # Q 블록
Q_i ← HBM 로드 (SRAM으로)
S_ij = Q_i @ K_j.T # SRAM 내부 계산
Online softmax 업데이트
m_ij = rowmax(S_ij)
P_ij = exp(S_ij - m_ij)
l_ij = rowsum(P_ij)
이전과 결합
m_new = max(m_i, m_ij)
l_new = exp(m_i - m_new) * l_i + exp(m_ij - m_new) * l_ij
O 업데이트
O_i = diag(l_new)^-1 * (diag(l_i) * exp(m_i - m_new) * O_i
+ exp(m_ij - m_new) * P_ij * V_j)
m_i = m_new
l_i = l_new
O_i 저장 (HBM)
핵심:
- $Q$, $K$, $V$만 HBM에서 읽기.
- $S = QK^T$ 전체 행렬을 **절대 HBM에 쓰지 않음**.
- 각 블록 내부 계산은 SRAM.
- $O$ 최종 결과만 HBM.
3.5 HBM I/O 분석
Naive attention: $O(Nd + N^2)$ HBM 접근.
FlashAttention: $O(N^2 d^2 / M)$ HBM 접근. $M$은 SRAM 크기.
예: $N = 8192$, $d = 128$, $M \approx 100 KB$.
- Naive: ~67 M bytes (FP16).
- FlashAttention: ~4 M bytes.
**17배 적은 HBM I/O**. 실제 속도 향상 5-10배.
3.6 정확성
"근사"가 아니다. **수학적으로 정확한 attention**. Online softmax 덕분에 블록 단위 처리가 전체 softmax와 동일한 결과.
4. Backward Pass
훈련 시 gradient 계산도 attention.
4.1 Naive Backward
Forward 시 $P = \text{softmax}(QK^T)$를 저장했다가 backward에서 사용. **$O(N^2)$ 메모리 저장**.
4.2 FlashAttention Backward
**Recomputation**: forward에서 $P$를 저장하지 않고, backward에서 다시 계산.
Forward: Q, K, V, O, l, m만 저장 (O(N) 메모리, not O(N²))
Backward: 필요할 때 S, P 재계산
FLOP 증가 (약 2x compute), 메모리 감소 (O(N²) → O(N)). 전체적으로 여전히 빠름 — **memory-bound**라서.
4.3 결과
LLaMA 훈련 같은 워크로드:
- 메모리: 10배 감소.
- 속도: 2-3배 빠름.
- 긴 컨텍스트(8K-32K) 훈련 가능.
5. FlashAttention-2
2023년 Tri Dao의 후속. 주요 개선:
5.1 더 나은 병렬화
**FlashAttention-1**: 시퀀스 차원으로만 병렬.
**FlashAttention-2**: Batch, head, sequence 3차원 병렬 → GPU 활용 ↑.
5.2 Warp 간 작업 재분배
각 warp가 하는 일을 재설계. 동기화 오버헤드 감소.
5.3 Non-matmul FLOP 감소
Softmax 관련 non-matmul 연산 (rescaling 등)을 줄임. Tensor Core가 아닌 CUDA Core 부하가 병목이었는데 이를 감소.
5.4 성능
FlashAttention-1 대비 **2배 빠름**. Standard attention 대비 5-10배.
LLaMA-2 13B 훈련에서 사용. 긴 컨텍스트(16K, 32K) 훈련 실용화.
6. FlashAttention-3 (2024, Hopper)
NVIDIA Hopper 아키텍처(H100)에 특화된 버전.
6.1 새 Hopper 기능 활용
**WGMMA** (Warp Group Matrix Multiply-Accumulate): 비동기 행렬 곱. 새 instruction.
**TMA** (Tensor Memory Accelerator): 전용 async copy 엔진.
**FP8 Tensor Core**: 정밀도 낮춤으로 대역폭/속도 ↑.
6.2 알고리즘 변경
**Producer-Consumer 모델**: 일부 warp가 memory load 담당 (producer), 다른 warp가 compute 담당 (consumer). 오버랩.
**Warp Specialization**: Warp마다 특정 역할 고정. 스케줄링 유연성 감소, 성능 ↑.
6.3 결과
H100에서 **peak FP16 throughput의 75%** 달성. 이전 flash-attention은 35% 정도.
GPT-4 훈련 수준 워크로드에 필수.
7. PagedAttention — vLLM
2023년 UC Berkeley Sky Lab (Woosuk Kwon et al.)의 **vLLM**.
7.1 KV Cache 문제
LLM 추론 시 각 이전 토큰의 K, V를 저장해야 한다 — 다음 토큰 생성에 필요.
**LLaMA-13B**, 2048 context:
- K: (2048, 40, 128) per layer × 40 layers × FP16 = 단일 시퀀스당 **800 MB**.
- V: 같은 크기.
- Total: ~1.6 GB per sequence.
100 동시 사용자 → 160 GB. 80 GB GPU로 불가능.
7.2 Fragmentation 문제
전통적 할당:
- 각 요청에 **max_length**만큼 미리 할당 (예: 2048 토큰).
- 실제 생성은 평균 500 토큰 → **75% 낭비**.
- **Internal fragmentation**.
또한 요청마다 다른 크기 → **external fragmentation**.
7.3 OS 가상 메모리 아이디어
**운영체제의 가상 메모리**는 이 문제를 수십 년 전에 해결했다:
- 프로세스 메모리를 **고정 크기 페이지**로 분할.
- **페이지 테이블**로 논리 주소 → 물리 페이지 매핑.
- 필요 시에만 물리 페이지 할당.
- 페이지는 서로 **공유** 가능 (shared memory, copy-on-write).
PagedAttention: **이 모든 아이디어를 LLM KV cache에 적용**.
7.4 구조
KV cache를 block(예: 16 tokens)으로 분할
Sequence A: [block 0] → [block 1] → [block 2]
↓ ↓ ↓
Physical Physical Physical
Page 42 Page 7 Page 13
Sequence B: [block 0] → [block 1]
↓ ↓
Physical Physical
Page 42 Page 3 ← 공유!
(copy-on-write)
- 논리적으로 연속된 블록이 물리적으로 흩어져 있어도 OK.
- 필요할 때만 할당.
- 공통 prefix는 공유 (copy-on-write).
7.5 이점
**1. 메모리 효율**:
- Fragmentation 거의 제거.
- 100 users → 100 GB → 30 GB (실제 측정).
**2. Sharing**:
- 같은 시스템 프롬프트를 여러 요청이 공유.
- 같은 조건의 beam search branch가 공유.
**3. Throughput**:
- 기존 대비 **2-4배 더 많은 동시 요청** 처리.
- GPU 활용률 ↑.
7.6 vLLM의 영향
vLLM은 2023년 말 공개, 순식간에 **LLM 서빙의 표준**이 됐다. 거의 모든 open source LLM 서버(TGI, Ollama, Llama.cpp도 일부)가 비슷한 아이디어 채택.
**2024**: PagedAttention의 구현 변형, CUDA 커널 튜닝 논문 봇물.
8. Multi-Query / Grouped-Query Attention
8.1 Multi-Head Attention (원본)
원래 Transformer:
- $H$개의 query head.
- 각자 자기 $Q_h$, $K_h$, $V_h$.
- 독립적 attention 계산 후 concat.
예: LLaMA-2 7B에 $H = 32$. 32 Q head + 32 K head + 32 V head.
8.2 KV Cache 부담
KV cache 크기: $2 \cdot H \cdot L \cdot d$ per token.
- $H = 32$, $L = 32$ layers, $d = 128$ → 토큰당 524 KB (FP16).
- 10,000 토큰 컨텍스트 → 5.2 GB per sequence.
8B 모델에서 모델 자체 16 GB, KV cache 5 GB, batch 여러 개 → GPU 메모리 한계.
8.3 Multi-Query Attention (MQA)
**모든 query head가 하나의 K, V head를 공유**.
- H개의 Q head, 1개의 K head, 1개의 V head.
- KV cache 크기: **1/H**.
- 계산량은 거의 같음.
단점: 품질 약간 하락. "모든 head가 같은 K/V를 보면 표현력 감소."
8.4 Grouped-Query Attention (GQA)
LLaMA-2/3에서 사용. 타협:
- $H$개의 Q head.
- $G$개의 K, V head ($G < H$, 보통 $H / 4$ 또는 $H / 8$).
- 각 K, V head를 $H / G$개의 Q head가 공유.
예: LLaMA-3 8B, $H = 32$, $G = 8$ → 4 Q heads per KV head.
- KV cache: **1/4**.
- 품질: MHA와 거의 동일.
- 속도: 추론 빠름.
8.5 KV Cache 비교
동일 모델 크기에서:
- **MHA**: 100% (기준).
- **GQA (G=H/4)**: 25%.
- **MQA**: 1/H (3%).
**4배 메모리 감소**로 4배 긴 컨텍스트 또는 4배 동시 요청.
LLaMA-2 34B부터 GQA 사용 → LLaMA-3 표준. Mistral, Mixtral도 같은 접근.
9. Sliding Window Attention
9.1 아이디어
각 토큰이 **최근 N개 토큰**만 본다. 전체가 아니라.
Token at position i attends to:
positions [i - W, i - W + 1, ..., i]
$W$는 window size (예: 4096).
9.2 장점
- **계산**: $O(NW)$ 대신 $O(N^2)$. $N \gg W$면 훨씬 빠름.
- **KV cache**: 최근 $W$개만 유지 → **고정 크기**. 시퀀스 무한 확장 가능.
9.3 단점
- **멀리 있는 토큰 못 봄**. 이론상 장거리 의존성 손실.
- 하지만 여러 layer 거치면 **effective receptive field**가 누적. 예: 12 layers × 4096 window = 49152 receptive field.
9.4 사용
**Mistral 7B**: Sliding window 4096. 긴 컨텍스트에 매우 효율.
**Gemma**: Local + global attention 번갈아.
**Longformer**: 대부분 local, 몇 개 토큰만 global.
9.5 조합
많은 현대 LLM이 **local + global attention**을 섞는다:
- 대부분 layer: local (sliding window).
- 일부 layer: full global.
- 결과: 긴 컨텍스트 + 효율.
10. Flash Decoding
10.1 문제
LLM 추론의 두 단계:
1. **Prefill**: 입력 전체를 한 번에 처리. Batch compute-intensive.
2. **Decode**: 토큰 하나씩 생성. 각 단계에서 한 토큰만 compute.
Decode는 **본질적으로 sequential** — 이전 토큰의 출력이 다음의 입력. 그리고 **memory-bound** — 작은 compute, 큰 KV cache 읽기.
10.2 GPU 활용률 문제
Decode 시 단일 토큰 attention은 작은 연산. GPU가 거의 놀고 있음. 특히 long context:
Decode step: q (1 token) × K (context 10000 tokens) → score (10000)
이건 Matrix-vector 곱 (GEMV). Compute는 작고 memory bandwidth가 병목.
10.3 Flash Decoding 아이디어
**KV cache를 chunk로 나누고 각 chunk를 별도 SM에서 병렬 처리**. Decode 시점에 parallelism 발굴.
Sequence K, V 분할: chunks of size C
각 chunk에 대해: local softmax 계산 (in parallel)
최종: partial 결과 집계
Flash Attention의 online softmax 아이디어를 decode에 적용.
10.4 결과
Long context decode에서 **8배 속도**. LLaMA 13B의 16K 컨텍스트 decode가 실용화.
11. Ring Attention
11.1 거대한 컨텍스트 문제
1M, 10M 토큰 컨텍스트 — 단일 GPU 메모리 초과.
11.2 기본 아이디어
**시퀀스를 여러 GPU로 분할**. 각 GPU가 sequence의 일부.
GPU 0: Q[0:N/4]
GPU 1: Q[N/4:N/2]
GPU 2: Q[N/2:3N/4]
GPU 3: Q[3N/4:N]
K, V도 같은 방식으로 분할.
11.3 통신 패턴
각 Q 블록은 **모든 K, V 블록**을 봐야 한다. 순차적으로 "ring"으로 전달:
GPU 0 ← K_0, V_0 (local)
GPU 1 ← K_1, V_1 (local)
GPU 2 ← K_2, V_2 (local)
GPU 3 ← K_3, V_3 (local)
Round 1: Ring rotate
GPU 0 ← K_3, V_3
GPU 1 ← K_0, V_0
GPU 2 ← K_1, V_1
GPU 3 ← K_2, V_2
Round 2, 3: 계속 rotate
총 N rounds 후 모든 Q가 모든 K, V를 봤음.
각 rotate 동안 compute와 통신 **겹침** (async).
11.4 결과
**Ring Attention** (Liu et al. 2023):
- 10M 토큰 컨텍스트 훈련 가능.
- Gemini 1.5 Pro의 1M+ context의 기반 기술 중 하나.
Sora, Llama-3 긴 context에도 유사 기법.
12. Sparse Attention
12.1 아이디어
"대부분의 attention weight는 작다" — 실제로 몇 개 토큰에 집중. 나머지는 sparse로 처리.
12.2 패턴
**Block Sparse**: 미리 정의된 패턴.
- Strided: 일정 간격.
- Fixed: 국소 + 전역 혼합.
**Longformer**: Sliding window + 선택적 global tokens.
**BigBird**: Random + window + global. 이론적 universal approximator.
**Sparse Transformer** (OpenAI 2019): Stride + local.
12.3 한계
Sparse 패턴이 고정이라 workload에 따라 품질 변동. Dense attention만큼 일반화되지 않음. 주류 LLM은 대부분 dense + efficient 기법(FlashAttention + GQA).
12.4 학습된 Sparsity
최근 연구: 학습 중 attention 패턴을 모델이 선택. 예: Routing Attention, Switch Transformer에 힌트.
아직 주류는 아님.
13. 다른 효율화 기법
13.1 Linear Attention
Attention의 $O(N^2)$를 **$O(N)$**로 만드는 수학적 트릭.
Core: softmax를 feature map $\phi$로 근사.
$$
\text{Attention} = \text{softmax}(QK^T)V \approx \phi(Q) (\phi(K)^T V)
$$
$\phi(K)^T V$를 먼저 계산 ($O(Nd^2)$), 그 다음 $\phi(Q)$와 곱함 ($O(Nd^2)$).
- **Performer** (2020): Random feature로 softmax 근사.
- **Linformer** (2020): K, V를 저차원으로 project.
단점: 품질 손실 (특히 긴 sequence). 주류 LLM 채택 거의 없음.
13.2 KV Cache Quantization
KV cache 메모리 절반 또는 더 줄이기:
- **FP16 → INT8**: 2배 감소.
- **FP16 → INT4**: 4배 감소.
- 품질 손실 최소화 for recent KV cache.
13.3 KV Cache Offloading
KV cache를 **CPU DRAM** 또는 **NVMe**에 저장. 필요할 때 GPU로 가져옴.
**vLLM**: CPU ↔ GPU swap 지원.
**DeepSpeed**: 더 공격적.
속도 손실이 있지만 메모리 제한이 극복된다. Long context 서비스에 유용.
13.4 Speculative Decoding
작은 drafter 모델이 여러 토큰을 predictor. 큰 모델이 validator. 올바른 만큼 한 번에 수락.
**Medusa**, **Lookahead decoding** 같은 variant들. 2-3배 추론 속도.
Attention 최적화 직접 아니지만, LLM 서빙 효율 향상에 핵심.
14. 생태계
14.1 구현
**FlashAttention** (공식, Tri Dao): https://github.com/Dao-AILab/flash-attention. PyTorch CUDA extension.
**xformers** (Meta): 다양한 efficient attention 구현 모음. FlashAttention 통합.
**vLLM** (UC Berkeley): PagedAttention 구현. LLM 서빙.
**TRT-LLM** (NVIDIA): TensorRT 기반 고성능 추론. FlashAttention, PagedAttention 포함.
**SGLang**: vLLM 대안. 다른 최적화 전략.
**Triton**: FlashAttention을 Triton으로 재구현 가능. Tri Dao의 튜토리얼.
14.2 PyTorch 통합
PyTorch 2.0+:
from torch.nn.functional import scaled_dot_product_attention
output = scaled_dot_product_attention(q, k, v, is_causal=True)
자동으로 FlashAttention 2 사용 (설치됐으면). 완전 투명.
14.3 HuggingFace
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3-8B",
torch_dtype=torch.float16,
attn_implementation="flash_attention_2" # 명시적
)
대부분 최신 모델이 FlashAttention 2/3 지원.
15. 미래 방향
15.1 하드웨어 진화
**Blackwell** (B100/B200, 2024): FP4 Tensor Core, 더 큰 SRAM, TMEM(Tensor Memory) 새 유닛.
FlashAttention 4 추정: Blackwell 특화 버전이 곧 나올 것.
15.2 긴 컨텍스트
1M, 10M, 100M 컨텍스트가 목표. Ring Attention + hierarchical attention 조합.
Gemini 1.5 이미 2M. Gemini 2.0+ 더 큰.
15.3 Post-Attention?
일부 연구자는 "attention은 O(N²)라는 근본 한계가 있으니 완전히 새 아키텍처가 필요"라고 주장.
- **State Space Models (Mamba)**: $O(N)$ 추론.
- **RWKV**: RNN의 부활.
- **Hyena**: Long convolution 기반.
하지만 2025년 현재 dense attention + efficient 기법이 품질 면에서 여전히 승.
15.4 Quantization 극단
INT2, ternary, 1-bit. KV cache + weights 양쪽에 적용. Sub-billion GPU 메모리로 대형 모델.
16. 학습 리소스
**논문**:
- "FlashAttention: Fast and Memory-Efficient Exact Attention" — Tri Dao et al. 2022.
- "FlashAttention-2" — Tri Dao 2023.
- "FlashAttention-3" — Shah et al. 2024.
- "Efficient Memory Management for LLM Serving with PagedAttention" — vLLM paper.
- "Ring Attention" — Liu et al. 2023.
- "Efficient Streaming Language Models with Attention Sinks" — 2023.
**블로그**:
- Tri Dao의 개인 블로그 / Princeton 페이지.
- Hugging Face 블로그 (LLM 최적화 시리즈).
- Together.ai 블로그 (serving 최적화).
- UC Berkeley Sky Lab.
**코드**:
- Dao-AILab/flash-attention.
- vllm-project/vllm.
- NVIDIA/TensorRT-LLM.
- huggingface/text-generation-inference (TGI).
**강의**:
- "GPU MODE" (Discord + YouTube).
- MIT 6.5940 "TinyML" (Han Song).
- Stanford CS229M.
17. 요약 — 한 장 정리
┌─────────────────────────────────────────────────────┐
│ Efficient Attention Cheat Sheet │
├─────────────────────────────────────────────────────┤
│ 문제: │
│ Naive attention: O(N²) 메모리 │
│ HBM I/O 병목 (memory-bound) │
│ │
│ FlashAttention (Tri Dao 2022): │
│ Tiling: 블록 단위 처리 │
│ Online softmax: incremental normalization │
│ SRAM에서 계산, HBM에 N² 쓰지 않음 │
│ 정확 (근사 아님) │
│ Backward: recomputation │
│ 5-10x 빠름 │
│ │
│ FlashAttention-2 (2023): │
│ 병렬화 개선 │
│ Warp 작업 재분배 │
│ 2-3x 추가 속도 │
│ │
│ FlashAttention-3 (Hopper 2024): │
│ WGMMA (비동기) │
│ Producer-Consumer │
│ Warp specialization │
│ FP8 Tensor Core │
│ H100의 75% 효율 │
│ │
│ PagedAttention (vLLM 2023): │
│ OS 가상 메모리 → KV cache │
│ 블록 단위 할당 │
│ Block table (page table 유사) │
│ 공유 (copy-on-write) │
│ 2-4x throughput │
│ │
│ MQA / GQA: │
│ Multi-Query: 1 K/V head │
│ Grouped-Query: G (< H) K/V heads │
│ KV cache 1/N 감소 │
│ LLaMA-2/3, Mistral 사용 │
│ │
│ Sliding Window: │
│ 최근 W 토큰만 │
│ O(NW) 계산, fixed KV cache │
│ Mistral, Gemma │
│ │
│ Ring Attention: │
│ 시퀀스 여러 GPU 분할 │
│ K/V를 ring으로 rotate │
│ 10M+ context │
│ │
│ Flash Decoding: │
│ KV를 chunk로 분할 │
│ Decode time에 병렬 │
│ Long context 8x │
│ │
│ Sparse Attention: │
│ Longformer, BigBird │
│ 품질 trade-off │
│ │
│ 기타: │
│ KV cache quantization (INT8/4) │
│ KV offloading (CPU/NVMe) │
│ Speculative decoding │
│ │
│ 통합: │
│ PyTorch scaled_dot_product_attention │
│ HuggingFace attn_implementation │
│ xformers, vLLM, TRT-LLM, SGLang │
└─────────────────────────────────────────────────────┘
18. 퀴즈
**A.** 기존 attention 최적화는 "FLOP을 줄이자"(linear attention, sparse attention 등)에 집중했지만 실제 병목은 **HBM I/O**였다. GPU의 HBM 대역폭(3 TB/s)이 빠르지만 $O(N^2)$ 크기의 $S = QK^T$ 행렬을 HBM에 쓰고 다시 읽는 것만으로 수 GB를 옮긴다. Compute units(Tensor Core)는 데이터를 기다리며 대부분 시간 idle. Tri Dao의 통찰: "FLOP은 충분히 빠르니 **HBM 접근을 줄이면** 자동으로 빨라진다". 이 관점에서 tiling + online softmax로 SRAM만 사용하도록 알고리즘을 재설계. FLOP은 오히려 약간 증가(backward의 recomputation)하지만 HBM I/O는 17배 감소 → 5-10배 속도 향상. 시스템 엔지니어링의 교훈: "**병목을 오해하면 최적화가 소용없다**".
**A.** Softmax는 전체 행의 모든 값이 필요해서(max와 sum 계산) 블록 단위로 쪼갤 수 없어 보인다. 트릭: **현재까지의 max와 sum을 유지하면서 새 블록이 올 때마다 재조정**. 새 블록의 max가 이전 max보다 크면 이전 sum에 `exp(old_max - new_max)`를 곱해 "scale down" 하고, 새 블록의 contribution을 더한다. 최종 결과는 **수학적으로 전체 softmax와 정확히 동등**. 이 덕분에 $S$ 행렬 전체를 메모리에 저장할 필요 없이 블록을 순차로 SRAM에 로드하며 $O$(출력)를 점진적으로 업데이트 가능. 수치적으로도 안전 — max 뺌으로 overflow 방지. 이 아이디어는 1996년 "Keeping a running softmax" 수치해석 논문에 있었지만 FlashAttention이 **GPU tiling과 결합**해서 현대 AI 가속의 핵심 기법이 됐다.
**A.** **네 가지 핵심 OS 개념**: (1) **고정 크기 페이지** — KV cache를 연속 블록으로 할당하지 않고 작은 블록(예: 16 tokens)으로 나눔, (2) **페이지 테이블** — 논리적 토큰 위치에서 물리적 블록으로의 매핑. 각 시퀀스가 자기 block table을 가짐, (3) **On-demand 할당** — 필요한 시점에만 물리 블록 확보, over-allocation 방지, (4) **공유 + Copy-on-Write** — 같은 시스템 프롬프트로 시작하는 여러 요청이 같은 물리 블록을 공유, divergence 시에만 복사. 결과: fragmentation이 거의 없고(80% → 5% 미만), 공유로 메모리 절약, throughput 2-4배 증가. 운영체제가 1960년대에 해결한 문제를 2023년에 LLM serving에 적용한 사례 — "이미 아는 문제의 새 영역 응용"이 연구의 가장 실용적인 형태.
**A.** **품질과 효율의 균형**. **MHA**(Multi-Head Attention): H개의 Q head + H개의 K/V head → 최고 품질, 최대 KV cache. **MQA**(Multi-Query): H개의 Q head + 1개의 K/V head → 최소 KV cache(1/H), 품질 약간 하락 ("모든 head가 같은 K/V를 보면 표현력 감소"). **GQA**: H개의 Q head + G개의 K/V head (G ≈ H/4 또는 H/8) → 중간. KV cache는 4-8배 감소하지만 품질은 MHA와 거의 동일. LLaMA-2 70B와 LLaMA-3 전체 라인업이 GQA 채택: Q head 32 + KV head 8 → 4:1 비율. 실측: KV cache 4배 감소로 동일 GPU에서 4배 더 긴 컨텍스트 또는 4배 더 많은 동시 요청. 품질 손실 측정 불가 수준. "완전한 분리도, 완전한 공유도 아닌 중간"이 최선인 흔한 설계 원칙.
**A.** **근사 알고리즘이 아니라는 보장**. Efficient attention 계열의 다른 접근들(Linear Attention, Performer, Linformer, Sparse Attention)은 대부분 **근사** — 수학적으로 $O(N^2)$ 결과와 다른 값을 출력. 품질 손실이 있고, 특정 워크로드에 잘 맞지 않을 수 있으며, 훈련된 모델의 attention과 완전 호환 안 됨. FlashAttention은 **비트 단위 재현이 가능한 수학적 정확성** — online softmax의 수치적 안정성 덕분에 naive attention과 동등한 결과. 이것이 결정적 — 기존 훈련된 모델(LLaMA, GPT)에 **drop-in replacement**로 사용 가능. 프레임워크가 자동으로 FlashAttention을 켜도 모델 출력이 완전히 같음. 성능 향상이 "부작용 없는 업그레이드"가 된다. 근사 방법들이 널리 채택 안 된 이유: 작은 차이도 downstream task에서 예측 불가능한 영향.
**A.** **Sequence parallelism + 통신-계산 오버랩**. 단일 GPU는 큰 sequence의 attention matrix를 저장 불가(10M² = 100조 entry). Ring Attention: sequence를 N개 GPU로 분할(각 1/N). 각 GPU는 자기 Q 블록과 local K/V로 시작, 그 다음 K/V를 **ring 토폴로지로 순환**(GPU 0 → 1 → 2 → ... → 0). 매 회전마다 새 K/V 블록을 받으며 부분 attention 계산. **Online softmax**를 사용해 부분 결과를 정확하게 합쳐 최종 출력. 중요한 점: **통신이 compute와 동시 진행** — 현재 블록 계산하는 동안 다음 블록 수신. Async로 latency 숨김. N rounds 후 모든 Q가 모든 K/V를 봤고, 각 GPU는 자기 Q 블록의 출력만 가지면 됨. Gemini 1.5 Pro의 1-2M 컨텍스트, Sora의 긴 비디오 시퀀스가 이런 기법 기반. "분산 시스템이 메모리 한계를 극복"의 완벽한 사례.
**A.** **추론 단계별 특성이 다름**. Prefill 단계(입력 전체 처리)는 compute-heavy — 많은 Q, K, V를 한 번에 GEMM. Decode 단계(토큰 하나씩 생성)는 **memory-bandwidth 제한** — 매 step에 single query × long K/V context로 GEMV(행렬-벡터 곱). GEMM은 arithmetic intensity 높고 Tensor Core 활용 가능, GEMV는 compute/memory 비율 낮아 HBM bandwidth가 병목. 한 토큰 decode 시 KV cache 전체 읽기만 수 ms. **Flash Attention**은 여전히 유용하지만 GEMM 최적화라 decode에서 이득 작음. **Flash Decoding**이 해결책: KV cache를 chunk로 분할하고 각 chunk를 별도 SM에서 병렬 처리 → decode에 parallelism 도입. Online softmax로 partial 결과 합산. 장문 컨텍스트 decode에서 8배 속도. 관찰: "LLM 서빙 최적화는 prefill과 decode를 따로 생각해야 함". 이 둘이 근본적으로 다른 병목.
이 글이 도움이 됐다면 다음 포스트도 확인해 보세요:
- "Transformer Architecture Deep Dive" — Attention의 기본.
- "CUDA GPU Programming Deep Dive" — 커널 최적화의 기반.
- "Diffusion Models Deep Dive" — 또 다른 attention-heavy 워크로드.
- "RDMA & NCCL" — 멀티 GPU 훈련의 통신 기반.
현재 단락 (1/417)
- **Naive attention**은 $N \times N$ attention matrix를 HBM에 저장. 시퀀스 길이 $N$에 대해 **$O(N^2)$ 메모리**. LLM ...