Skip to content
Published on

FlashAttention & Efficient Attention Deep Dive — Tiling, Online Softmax, PagedAttention, GQA 완전 정복 (2025)

Authors

TL;DR

  • Naive attentionN×NN \times N attention matrix를 HBM에 저장. 시퀀스 길이 NN에 대해 O(N2)O(N^2) 메모리. LLM 긴 컨텍스트에서 치명적.
  • FlashAttention (Tri Dao, 2022): 핵심 통찰은 attention이 memory-bound이지 compute-bound 아니라는 것. Tiling + Online Softmax로 SRAM만 쓰고 HBM에 N×N 행렬을 절대 쓰지 않음.
  • Online Softmax: Softmax를 한 번에 계산하는 대신 블록 단위로 누적. 수학적으로 동등한데 IO가 완전히 다름.
  • FlashAttention-2 (2023): 병렬화 개선, Forward/Backward 최적화. 2-3배 빠름.
  • FlashAttention-3 (2024, Hopper): Tensor Core + 비동기 WGMMA + warp 특수화. H100에서 75% 효율.
  • PagedAttention (vLLM, 2023): OS 가상 메모리 아이디어를 KV cache에. 2-4배 throughput.
  • Grouped Query Attention (GQA): Query head 여러 개가 같은 K/V head 공유. KV cache 크기 1/N.
  • Sliding Window Attention: 컨텍스트 윈도우 제한. Mistral, Gemma가 사용.
  • Ring Attention: 여러 GPU로 sequence 분할. 10M 토큰 컨텍스트 가능.

1. Attention의 병목 문제

1.1 Transformer의 심장

Transformer의 핵심 연산:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
  • QQ, KK, VV: (N, d) 크기. N = 시퀀스 길이, d = 헤드 차원.
  • QKTQK^T: (N, N) 점수 행렬.
  • Softmax: 행별 정규화.
  • 결과에 VV 곱: (N, d) 출력.

이것이 Transformer의 품질의 원천이지만, 동시에 가장 큰 비용.

1.2 O(N²) 메모리

QKTQK^T 크기 = N×NN \times N. 시퀀스 길이에 제곱 비례.

N = 1,000:    N² = 1,000,000 = 1 M 요소
N = 10,000:   N² = 100 M
N = 100,000:  N² = 10 GGPU 메모리 초과

FP16이면 요소당 2 바이트. 100,000 토큰 컨텍스트는 20 GB의 attention matrix만으로 차지. 한 레이어에. 한 헤드에.

실제로는 layer × head × batch → 수백 GB. 불가능.

1.3 Compute vs Memory

Naive attention 구현:

S = Q @ K.T  # (N, N)
P = softmax(S)  # (N, N)
O = P @ V  # (N, d)

각 단계에서 HBM 읽기/쓰기.

  • QQ, KK, VV 읽기: O(Nd)O(Nd)
  • SS 쓰기: O(N2)O(N^2)
  • SS 읽기, PP 쓰기: O(N2)O(N^2)
  • PP 읽기, VV 읽기, OO 쓰기: O(N2+Nd)O(N^2 + Nd)

HBM 총 I/O: O(N2)O(N^2).

GPU의 HBM 대역폭이 아무리 빨라도(3 TB/s on H100) N2N^2 데이터를 옮기는 데 시간이 걸린다. 대부분 시간을 HBM ↔ SRAM 복사에 씀. 컴퓨트 유닛은 대부분 놀고 있다.

1.4 "Memory-bound, not compute-bound"

이것이 Tri Dao의 핵심 통찰이었다. "Attention 성능을 올리려면 FLOPs를 줄이는 게 아니라 HBM 접근을 줄여야 한다."


2. GPU 메모리 계층 복습

2.1 두 가지 메모리

HBM (High Bandwidth Memory):

  • GPU의 "main memory". 수십~수백 GB.
  • 대역폭 ~3 TB/s (H100).
  • 지연 시간 ~500 cycle.

SRAM (Shared Memory / L1):

  • SM 내부의 on-chip 메모리.
  • 매우 작음 (각 SM당 ~228 KB on H100).
  • 대역폭 ~19 TB/s (H100).
  • 지연 시간 ~20 cycle.

SRAM이 6배 이상 빠르다. 그리고 지연 시간이 25배 적다.

2.2 위치의 중요성

데이터가 SRAM에 있으면 compute unit이 놀 시간 없이 곧바로 처리. HBM에 있으면 SRAM으로 가져오는 동안 기다려야 함.

Attention 최적화의 목표: SRAM에서 가능한 한 모든 계산을 완료.

문제: SRAM이 너무 작아 N×NN \times N 행렬이 들어가지 않는다.

해결: Tiling.


3. FlashAttention의 핵심 아이디어

3.1 Tiling

행렬을 블록으로 나누어 순차 처리:

Q: (N, d)Q_1, Q_2, ..., Q_{N/B_r}  ( (B_r, d))
K, V: (N, d)K_1, K_2, ..., K_{N/B_c}  ( (B_c, d))

블록 크기 BrB_r, BcB_c는 SRAM에 맞게 선택.

3.2 Naive Blockwise

QiQ_i에 대해 전체 K, V를 순회해서 OiO_i 계산:

for i in range(N / B_r):
    Q_i = load(Q, i)
    for j in range(N / B_c):
        K_j, V_j = load(K, V, j)
        S_ij = Q_i @ K_j.T
        # 하지만 softmax는 전체 행에 걸쳐...?

문제: softmax는 각 행의 모든 값을 알아야 정규화 가능. 블록 단위로는 지역적 정보만.

3.3 Online Softmax — 핵심 트릭

Naive softmax:

softmax(x)i=exijexj\text{softmax}(x)_i = \frac{e^{x_i}}{\sum_j e^{x_j}}

수치 안정화 버전 (max 뺌):

softmax(x)i=eximjexjm,m=maxjxj\text{softmax}(x)_i = \frac{e^{x_i - m}}{\sum_j e^{x_j - m}}, \quad m = \max_j x_j

이것을 블록 단위로 계산할 수 있다 — 만약 "현재까지의 max와 sum"을 유지하면.

Incremental softmax:

m_prev = -∞  # 현재까지 본 max
l_prev = 0   # 현재까지 본 sum

For each block:
    m_block = max(block)
    m_new = max(m_prev, m_block)
    
    # 기존 누적치 rescale
    l_prev = l_prev * exp(m_prev - m_new)
    
    # 현재 블록
    l_block = sum(exp(block - m_new))
    
    # 합침
    l_prev = l_prev + l_block
    m_prev = m_new

마지막에 softmax = exp(x - m_prev) / l_prev.

수학적으로 정확. 단, 각 블록 처리 시 이전 누적을 재조정해야 함.

3.4 FlashAttention 알고리즘

초기화: O = 0, l = 0, m = -  (출력, sum, max)

For j in range(N / B_c):  # K, V 블록
    K_j, V_jHBM 로드 (SRAM으로)
    
    For i in range(N / B_r):  # Q 블록
        Q_iHBM 로드 (SRAM으로)
        
        S_ij = Q_i @ K_j.T    # SRAM 내부 계산
        
        # Online softmax 업데이트
        m_ij = rowmax(S_ij)
        P_ij = exp(S_ij - m_ij)
        l_ij = rowsum(P_ij)
        
        # 이전과 결합
        m_new = max(m_i, m_ij)
        l_new = exp(m_i - m_new) * l_i + exp(m_ij - m_new) * l_ij
        
        # O 업데이트
        O_i = diag(l_new)^-1 * (diag(l_i) * exp(m_i - m_new) * O_i 
                                + exp(m_ij - m_new) * P_ij * V_j)
        
        m_i = m_new
        l_i = l_new
    
    O_i 저장 (HBM)

핵심:

  • QQ, KK, VV만 HBM에서 읽기.
  • S=QKTS = QK^T 전체 행렬을 절대 HBM에 쓰지 않음.
  • 각 블록 내부 계산은 SRAM.
  • OO 최종 결과만 HBM.

3.5 HBM I/O 분석

Naive attention: O(Nd+N2)O(Nd + N^2) HBM 접근. FlashAttention: O(N2d2/M)O(N^2 d^2 / M) HBM 접근. MM은 SRAM 크기.

예: N=8192N = 8192, d=128d = 128, M100KBM \approx 100 KB.

  • Naive: ~67 M bytes (FP16).
  • FlashAttention: ~4 M bytes.

17배 적은 HBM I/O. 실제 속도 향상 5-10배.

3.6 정확성

"근사"가 아니다. 수학적으로 정확한 attention. Online softmax 덕분에 블록 단위 처리가 전체 softmax와 동일한 결과.


4. Backward Pass

훈련 시 gradient 계산도 attention.

4.1 Naive Backward

Forward 시 P=softmax(QKT)P = \text{softmax}(QK^T)를 저장했다가 backward에서 사용. O(N2)O(N^2) 메모리 저장.

4.2 FlashAttention Backward

Recomputation: forward에서 PP를 저장하지 않고, backward에서 다시 계산.

Forward: Q, K, V, O, l, m만 저장 (O(N) 메모리, not O(N²))
Backward: 필요할 때 S, P 재계산

FLOP 증가 (약 2x compute), 메모리 감소 (O(N²) → O(N)). 전체적으로 여전히 빠름 — memory-bound라서.

4.3 결과

LLaMA 훈련 같은 워크로드:

  • 메모리: 10배 감소.
  • 속도: 2-3배 빠름.
  • 긴 컨텍스트(8K-32K) 훈련 가능.

5. FlashAttention-2

2023년 Tri Dao의 후속. 주요 개선:

5.1 더 나은 병렬화

FlashAttention-1: 시퀀스 차원으로만 병렬. FlashAttention-2: Batch, head, sequence 3차원 병렬 → GPU 활용 ↑.

5.2 Warp 간 작업 재분배

각 warp가 하는 일을 재설계. 동기화 오버헤드 감소.

5.3 Non-matmul FLOP 감소

Softmax 관련 non-matmul 연산 (rescaling 등)을 줄임. Tensor Core가 아닌 CUDA Core 부하가 병목이었는데 이를 감소.

5.4 성능

FlashAttention-1 대비 2배 빠름. Standard attention 대비 5-10배.

LLaMA-2 13B 훈련에서 사용. 긴 컨텍스트(16K, 32K) 훈련 실용화.


6. FlashAttention-3 (2024, Hopper)

NVIDIA Hopper 아키텍처(H100)에 특화된 버전.

6.1 새 Hopper 기능 활용

WGMMA (Warp Group Matrix Multiply-Accumulate): 비동기 행렬 곱. 새 instruction.

TMA (Tensor Memory Accelerator): 전용 async copy 엔진.

FP8 Tensor Core: 정밀도 낮춤으로 대역폭/속도 ↑.

6.2 알고리즘 변경

Producer-Consumer 모델: 일부 warp가 memory load 담당 (producer), 다른 warp가 compute 담당 (consumer). 오버랩.

Warp Specialization: Warp마다 특정 역할 고정. 스케줄링 유연성 감소, 성능 ↑.

6.3 결과

H100에서 peak FP16 throughput의 75% 달성. 이전 flash-attention은 35% 정도.

GPT-4 훈련 수준 워크로드에 필수.


7. PagedAttention — vLLM

2023년 UC Berkeley Sky Lab (Woosuk Kwon et al.)의 vLLM.

7.1 KV Cache 문제

LLM 추론 시 각 이전 토큰의 K, V를 저장해야 한다 — 다음 토큰 생성에 필요.

LLaMA-13B, 2048 context:

  • K: (2048, 40, 128) per layer × 40 layers × FP16 = 단일 시퀀스당 800 MB.
  • V: 같은 크기.
  • Total: ~1.6 GB per sequence.

100 동시 사용자 → 160 GB. 80 GB GPU로 불가능.

7.2 Fragmentation 문제

전통적 할당:

  • 각 요청에 max_length만큼 미리 할당 (예: 2048 토큰).
  • 실제 생성은 평균 500 토큰 → 75% 낭비.
  • Internal fragmentation.

또한 요청마다 다른 크기 → external fragmentation.

7.3 OS 가상 메모리 아이디어

운영체제의 가상 메모리는 이 문제를 수십 년 전에 해결했다:

  • 프로세스 메모리를 고정 크기 페이지로 분할.
  • 페이지 테이블로 논리 주소 → 물리 페이지 매핑.
  • 필요 시에만 물리 페이지 할당.
  • 페이지는 서로 공유 가능 (shared memory, copy-on-write).

PagedAttention: 이 모든 아이디어를 LLM KV cache에 적용.

7.4 구조

KV cache를 block(: 16 tokens)으로 분할

Sequence A: [block 0][block 1][block 2]
                 ↓           ↓           ↓
          Physical    Physical    Physical
          Page 42     Page 7      Page 13

Sequence B: [block 0][block 1]
                 ↓           ↓
          Physical    Physical
          Page 42     Page 3    ← 공유!
          (copy-on-write)
  • 논리적으로 연속된 블록이 물리적으로 흩어져 있어도 OK.
  • 필요할 때만 할당.
  • 공통 prefix는 공유 (copy-on-write).

7.5 이점

1. 메모리 효율:

  • Fragmentation 거의 제거.
  • 100 users → 100 GB → 30 GB (실제 측정).

2. Sharing:

  • 같은 시스템 프롬프트를 여러 요청이 공유.
  • 같은 조건의 beam search branch가 공유.

3. Throughput:

  • 기존 대비 2-4배 더 많은 동시 요청 처리.
  • GPU 활용률 ↑.

7.6 vLLM의 영향

vLLM은 2023년 말 공개, 순식간에 LLM 서빙의 표준이 됐다. 거의 모든 open source LLM 서버(TGI, Ollama, Llama.cpp도 일부)가 비슷한 아이디어 채택.

2024: PagedAttention의 구현 변형, CUDA 커널 튜닝 논문 봇물.


8. Multi-Query / Grouped-Query Attention

8.1 Multi-Head Attention (원본)

원래 Transformer:

  • HH개의 query head.
  • 각자 자기 QhQ_h, KhK_h, VhV_h.
  • 독립적 attention 계산 후 concat.

예: LLaMA-2 7B에 H=32H = 32. 32 Q head + 32 K head + 32 V head.

8.2 KV Cache 부담

KV cache 크기: 2HLd2 \cdot H \cdot L \cdot d per token.

  • H=32H = 32, L=32L = 32 layers, d=128d = 128 → 토큰당 524 KB (FP16).
  • 10,000 토큰 컨텍스트 → 5.2 GB per sequence.

8B 모델에서 모델 자체 16 GB, KV cache 5 GB, batch 여러 개 → GPU 메모리 한계.

8.3 Multi-Query Attention (MQA)

모든 query head가 하나의 K, V head를 공유.

  • H개의 Q head, 1개의 K head, 1개의 V head.
  • KV cache 크기: 1/H.
  • 계산량은 거의 같음.

단점: 품질 약간 하락. "모든 head가 같은 K/V를 보면 표현력 감소."

8.4 Grouped-Query Attention (GQA)

LLaMA-2/3에서 사용. 타협:

  • HH개의 Q head.
  • GG개의 K, V head (G<HG < H, 보통 H/4H / 4 또는 H/8H / 8).
  • 각 K, V head를 H/GH / G개의 Q head가 공유.

예: LLaMA-3 8B, H=32H = 32, G=8G = 8 → 4 Q heads per KV head.

  • KV cache: 1/4.
  • 품질: MHA와 거의 동일.
  • 속도: 추론 빠름.

8.5 KV Cache 비교

동일 모델 크기에서:

  • MHA: 100% (기준).
  • GQA (G=H/4): 25%.
  • MQA: 1/H (3%).

4배 메모리 감소로 4배 긴 컨텍스트 또는 4배 동시 요청.

LLaMA-2 34B부터 GQA 사용 → LLaMA-3 표준. Mistral, Mixtral도 같은 접근.


9. Sliding Window Attention

9.1 아이디어

각 토큰이 최근 N개 토큰만 본다. 전체가 아니라.

Token at position i attends to:
    positions [i - W, i - W + 1, ..., i]

WW는 window size (예: 4096).

9.2 장점

  • 계산: O(NW)O(NW) 대신 O(N2)O(N^2). NWN \gg W면 훨씬 빠름.
  • KV cache: 최근 WW개만 유지 → 고정 크기. 시퀀스 무한 확장 가능.

9.3 단점

  • 멀리 있는 토큰 못 봄. 이론상 장거리 의존성 손실.
  • 하지만 여러 layer 거치면 effective receptive field가 누적. 예: 12 layers × 4096 window = 49152 receptive field.

9.4 사용

Mistral 7B: Sliding window 4096. 긴 컨텍스트에 매우 효율.

Gemma: Local + global attention 번갈아.

Longformer: 대부분 local, 몇 개 토큰만 global.

9.5 조합

많은 현대 LLM이 local + global attention을 섞는다:

  • 대부분 layer: local (sliding window).
  • 일부 layer: full global.
  • 결과: 긴 컨텍스트 + 효율.

10. Flash Decoding

10.1 문제

LLM 추론의 두 단계:

  1. Prefill: 입력 전체를 한 번에 처리. Batch compute-intensive.
  2. Decode: 토큰 하나씩 생성. 각 단계에서 한 토큰만 compute.

Decode는 본질적으로 sequential — 이전 토큰의 출력이 다음의 입력. 그리고 memory-bound — 작은 compute, 큰 KV cache 읽기.

10.2 GPU 활용률 문제

Decode 시 단일 토큰 attention은 작은 연산. GPU가 거의 놀고 있음. 특히 long context:

Decode step: q (1 token) × K (context 10000 tokens)score (10000)

이건 Matrix-vector 곱 (GEMV). Compute는 작고 memory bandwidth가 병목.

10.3 Flash Decoding 아이디어

KV cache를 chunk로 나누고 각 chunk를 별도 SM에서 병렬 처리. Decode 시점에 parallelism 발굴.

Sequence K, V 분할: chunks of size C
각 chunk에 대해: local softmax 계산 (in parallel)
최종: partial 결과 집계

Flash Attention의 online softmax 아이디어를 decode에 적용.

10.4 결과

Long context decode에서 8배 속도. LLaMA 13B의 16K 컨텍스트 decode가 실용화.


11. Ring Attention

11.1 거대한 컨텍스트 문제

1M, 10M 토큰 컨텍스트 — 단일 GPU 메모리 초과.

11.2 기본 아이디어

시퀀스를 여러 GPU로 분할. 각 GPU가 sequence의 일부.

GPU 0: Q[0:N/4]
GPU 1: Q[N/4:N/2]
GPU 2: Q[N/2:3N/4]
GPU 3: Q[3N/4:N]

K, V도 같은 방식으로 분할.

11.3 통신 패턴

각 Q 블록은 모든 K, V 블록을 봐야 한다. 순차적으로 "ring"으로 전달:

GPU 0K_0, V_0 (local)
GPU 1K_1, V_1 (local)
GPU 2K_2, V_2 (local)
GPU 3K_3, V_3 (local)

Round 1: Ring rotate
GPU 0K_3, V_3
GPU 1K_0, V_0
GPU 2K_1, V_1
GPU 3K_2, V_2

Round 2, 3: 계속 rotate

N rounds 후 모든 Q가 모든 K, V를 봤음.

각 rotate 동안 compute와 통신 겹침 (async).

11.4 결과

Ring Attention (Liu et al. 2023):

  • 10M 토큰 컨텍스트 훈련 가능.
  • Gemini 1.5 Pro의 1M+ context의 기반 기술 중 하나.

Sora, Llama-3 긴 context에도 유사 기법.


12. Sparse Attention

12.1 아이디어

"대부분의 attention weight는 작다" — 실제로 몇 개 토큰에 집중. 나머지는 sparse로 처리.

12.2 패턴

Block Sparse: 미리 정의된 패턴.

  • Strided: 일정 간격.
  • Fixed: 국소 + 전역 혼합.

Longformer: Sliding window + 선택적 global tokens.

BigBird: Random + window + global. 이론적 universal approximator.

Sparse Transformer (OpenAI 2019): Stride + local.

12.3 한계

Sparse 패턴이 고정이라 workload에 따라 품질 변동. Dense attention만큼 일반화되지 않음. 주류 LLM은 대부분 dense + efficient 기법(FlashAttention + GQA).

12.4 학습된 Sparsity

최근 연구: 학습 중 attention 패턴을 모델이 선택. 예: Routing Attention, Switch Transformer에 힌트.

아직 주류는 아님.


13. 다른 효율화 기법

13.1 Linear Attention

Attention의 O(N2)O(N^2)를 **O(N)O(N)**로 만드는 수학적 트릭.

Core: softmax를 feature map ϕ\phi로 근사.

Attention=softmax(QKT)Vϕ(Q)(ϕ(K)TV)\text{Attention} = \text{softmax}(QK^T)V \approx \phi(Q) (\phi(K)^T V)

ϕ(K)TV\phi(K)^T V를 먼저 계산 (O(Nd2)O(Nd^2)), 그 다음 ϕ(Q)\phi(Q)와 곱함 (O(Nd2)O(Nd^2)).

  • Performer (2020): Random feature로 softmax 근사.
  • Linformer (2020): K, V를 저차원으로 project.

단점: 품질 손실 (특히 긴 sequence). 주류 LLM 채택 거의 없음.

13.2 KV Cache Quantization

KV cache 메모리 절반 또는 더 줄이기:

  • FP16 → INT8: 2배 감소.
  • FP16 → INT4: 4배 감소.
  • 품질 손실 최소화 for recent KV cache.

13.3 KV Cache Offloading

KV cache를 CPU DRAM 또는 NVMe에 저장. 필요할 때 GPU로 가져옴.

vLLM: CPU ↔ GPU swap 지원. DeepSpeed: 더 공격적.

속도 손실이 있지만 메모리 제한이 극복된다. Long context 서비스에 유용.

13.4 Speculative Decoding

작은 drafter 모델이 여러 토큰을 predictor. 큰 모델이 validator. 올바른 만큼 한 번에 수락.

Medusa, Lookahead decoding 같은 variant들. 2-3배 추론 속도.

Attention 최적화 직접 아니지만, LLM 서빙 효율 향상에 핵심.


14. 생태계

14.1 구현

FlashAttention (공식, Tri Dao): https://github.com/Dao-AILab/flash-attention. PyTorch CUDA extension.

xformers (Meta): 다양한 efficient attention 구현 모음. FlashAttention 통합.

vLLM (UC Berkeley): PagedAttention 구현. LLM 서빙.

TRT-LLM (NVIDIA): TensorRT 기반 고성능 추론. FlashAttention, PagedAttention 포함.

SGLang: vLLM 대안. 다른 최적화 전략.

Triton: FlashAttention을 Triton으로 재구현 가능. Tri Dao의 튜토리얼.

14.2 PyTorch 통합

PyTorch 2.0+:

from torch.nn.functional import scaled_dot_product_attention

output = scaled_dot_product_attention(q, k, v, is_causal=True)

자동으로 FlashAttention 2 사용 (설치됐으면). 완전 투명.

14.3 HuggingFace

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3-8B",
    torch_dtype=torch.float16,
    attn_implementation="flash_attention_2"  # 명시적
)

대부분 최신 모델이 FlashAttention 2/3 지원.


15. 미래 방향

15.1 하드웨어 진화

Blackwell (B100/B200, 2024): FP4 Tensor Core, 더 큰 SRAM, TMEM(Tensor Memory) 새 유닛.

FlashAttention 4 추정: Blackwell 특화 버전이 곧 나올 것.

15.2 긴 컨텍스트

1M, 10M, 100M 컨텍스트가 목표. Ring Attention + hierarchical attention 조합.

Gemini 1.5 이미 2M. Gemini 2.0+ 더 큰.

15.3 Post-Attention?

일부 연구자는 "attention은 O(N²)라는 근본 한계가 있으니 완전히 새 아키텍처가 필요"라고 주장.

  • State Space Models (Mamba): O(N)O(N) 추론.
  • RWKV: RNN의 부활.
  • Hyena: Long convolution 기반.

하지만 2025년 현재 dense attention + efficient 기법이 품질 면에서 여전히 승.

15.4 Quantization 극단

INT2, ternary, 1-bit. KV cache + weights 양쪽에 적용. Sub-billion GPU 메모리로 대형 모델.


16. 학습 리소스

논문:

  • "FlashAttention: Fast and Memory-Efficient Exact Attention" — Tri Dao et al. 2022.
  • "FlashAttention-2" — Tri Dao 2023.
  • "FlashAttention-3" — Shah et al. 2024.
  • "Efficient Memory Management for LLM Serving with PagedAttention" — vLLM paper.
  • "Ring Attention" — Liu et al. 2023.
  • "Efficient Streaming Language Models with Attention Sinks" — 2023.

블로그:

  • Tri Dao의 개인 블로그 / Princeton 페이지.
  • Hugging Face 블로그 (LLM 최적화 시리즈).
  • Together.ai 블로그 (serving 최적화).
  • UC Berkeley Sky Lab.

코드:

  • Dao-AILab/flash-attention.
  • vllm-project/vllm.
  • NVIDIA/TensorRT-LLM.
  • huggingface/text-generation-inference (TGI).

강의:

  • "GPU MODE" (Discord + YouTube).
  • MIT 6.5940 "TinyML" (Han Song).
  • Stanford CS229M.

17. 요약 — 한 장 정리

┌─────────────────────────────────────────────────────┐
Efficient Attention Cheat Sheet├─────────────────────────────────────────────────────┤
│ 문제:Naive attention: O(N²) 메모리                       │
HBM I/O 병목 (memory-bound)│                                                       │
FlashAttention (Tri Dao 2022):Tiling: 블록 단위 처리                               │
Online softmax: incremental normalization           │
SRAM에서 계산, HBMN² 쓰지 않음                    │
정확 (근사 아님)Backward: recomputation                            │
5-10x 빠름                                          │
│                                                       │
FlashAttention-2 (2023):│   병렬화 개선                                          │
Warp 작업 재분배                                     │
2-3x 추가 속도                                       │
│                                                       │
FlashAttention-3 (Hopper 2024):WGMMA (비동기)Producer-ConsumerWarp specialization                                 │
FP8 Tensor CoreH10075% 효율                                      │
│                                                       │
PagedAttention (vLLM 2023):OS 가상 메모리 → KV cache                           │
│   블록 단위 할당                                       │
Block table (page table 유사)공유 (copy-on-write)2-4x throughput                                    │
│                                                       │
MQA / GQA:Multi-Query: 1 K/V head                            │
Grouped-Query: G (< H) K/V heads                   │
KV cache 1/N 감소                                   │
LLaMA-2/3, Mistral 사용                             │
│                                                       │
Sliding Window:│   최근 W 토큰만                                        │
O(NW) 계산, fixed KV cache                          │
Mistral, Gemma│                                                       │
Ring Attention:│   시퀀스 여러 GPU 분할                                 │
K/V를 ring으로 rotate                               │
│   10M+ context                                        │
│                                                       │
Flash Decoding:KV를 chunk로 분할                                   │
Decode time에 병렬                                   │
Long context 8x                                    │
│                                                       │
Sparse Attention:Longformer, BigBird│   품질 trade-off                                      │
│                                                       │
│ 기타:KV cache quantization (INT8/4)KV offloading (CPU/NVMe)Speculative decoding                                │
│                                                       │
│ 통합:PyTorch scaled_dot_product_attention                │
HuggingFace attn_implementation                     │
│   xformers, vLLM, TRT-LLM, SGLang└─────────────────────────────────────────────────────┘

18. 퀴즈

Q1. FlashAttention이 "memory-bound"라는 통찰이 왜 결정적이었는가?

A. 기존 attention 최적화는 "FLOP을 줄이자"(linear attention, sparse attention 등)에 집중했지만 실제 병목은 HBM I/O였다. GPU의 HBM 대역폭(3 TB/s)이 빠르지만 O(N2)O(N^2) 크기의 S=QKTS = QK^T 행렬을 HBM에 쓰고 다시 읽는 것만으로 수 GB를 옮긴다. Compute units(Tensor Core)는 데이터를 기다리며 대부분 시간 idle. Tri Dao의 통찰: "FLOP은 충분히 빠르니 HBM 접근을 줄이면 자동으로 빨라진다". 이 관점에서 tiling + online softmax로 SRAM만 사용하도록 알고리즘을 재설계. FLOP은 오히려 약간 증가(backward의 recomputation)하지만 HBM I/O는 17배 감소 → 5-10배 속도 향상. 시스템 엔지니어링의 교훈: "병목을 오해하면 최적화가 소용없다".

Q2. Online Softmax가 어떻게 블록 단위 처리를 가능하게 하는가?

A. Softmax는 전체 행의 모든 값이 필요해서(max와 sum 계산) 블록 단위로 쪼갤 수 없어 보인다. 트릭: 현재까지의 max와 sum을 유지하면서 새 블록이 올 때마다 재조정. 새 블록의 max가 이전 max보다 크면 이전 sum에 exp(old_max - new_max)를 곱해 "scale down" 하고, 새 블록의 contribution을 더한다. 최종 결과는 수학적으로 전체 softmax와 정확히 동등. 이 덕분에 SS 행렬 전체를 메모리에 저장할 필요 없이 블록을 순차로 SRAM에 로드하며 OO(출력)를 점진적으로 업데이트 가능. 수치적으로도 안전 — max 뺌으로 overflow 방지. 이 아이디어는 1996년 "Keeping a running softmax" 수치해석 논문에 있었지만 FlashAttention이 GPU tiling과 결합해서 현대 AI 가속의 핵심 기법이 됐다.

Q3. PagedAttention이 OS 가상 메모리에서 빌려온 아이디어는?

A. 네 가지 핵심 OS 개념: (1) 고정 크기 페이지 — KV cache를 연속 블록으로 할당하지 않고 작은 블록(예: 16 tokens)으로 나눔, (2) 페이지 테이블 — 논리적 토큰 위치에서 물리적 블록으로의 매핑. 각 시퀀스가 자기 block table을 가짐, (3) On-demand 할당 — 필요한 시점에만 물리 블록 확보, over-allocation 방지, (4) 공유 + Copy-on-Write — 같은 시스템 프롬프트로 시작하는 여러 요청이 같은 물리 블록을 공유, divergence 시에만 복사. 결과: fragmentation이 거의 없고(80% → 5% 미만), 공유로 메모리 절약, throughput 2-4배 증가. 운영체제가 1960년대에 해결한 문제를 2023년에 LLM serving에 적용한 사례 — "이미 아는 문제의 새 영역 응용"이 연구의 가장 실용적인 형태.

Q4. Grouped-Query Attention이 Multi-Query와 Multi-Head의 타협점이 되는 이유는?

A. 품질과 효율의 균형. MHA(Multi-Head Attention): H개의 Q head + H개의 K/V head → 최고 품질, 최대 KV cache. MQA(Multi-Query): H개의 Q head + 1개의 K/V head → 최소 KV cache(1/H), 품질 약간 하락 ("모든 head가 같은 K/V를 보면 표현력 감소"). GQA: H개의 Q head + G개의 K/V head (G ≈ H/4 또는 H/8) → 중간. KV cache는 4-8배 감소하지만 품질은 MHA와 거의 동일. LLaMA-2 70B와 LLaMA-3 전체 라인업이 GQA 채택: Q head 32 + KV head 8 → 4:1 비율. 실측: KV cache 4배 감소로 동일 GPU에서 4배 더 긴 컨텍스트 또는 4배 더 많은 동시 요청. 품질 손실 측정 불가 수준. "완전한 분리도, 완전한 공유도 아닌 중간"이 최선인 흔한 설계 원칙.

Q5. FlashAttention이 정확하다는 말이 왜 중요한가?

A. 근사 알고리즘이 아니라는 보장. Efficient attention 계열의 다른 접근들(Linear Attention, Performer, Linformer, Sparse Attention)은 대부분 근사 — 수학적으로 O(N2)O(N^2) 결과와 다른 값을 출력. 품질 손실이 있고, 특정 워크로드에 잘 맞지 않을 수 있으며, 훈련된 모델의 attention과 완전 호환 안 됨. FlashAttention은 비트 단위 재현이 가능한 수학적 정확성 — online softmax의 수치적 안정성 덕분에 naive attention과 동등한 결과. 이것이 결정적 — 기존 훈련된 모델(LLaMA, GPT)에 drop-in replacement로 사용 가능. 프레임워크가 자동으로 FlashAttention을 켜도 모델 출력이 완전히 같음. 성능 향상이 "부작용 없는 업그레이드"가 된다. 근사 방법들이 널리 채택 안 된 이유: 작은 차이도 downstream task에서 예측 불가능한 영향.

Q6. Ring Attention이 10M 토큰 컨텍스트를 가능하게 하는 원리는?

A. Sequence parallelism + 통신-계산 오버랩. 단일 GPU는 큰 sequence의 attention matrix를 저장 불가(10M² = 100조 entry). Ring Attention: sequence를 N개 GPU로 분할(각 1/N). 각 GPU는 자기 Q 블록과 local K/V로 시작, 그 다음 K/V를 ring 토폴로지로 순환(GPU 0 → 1 → 2 → ... → 0). 매 회전마다 새 K/V 블록을 받으며 부분 attention 계산. Online softmax를 사용해 부분 결과를 정확하게 합쳐 최종 출력. 중요한 점: 통신이 compute와 동시 진행 — 현재 블록 계산하는 동안 다음 블록 수신. Async로 latency 숨김. N rounds 후 모든 Q가 모든 K/V를 봤고, 각 GPU는 자기 Q 블록의 출력만 가지면 됨. Gemini 1.5 Pro의 1-2M 컨텍스트, Sora의 긴 비디오 시퀀스가 이런 기법 기반. "분산 시스템이 메모리 한계를 극복"의 완벽한 사례.

Q7. LLM 서빙에서 "memory-bound decode"가 왜 flash attention 이상의 최적화를 필요로 하는가?

A. 추론 단계별 특성이 다름. Prefill 단계(입력 전체 처리)는 compute-heavy — 많은 Q, K, V를 한 번에 GEMM. Decode 단계(토큰 하나씩 생성)는 memory-bandwidth 제한 — 매 step에 single query × long K/V context로 GEMV(행렬-벡터 곱). GEMM은 arithmetic intensity 높고 Tensor Core 활용 가능, GEMV는 compute/memory 비율 낮아 HBM bandwidth가 병목. 한 토큰 decode 시 KV cache 전체 읽기만 수 ms. Flash Attention은 여전히 유용하지만 GEMM 최적화라 decode에서 이득 작음. Flash Decoding이 해결책: KV cache를 chunk로 분할하고 각 chunk를 별도 SM에서 병렬 처리 → decode에 parallelism 도입. Online softmax로 partial 결과 합산. 장문 컨텍스트 decode에서 8배 속도. 관찰: "LLM 서빙 최적화는 prefill과 decode를 따로 생각해야 함". 이 둘이 근본적으로 다른 병목.


이 글이 도움이 됐다면 다음 포스트도 확인해 보세요:

  • "Transformer Architecture Deep Dive" — Attention의 기본.
  • "CUDA GPU Programming Deep Dive" — 커널 최적화의 기반.
  • "Diffusion Models Deep Dive" — 또 다른 attention-heavy 워크로드.
  • "RDMA & NCCL" — 멀티 GPU 훈련의 통신 기반.