Skip to content

필사 모드: FlashAttention & Efficient Attention Deep Dive — Tiling, Online Softmax, PagedAttention, GQA 완전 정복 (2025)

한국어
0%
정확도 0%
💡 왼쪽 원문을 읽으면서 오른쪽에 따라 써보세요. Tab 키로 힌트를 받을 수 있습니다.
원문 렌더가 준비되기 전까지 텍스트 가이드로 표시합니다.

TL;DR

- **Naive attention**은 $N \times N$ attention matrix를 HBM에 저장. 시퀀스 길이 $N$에 대해 **$O(N^2)$ 메모리**. LLM 긴 컨텍스트에서 치명적.

- **FlashAttention** (Tri Dao, 2022): 핵심 통찰은 attention이 **memory-bound**이지 compute-bound 아니라는 것. **Tiling + Online Softmax**로 SRAM만 쓰고 HBM에 N×N 행렬을 절대 쓰지 않음.

- **Online Softmax**: Softmax를 한 번에 계산하는 대신 블록 단위로 누적. 수학적으로 동등한데 IO가 완전히 다름.

- **FlashAttention-2** (2023): 병렬화 개선, Forward/Backward 최적화. 2-3배 빠름.

- **FlashAttention-3** (2024, Hopper): Tensor Core + 비동기 WGMMA + warp 특수화. H100에서 75% 효율.

- **PagedAttention** (vLLM, 2023): OS 가상 메모리 아이디어를 KV cache에. **2-4배 throughput**.

- **Grouped Query Attention (GQA)**: Query head 여러 개가 같은 K/V head 공유. **KV cache 크기 1/N**.

- **Sliding Window Attention**: 컨텍스트 윈도우 제한. Mistral, Gemma가 사용.

- **Ring Attention**: 여러 GPU로 sequence 분할. 10M 토큰 컨텍스트 가능.

1. Attention의 병목 문제

1.1 Transformer의 심장

Transformer의 핵심 연산:

$$

\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

$$

- $Q$, $K$, $V$: (N, d) 크기. N = 시퀀스 길이, d = 헤드 차원.

- $QK^T$: (N, N) 점수 행렬.

- Softmax: 행별 정규화.

- 결과에 $V$ 곱: (N, d) 출력.

이것이 Transformer의 품질의 원천이지만, 동시에 **가장 큰 비용**.

1.2 O(N²) 메모리

$QK^T$ 크기 = $N \times N$. 시퀀스 길이에 제곱 비례.

N = 1,000: N² = 1,000,000 = 1 M 요소

N = 10,000: N² = 100 M

N = 100,000: N² = 10 G ← GPU 메모리 초과

FP16이면 요소당 2 바이트. 100,000 토큰 컨텍스트는 **20 GB**의 attention matrix만으로 차지. 한 레이어에. 한 헤드에.

실제로는 layer × head × batch → 수백 GB. **불가능**.

1.3 Compute vs Memory

Naive attention 구현:

S = Q @ K.T # (N, N)

P = softmax(S) # (N, N)

O = P @ V # (N, d)

각 단계에서 **HBM 읽기/쓰기**.

- $Q$, $K$, $V$ 읽기: $O(Nd)$

- $S$ 쓰기: $O(N^2)$

- $S$ 읽기, $P$ 쓰기: $O(N^2)$

- $P$ 읽기, $V$ 읽기, $O$ 쓰기: $O(N^2 + Nd)$

HBM 총 I/O: **$O(N^2)$**.

GPU의 HBM 대역폭이 아무리 빨라도(3 TB/s on H100) $N^2$ 데이터를 옮기는 데 시간이 걸린다. **대부분 시간을 HBM ↔ SRAM 복사에 씀**. 컴퓨트 유닛은 대부분 놀고 있다.

1.4 "Memory-bound, not compute-bound"

이것이 Tri Dao의 핵심 통찰이었다. "Attention 성능을 올리려면 FLOPs를 줄이는 게 아니라 **HBM 접근을 줄여야** 한다."

2. GPU 메모리 계층 복습

2.1 두 가지 메모리

**HBM** (High Bandwidth Memory):

- GPU의 "main memory". 수십~수백 GB.

- 대역폭 ~3 TB/s (H100).

- 지연 시간 ~500 cycle.

**SRAM** (Shared Memory / L1):

- SM 내부의 **on-chip** 메모리.

- 매우 작음 (각 SM당 ~228 KB on H100).

- 대역폭 ~19 TB/s (H100).

- 지연 시간 ~20 cycle.

**SRAM이 6배 이상 빠르다**. 그리고 **지연 시간이 25배 적다**.

2.2 위치의 중요성

데이터가 SRAM에 있으면 compute unit이 놀 시간 없이 곧바로 처리.

HBM에 있으면 SRAM으로 가져오는 동안 기다려야 함.

Attention 최적화의 목표: **SRAM에서 가능한 한 모든 계산을 완료**.

문제: SRAM이 너무 작아 $N \times N$ 행렬이 들어가지 않는다.

해결: **Tiling**.

3. FlashAttention의 핵심 아이디어

3.1 Tiling

행렬을 **블록**으로 나누어 순차 처리:

Q: (N, d) → Q_1, Q_2, ..., Q_{N/B_r} (각 (B_r, d))

K, V: (N, d) → K_1, K_2, ..., K_{N/B_c} (각 (B_c, d))

블록 크기 $B_r$, $B_c$는 SRAM에 맞게 선택.

3.2 Naive Blockwise

각 $Q_i$에 대해 전체 K, V를 순회해서 $O_i$ 계산:

for i in range(N / B_r):

Q_i = load(Q, i)

for j in range(N / B_c):

K_j, V_j = load(K, V, j)

S_ij = Q_i @ K_j.T

하지만 softmax는 전체 행에 걸쳐...?

**문제**: softmax는 각 행의 **모든 값**을 알아야 정규화 가능. 블록 단위로는 지역적 정보만.

3.3 Online Softmax — 핵심 트릭

Naive softmax:

$$

\text{softmax}(x)_i = \frac{e^{x_i}}{\sum_j e^{x_j}}

$$

**수치 안정화** 버전 (max 뺌):

$$

\text{softmax}(x)_i = \frac{e^{x_i - m}}{\sum_j e^{x_j - m}}, \quad m = \max_j x_j

$$

이것을 **블록 단위로** 계산할 수 있다 — 만약 "현재까지의 max와 sum"을 유지하면.

**Incremental softmax**:

m_prev = -∞ # 현재까지 본 max

l_prev = 0 # 현재까지 본 sum

For each block:

m_block = max(block)

m_new = max(m_prev, m_block)

기존 누적치 rescale

l_prev = l_prev * exp(m_prev - m_new)

현재 블록

l_block = sum(exp(block - m_new))

합침

l_prev = l_prev + l_block

m_prev = m_new

마지막에 `softmax = exp(x - m_prev) / l_prev`.

**수학적으로 정확**. 단, 각 블록 처리 시 이전 누적을 재조정해야 함.

3.4 FlashAttention 알고리즘

초기화: O = 0, l = 0, m = -∞ (출력, sum, max)

For j in range(N / B_c): # K, V 블록

K_j, V_j ← HBM 로드 (SRAM으로)

For i in range(N / B_r): # Q 블록

Q_i ← HBM 로드 (SRAM으로)

S_ij = Q_i @ K_j.T # SRAM 내부 계산

Online softmax 업데이트

m_ij = rowmax(S_ij)

P_ij = exp(S_ij - m_ij)

l_ij = rowsum(P_ij)

이전과 결합

m_new = max(m_i, m_ij)

l_new = exp(m_i - m_new) * l_i + exp(m_ij - m_new) * l_ij

O 업데이트

O_i = diag(l_new)^-1 * (diag(l_i) * exp(m_i - m_new) * O_i

+ exp(m_ij - m_new) * P_ij * V_j)

m_i = m_new

l_i = l_new

O_i 저장 (HBM)

핵심:

- $Q$, $K$, $V$만 HBM에서 읽기.

- $S = QK^T$ 전체 행렬을 **절대 HBM에 쓰지 않음**.

- 각 블록 내부 계산은 SRAM.

- $O$ 최종 결과만 HBM.

3.5 HBM I/O 분석

Naive attention: $O(Nd + N^2)$ HBM 접근.

FlashAttention: $O(N^2 d^2 / M)$ HBM 접근. $M$은 SRAM 크기.

예: $N = 8192$, $d = 128$, $M \approx 100 KB$.

- Naive: ~67 M bytes (FP16).

- FlashAttention: ~4 M bytes.

**17배 적은 HBM I/O**. 실제 속도 향상 5-10배.

3.6 정확성

"근사"가 아니다. **수학적으로 정확한 attention**. Online softmax 덕분에 블록 단위 처리가 전체 softmax와 동일한 결과.

4. Backward Pass

훈련 시 gradient 계산도 attention.

4.1 Naive Backward

Forward 시 $P = \text{softmax}(QK^T)$를 저장했다가 backward에서 사용. **$O(N^2)$ 메모리 저장**.

4.2 FlashAttention Backward

**Recomputation**: forward에서 $P$를 저장하지 않고, backward에서 다시 계산.

Forward: Q, K, V, O, l, m만 저장 (O(N) 메모리, not O(N²))

Backward: 필요할 때 S, P 재계산

FLOP 증가 (약 2x compute), 메모리 감소 (O(N²) → O(N)). 전체적으로 여전히 빠름 — **memory-bound**라서.

4.3 결과

LLaMA 훈련 같은 워크로드:

- 메모리: 10배 감소.

- 속도: 2-3배 빠름.

- 긴 컨텍스트(8K-32K) 훈련 가능.

5. FlashAttention-2

2023년 Tri Dao의 후속. 주요 개선:

5.1 더 나은 병렬화

**FlashAttention-1**: 시퀀스 차원으로만 병렬.

**FlashAttention-2**: Batch, head, sequence 3차원 병렬 → GPU 활용 ↑.

5.2 Warp 간 작업 재분배

각 warp가 하는 일을 재설계. 동기화 오버헤드 감소.

5.3 Non-matmul FLOP 감소

Softmax 관련 non-matmul 연산 (rescaling 등)을 줄임. Tensor Core가 아닌 CUDA Core 부하가 병목이었는데 이를 감소.

5.4 성능

FlashAttention-1 대비 **2배 빠름**. Standard attention 대비 5-10배.

LLaMA-2 13B 훈련에서 사용. 긴 컨텍스트(16K, 32K) 훈련 실용화.

6. FlashAttention-3 (2024, Hopper)

NVIDIA Hopper 아키텍처(H100)에 특화된 버전.

6.1 새 Hopper 기능 활용

**WGMMA** (Warp Group Matrix Multiply-Accumulate): 비동기 행렬 곱. 새 instruction.

**TMA** (Tensor Memory Accelerator): 전용 async copy 엔진.

**FP8 Tensor Core**: 정밀도 낮춤으로 대역폭/속도 ↑.

6.2 알고리즘 변경

**Producer-Consumer 모델**: 일부 warp가 memory load 담당 (producer), 다른 warp가 compute 담당 (consumer). 오버랩.

**Warp Specialization**: Warp마다 특정 역할 고정. 스케줄링 유연성 감소, 성능 ↑.

6.3 결과

H100에서 **peak FP16 throughput의 75%** 달성. 이전 flash-attention은 35% 정도.

GPT-4 훈련 수준 워크로드에 필수.

7. PagedAttention — vLLM

2023년 UC Berkeley Sky Lab (Woosuk Kwon et al.)의 **vLLM**.

7.1 KV Cache 문제

LLM 추론 시 각 이전 토큰의 K, V를 저장해야 한다 — 다음 토큰 생성에 필요.

**LLaMA-13B**, 2048 context:

- K: (2048, 40, 128) per layer × 40 layers × FP16 = 단일 시퀀스당 **800 MB**.

- V: 같은 크기.

- Total: ~1.6 GB per sequence.

100 동시 사용자 → 160 GB. 80 GB GPU로 불가능.

7.2 Fragmentation 문제

전통적 할당:

- 각 요청에 **max_length**만큼 미리 할당 (예: 2048 토큰).

- 실제 생성은 평균 500 토큰 → **75% 낭비**.

- **Internal fragmentation**.

또한 요청마다 다른 크기 → **external fragmentation**.

7.3 OS 가상 메모리 아이디어

**운영체제의 가상 메모리**는 이 문제를 수십 년 전에 해결했다:

- 프로세스 메모리를 **고정 크기 페이지**로 분할.

- **페이지 테이블**로 논리 주소 → 물리 페이지 매핑.

- 필요 시에만 물리 페이지 할당.

- 페이지는 서로 **공유** 가능 (shared memory, copy-on-write).

PagedAttention: **이 모든 아이디어를 LLM KV cache에 적용**.

7.4 구조

KV cache를 block(예: 16 tokens)으로 분할

Sequence A: [block 0] → [block 1] → [block 2]

↓ ↓ ↓

Physical Physical Physical

Page 42 Page 7 Page 13

Sequence B: [block 0] → [block 1]

↓ ↓

Physical Physical

Page 42 Page 3 ← 공유!

(copy-on-write)

- 논리적으로 연속된 블록이 물리적으로 흩어져 있어도 OK.

- 필요할 때만 할당.

- 공통 prefix는 공유 (copy-on-write).

7.5 이점

**1. 메모리 효율**:

- Fragmentation 거의 제거.

- 100 users → 100 GB → 30 GB (실제 측정).

**2. Sharing**:

- 같은 시스템 프롬프트를 여러 요청이 공유.

- 같은 조건의 beam search branch가 공유.

**3. Throughput**:

- 기존 대비 **2-4배 더 많은 동시 요청** 처리.

- GPU 활용률 ↑.

7.6 vLLM의 영향

vLLM은 2023년 말 공개, 순식간에 **LLM 서빙의 표준**이 됐다. 거의 모든 open source LLM 서버(TGI, Ollama, Llama.cpp도 일부)가 비슷한 아이디어 채택.

**2024**: PagedAttention의 구현 변형, CUDA 커널 튜닝 논문 봇물.

8. Multi-Query / Grouped-Query Attention

8.1 Multi-Head Attention (원본)

원래 Transformer:

- $H$개의 query head.

- 각자 자기 $Q_h$, $K_h$, $V_h$.

- 독립적 attention 계산 후 concat.

예: LLaMA-2 7B에 $H = 32$. 32 Q head + 32 K head + 32 V head.

8.2 KV Cache 부담

KV cache 크기: $2 \cdot H \cdot L \cdot d$ per token.

- $H = 32$, $L = 32$ layers, $d = 128$ → 토큰당 524 KB (FP16).

- 10,000 토큰 컨텍스트 → 5.2 GB per sequence.

8B 모델에서 모델 자체 16 GB, KV cache 5 GB, batch 여러 개 → GPU 메모리 한계.

8.3 Multi-Query Attention (MQA)

**모든 query head가 하나의 K, V head를 공유**.

- H개의 Q head, 1개의 K head, 1개의 V head.

- KV cache 크기: **1/H**.

- 계산량은 거의 같음.

단점: 품질 약간 하락. "모든 head가 같은 K/V를 보면 표현력 감소."

8.4 Grouped-Query Attention (GQA)

LLaMA-2/3에서 사용. 타협:

- $H$개의 Q head.

- $G$개의 K, V head ($G < H$, 보통 $H / 4$ 또는 $H / 8$).

- 각 K, V head를 $H / G$개의 Q head가 공유.

예: LLaMA-3 8B, $H = 32$, $G = 8$ → 4 Q heads per KV head.

- KV cache: **1/4**.

- 품질: MHA와 거의 동일.

- 속도: 추론 빠름.

8.5 KV Cache 비교

동일 모델 크기에서:

- **MHA**: 100% (기준).

- **GQA (G=H/4)**: 25%.

- **MQA**: 1/H (3%).

**4배 메모리 감소**로 4배 긴 컨텍스트 또는 4배 동시 요청.

LLaMA-2 34B부터 GQA 사용 → LLaMA-3 표준. Mistral, Mixtral도 같은 접근.

9. Sliding Window Attention

9.1 아이디어

각 토큰이 **최근 N개 토큰**만 본다. 전체가 아니라.

Token at position i attends to:

positions [i - W, i - W + 1, ..., i]

$W$는 window size (예: 4096).

9.2 장점

- **계산**: $O(NW)$ 대신 $O(N^2)$. $N \gg W$면 훨씬 빠름.

- **KV cache**: 최근 $W$개만 유지 → **고정 크기**. 시퀀스 무한 확장 가능.

9.3 단점

- **멀리 있는 토큰 못 봄**. 이론상 장거리 의존성 손실.

- 하지만 여러 layer 거치면 **effective receptive field**가 누적. 예: 12 layers × 4096 window = 49152 receptive field.

9.4 사용

**Mistral 7B**: Sliding window 4096. 긴 컨텍스트에 매우 효율.

**Gemma**: Local + global attention 번갈아.

**Longformer**: 대부분 local, 몇 개 토큰만 global.

9.5 조합

많은 현대 LLM이 **local + global attention**을 섞는다:

- 대부분 layer: local (sliding window).

- 일부 layer: full global.

- 결과: 긴 컨텍스트 + 효율.

10. Flash Decoding

10.1 문제

LLM 추론의 두 단계:

1. **Prefill**: 입력 전체를 한 번에 처리. Batch compute-intensive.

2. **Decode**: 토큰 하나씩 생성. 각 단계에서 한 토큰만 compute.

Decode는 **본질적으로 sequential** — 이전 토큰의 출력이 다음의 입력. 그리고 **memory-bound** — 작은 compute, 큰 KV cache 읽기.

10.2 GPU 활용률 문제

Decode 시 단일 토큰 attention은 작은 연산. GPU가 거의 놀고 있음. 특히 long context:

Decode step: q (1 token) × K (context 10000 tokens) → score (10000)

이건 Matrix-vector 곱 (GEMV). Compute는 작고 memory bandwidth가 병목.

10.3 Flash Decoding 아이디어

**KV cache를 chunk로 나누고 각 chunk를 별도 SM에서 병렬 처리**. Decode 시점에 parallelism 발굴.

Sequence K, V 분할: chunks of size C

각 chunk에 대해: local softmax 계산 (in parallel)

최종: partial 결과 집계

Flash Attention의 online softmax 아이디어를 decode에 적용.

10.4 결과

Long context decode에서 **8배 속도**. LLaMA 13B의 16K 컨텍스트 decode가 실용화.

11. Ring Attention

11.1 거대한 컨텍스트 문제

1M, 10M 토큰 컨텍스트 — 단일 GPU 메모리 초과.

11.2 기본 아이디어

**시퀀스를 여러 GPU로 분할**. 각 GPU가 sequence의 일부.

GPU 0: Q[0:N/4]

GPU 1: Q[N/4:N/2]

GPU 2: Q[N/2:3N/4]

GPU 3: Q[3N/4:N]

K, V도 같은 방식으로 분할.

11.3 통신 패턴

각 Q 블록은 **모든 K, V 블록**을 봐야 한다. 순차적으로 "ring"으로 전달:

GPU 0 ← K_0, V_0 (local)

GPU 1 ← K_1, V_1 (local)

GPU 2 ← K_2, V_2 (local)

GPU 3 ← K_3, V_3 (local)

Round 1: Ring rotate

GPU 0 ← K_3, V_3

GPU 1 ← K_0, V_0

GPU 2 ← K_1, V_1

GPU 3 ← K_2, V_2

Round 2, 3: 계속 rotate

총 N rounds 후 모든 Q가 모든 K, V를 봤음.

각 rotate 동안 compute와 통신 **겹침** (async).

11.4 결과

**Ring Attention** (Liu et al. 2023):

- 10M 토큰 컨텍스트 훈련 가능.

- Gemini 1.5 Pro의 1M+ context의 기반 기술 중 하나.

Sora, Llama-3 긴 context에도 유사 기법.

12. Sparse Attention

12.1 아이디어

"대부분의 attention weight는 작다" — 실제로 몇 개 토큰에 집중. 나머지는 sparse로 처리.

12.2 패턴

**Block Sparse**: 미리 정의된 패턴.

- Strided: 일정 간격.

- Fixed: 국소 + 전역 혼합.

**Longformer**: Sliding window + 선택적 global tokens.

**BigBird**: Random + window + global. 이론적 universal approximator.

**Sparse Transformer** (OpenAI 2019): Stride + local.

12.3 한계

Sparse 패턴이 고정이라 workload에 따라 품질 변동. Dense attention만큼 일반화되지 않음. 주류 LLM은 대부분 dense + efficient 기법(FlashAttention + GQA).

12.4 학습된 Sparsity

최근 연구: 학습 중 attention 패턴을 모델이 선택. 예: Routing Attention, Switch Transformer에 힌트.

아직 주류는 아님.

13. 다른 효율화 기법

13.1 Linear Attention

Attention의 $O(N^2)$를 **$O(N)$**로 만드는 수학적 트릭.

Core: softmax를 feature map $\phi$로 근사.

$$

\text{Attention} = \text{softmax}(QK^T)V \approx \phi(Q) (\phi(K)^T V)

$$

$\phi(K)^T V$를 먼저 계산 ($O(Nd^2)$), 그 다음 $\phi(Q)$와 곱함 ($O(Nd^2)$).

- **Performer** (2020): Random feature로 softmax 근사.

- **Linformer** (2020): K, V를 저차원으로 project.

단점: 품질 손실 (특히 긴 sequence). 주류 LLM 채택 거의 없음.

13.2 KV Cache Quantization

KV cache 메모리 절반 또는 더 줄이기:

- **FP16 → INT8**: 2배 감소.

- **FP16 → INT4**: 4배 감소.

- 품질 손실 최소화 for recent KV cache.

13.3 KV Cache Offloading

KV cache를 **CPU DRAM** 또는 **NVMe**에 저장. 필요할 때 GPU로 가져옴.

**vLLM**: CPU ↔ GPU swap 지원.

**DeepSpeed**: 더 공격적.

속도 손실이 있지만 메모리 제한이 극복된다. Long context 서비스에 유용.

13.4 Speculative Decoding

작은 drafter 모델이 여러 토큰을 predictor. 큰 모델이 validator. 올바른 만큼 한 번에 수락.

**Medusa**, **Lookahead decoding** 같은 variant들. 2-3배 추론 속도.

Attention 최적화 직접 아니지만, LLM 서빙 효율 향상에 핵심.

14. 생태계

14.1 구현

**FlashAttention** (공식, Tri Dao): https://github.com/Dao-AILab/flash-attention. PyTorch CUDA extension.

**xformers** (Meta): 다양한 efficient attention 구현 모음. FlashAttention 통합.

**vLLM** (UC Berkeley): PagedAttention 구현. LLM 서빙.

**TRT-LLM** (NVIDIA): TensorRT 기반 고성능 추론. FlashAttention, PagedAttention 포함.

**SGLang**: vLLM 대안. 다른 최적화 전략.

**Triton**: FlashAttention을 Triton으로 재구현 가능. Tri Dao의 튜토리얼.

14.2 PyTorch 통합

PyTorch 2.0+:

from torch.nn.functional import scaled_dot_product_attention

output = scaled_dot_product_attention(q, k, v, is_causal=True)

자동으로 FlashAttention 2 사용 (설치됐으면). 완전 투명.

14.3 HuggingFace

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(

"meta-llama/Llama-3-8B",

torch_dtype=torch.float16,

attn_implementation="flash_attention_2" # 명시적

)

대부분 최신 모델이 FlashAttention 2/3 지원.

15. 미래 방향

15.1 하드웨어 진화

**Blackwell** (B100/B200, 2024): FP4 Tensor Core, 더 큰 SRAM, TMEM(Tensor Memory) 새 유닛.

FlashAttention 4 추정: Blackwell 특화 버전이 곧 나올 것.

15.2 긴 컨텍스트

1M, 10M, 100M 컨텍스트가 목표. Ring Attention + hierarchical attention 조합.

Gemini 1.5 이미 2M. Gemini 2.0+ 더 큰.

15.3 Post-Attention?

일부 연구자는 "attention은 O(N²)라는 근본 한계가 있으니 완전히 새 아키텍처가 필요"라고 주장.

- **State Space Models (Mamba)**: $O(N)$ 추론.

- **RWKV**: RNN의 부활.

- **Hyena**: Long convolution 기반.

하지만 2025년 현재 dense attention + efficient 기법이 품질 면에서 여전히 승.

15.4 Quantization 극단

INT2, ternary, 1-bit. KV cache + weights 양쪽에 적용. Sub-billion GPU 메모리로 대형 모델.

16. 학습 리소스

**논문**:

- "FlashAttention: Fast and Memory-Efficient Exact Attention" — Tri Dao et al. 2022.

- "FlashAttention-2" — Tri Dao 2023.

- "FlashAttention-3" — Shah et al. 2024.

- "Efficient Memory Management for LLM Serving with PagedAttention" — vLLM paper.

- "Ring Attention" — Liu et al. 2023.

- "Efficient Streaming Language Models with Attention Sinks" — 2023.

**블로그**:

- Tri Dao의 개인 블로그 / Princeton 페이지.

- Hugging Face 블로그 (LLM 최적화 시리즈).

- Together.ai 블로그 (serving 최적화).

- UC Berkeley Sky Lab.

**코드**:

- Dao-AILab/flash-attention.

- vllm-project/vllm.

- NVIDIA/TensorRT-LLM.

- huggingface/text-generation-inference (TGI).

**강의**:

- "GPU MODE" (Discord + YouTube).

- MIT 6.5940 "TinyML" (Han Song).

- Stanford CS229M.

17. 요약 — 한 장 정리

┌─────────────────────────────────────────────────────┐

│ Efficient Attention Cheat Sheet │

├─────────────────────────────────────────────────────┤

│ 문제: │

│ Naive attention: O(N²) 메모리 │

│ HBM I/O 병목 (memory-bound) │

│ │

│ FlashAttention (Tri Dao 2022): │

│ Tiling: 블록 단위 처리 │

│ Online softmax: incremental normalization │

│ SRAM에서 계산, HBM에 N² 쓰지 않음 │

│ 정확 (근사 아님) │

│ Backward: recomputation │

│ 5-10x 빠름 │

│ │

│ FlashAttention-2 (2023): │

│ 병렬화 개선 │

│ Warp 작업 재분배 │

│ 2-3x 추가 속도 │

│ │

│ FlashAttention-3 (Hopper 2024): │

│ WGMMA (비동기) │

│ Producer-Consumer │

│ Warp specialization │

│ FP8 Tensor Core │

│ H100의 75% 효율 │

│ │

│ PagedAttention (vLLM 2023): │

│ OS 가상 메모리 → KV cache │

│ 블록 단위 할당 │

│ Block table (page table 유사) │

│ 공유 (copy-on-write) │

│ 2-4x throughput │

│ │

│ MQA / GQA: │

│ Multi-Query: 1 K/V head │

│ Grouped-Query: G (< H) K/V heads │

│ KV cache 1/N 감소 │

│ LLaMA-2/3, Mistral 사용 │

│ │

│ Sliding Window: │

│ 최근 W 토큰만 │

│ O(NW) 계산, fixed KV cache │

│ Mistral, Gemma │

│ │

│ Ring Attention: │

│ 시퀀스 여러 GPU 분할 │

│ K/V를 ring으로 rotate │

│ 10M+ context │

│ │

│ Flash Decoding: │

│ KV를 chunk로 분할 │

│ Decode time에 병렬 │

│ Long context 8x │

│ │

│ Sparse Attention: │

│ Longformer, BigBird │

│ 품질 trade-off │

│ │

│ 기타: │

│ KV cache quantization (INT8/4) │

│ KV offloading (CPU/NVMe) │

│ Speculative decoding │

│ │

│ 통합: │

│ PyTorch scaled_dot_product_attention │

│ HuggingFace attn_implementation │

│ xformers, vLLM, TRT-LLM, SGLang │

└─────────────────────────────────────────────────────┘

18. 퀴즈

**A.** 기존 attention 최적화는 "FLOP을 줄이자"(linear attention, sparse attention 등)에 집중했지만 실제 병목은 **HBM I/O**였다. GPU의 HBM 대역폭(3 TB/s)이 빠르지만 $O(N^2)$ 크기의 $S = QK^T$ 행렬을 HBM에 쓰고 다시 읽는 것만으로 수 GB를 옮긴다. Compute units(Tensor Core)는 데이터를 기다리며 대부분 시간 idle. Tri Dao의 통찰: "FLOP은 충분히 빠르니 **HBM 접근을 줄이면** 자동으로 빨라진다". 이 관점에서 tiling + online softmax로 SRAM만 사용하도록 알고리즘을 재설계. FLOP은 오히려 약간 증가(backward의 recomputation)하지만 HBM I/O는 17배 감소 → 5-10배 속도 향상. 시스템 엔지니어링의 교훈: "**병목을 오해하면 최적화가 소용없다**".

**A.** Softmax는 전체 행의 모든 값이 필요해서(max와 sum 계산) 블록 단위로 쪼갤 수 없어 보인다. 트릭: **현재까지의 max와 sum을 유지하면서 새 블록이 올 때마다 재조정**. 새 블록의 max가 이전 max보다 크면 이전 sum에 `exp(old_max - new_max)`를 곱해 "scale down" 하고, 새 블록의 contribution을 더한다. 최종 결과는 **수학적으로 전체 softmax와 정확히 동등**. 이 덕분에 $S$ 행렬 전체를 메모리에 저장할 필요 없이 블록을 순차로 SRAM에 로드하며 $O$(출력)를 점진적으로 업데이트 가능. 수치적으로도 안전 — max 뺌으로 overflow 방지. 이 아이디어는 1996년 "Keeping a running softmax" 수치해석 논문에 있었지만 FlashAttention이 **GPU tiling과 결합**해서 현대 AI 가속의 핵심 기법이 됐다.

**A.** **네 가지 핵심 OS 개념**: (1) **고정 크기 페이지** — KV cache를 연속 블록으로 할당하지 않고 작은 블록(예: 16 tokens)으로 나눔, (2) **페이지 테이블** — 논리적 토큰 위치에서 물리적 블록으로의 매핑. 각 시퀀스가 자기 block table을 가짐, (3) **On-demand 할당** — 필요한 시점에만 물리 블록 확보, over-allocation 방지, (4) **공유 + Copy-on-Write** — 같은 시스템 프롬프트로 시작하는 여러 요청이 같은 물리 블록을 공유, divergence 시에만 복사. 결과: fragmentation이 거의 없고(80% → 5% 미만), 공유로 메모리 절약, throughput 2-4배 증가. 운영체제가 1960년대에 해결한 문제를 2023년에 LLM serving에 적용한 사례 — "이미 아는 문제의 새 영역 응용"이 연구의 가장 실용적인 형태.

**A.** **품질과 효율의 균형**. **MHA**(Multi-Head Attention): H개의 Q head + H개의 K/V head → 최고 품질, 최대 KV cache. **MQA**(Multi-Query): H개의 Q head + 1개의 K/V head → 최소 KV cache(1/H), 품질 약간 하락 ("모든 head가 같은 K/V를 보면 표현력 감소"). **GQA**: H개의 Q head + G개의 K/V head (G ≈ H/4 또는 H/8) → 중간. KV cache는 4-8배 감소하지만 품질은 MHA와 거의 동일. LLaMA-2 70B와 LLaMA-3 전체 라인업이 GQA 채택: Q head 32 + KV head 8 → 4:1 비율. 실측: KV cache 4배 감소로 동일 GPU에서 4배 더 긴 컨텍스트 또는 4배 더 많은 동시 요청. 품질 손실 측정 불가 수준. "완전한 분리도, 완전한 공유도 아닌 중간"이 최선인 흔한 설계 원칙.

**A.** **근사 알고리즘이 아니라는 보장**. Efficient attention 계열의 다른 접근들(Linear Attention, Performer, Linformer, Sparse Attention)은 대부분 **근사** — 수학적으로 $O(N^2)$ 결과와 다른 값을 출력. 품질 손실이 있고, 특정 워크로드에 잘 맞지 않을 수 있으며, 훈련된 모델의 attention과 완전 호환 안 됨. FlashAttention은 **비트 단위 재현이 가능한 수학적 정확성** — online softmax의 수치적 안정성 덕분에 naive attention과 동등한 결과. 이것이 결정적 — 기존 훈련된 모델(LLaMA, GPT)에 **drop-in replacement**로 사용 가능. 프레임워크가 자동으로 FlashAttention을 켜도 모델 출력이 완전히 같음. 성능 향상이 "부작용 없는 업그레이드"가 된다. 근사 방법들이 널리 채택 안 된 이유: 작은 차이도 downstream task에서 예측 불가능한 영향.

**A.** **Sequence parallelism + 통신-계산 오버랩**. 단일 GPU는 큰 sequence의 attention matrix를 저장 불가(10M² = 100조 entry). Ring Attention: sequence를 N개 GPU로 분할(각 1/N). 각 GPU는 자기 Q 블록과 local K/V로 시작, 그 다음 K/V를 **ring 토폴로지로 순환**(GPU 0 → 1 → 2 → ... → 0). 매 회전마다 새 K/V 블록을 받으며 부분 attention 계산. **Online softmax**를 사용해 부분 결과를 정확하게 합쳐 최종 출력. 중요한 점: **통신이 compute와 동시 진행** — 현재 블록 계산하는 동안 다음 블록 수신. Async로 latency 숨김. N rounds 후 모든 Q가 모든 K/V를 봤고, 각 GPU는 자기 Q 블록의 출력만 가지면 됨. Gemini 1.5 Pro의 1-2M 컨텍스트, Sora의 긴 비디오 시퀀스가 이런 기법 기반. "분산 시스템이 메모리 한계를 극복"의 완벽한 사례.

**A.** **추론 단계별 특성이 다름**. Prefill 단계(입력 전체 처리)는 compute-heavy — 많은 Q, K, V를 한 번에 GEMM. Decode 단계(토큰 하나씩 생성)는 **memory-bandwidth 제한** — 매 step에 single query × long K/V context로 GEMV(행렬-벡터 곱). GEMM은 arithmetic intensity 높고 Tensor Core 활용 가능, GEMV는 compute/memory 비율 낮아 HBM bandwidth가 병목. 한 토큰 decode 시 KV cache 전체 읽기만 수 ms. **Flash Attention**은 여전히 유용하지만 GEMM 최적화라 decode에서 이득 작음. **Flash Decoding**이 해결책: KV cache를 chunk로 분할하고 각 chunk를 별도 SM에서 병렬 처리 → decode에 parallelism 도입. Online softmax로 partial 결과 합산. 장문 컨텍스트 decode에서 8배 속도. 관찰: "LLM 서빙 최적화는 prefill과 decode를 따로 생각해야 함". 이 둘이 근본적으로 다른 병목.

이 글이 도움이 됐다면 다음 포스트도 확인해 보세요:

- "Transformer Architecture Deep Dive" — Attention의 기본.

- "CUDA GPU Programming Deep Dive" — 커널 최적화의 기반.

- "Diffusion Models Deep Dive" — 또 다른 attention-heavy 워크로드.

- "RDMA & NCCL" — 멀티 GPU 훈련의 통신 기반.

현재 단락 (1/417)

- **Naive attention**은 $N \times N$ attention matrix를 HBM에 저장. 시퀀스 길이 $N$에 대해 **$O(N^2)$ 메모리**. LLM ...

작성 글자: 0원문 글자: 16,189작성 단락: 0/417