- Authors
- Name
- 1. はじめに
- 2. Standard Attentionのメモリ問題
- 3. GPUメモリ階層構造
- 4. IO Complexity分析
- 5. Tiling技法:SRAMに収まるブロック単位の演算
- 6. Online Softmax(Safe Softmax)アルゴリズム
- 7. Backward PassのRecomputation戦略
- 8. FlashAttention-2の改善点
- 9. FlashAttention-3 最新の進展
- 10. ベンチマーク:速度/メモリ比較
- 11. PyTorch torch.nn.functional.scaled_dot_product_attention との連携
- 12. まとめと核心的教訓
- References
1. はじめに
Transformerアーキテクチャの中核であるSelf-Attentionは、シーケンス内の全トークンペア間の関係を計算する。この演算は強力な表現力を提供するが、シーケンス長 に対して時間およびメモリ計算量が で増加するという根本的な限界を持つ。GPT-4、LLaMA、Geminiなどの最新LLMが128K以上の長いコンテキストを処理するには、この のボトルネックを実質的に解決する必要がある。
FlashAttention(Dao et al., 2022)は、この問題を近似なしで解決する。核心のアイデアはシンプルでありながら深い:attention演算自体の計算量を削減するのではなく、GPUメモリ階層間のデータ移動(IO)を最小化する。 本稿では、FlashAttentionの原理をGPUハードウェアの観点から体系的に分析し、FlashAttention-2からFlashAttention-3までの発展を追う。
2. Standard Attentionのメモリ問題
2.1 Standard Attentionの演算フロー
Standard Self-Attentionは以下のように計算される。入力 に対して:
ここで はシーケンス長、 はhead dimensionである。
2.2 メモリ計算量の分析
問題の核心は中間行列 と にある。これらの行列のサイズは であり、シーケンス長に対して**二次(quadratic)**のメモリを要求する。具体的な数値を計算すると:
| シーケンス長 () | Attention行列サイズ | FP16メモリ |
|---|---|---|
| 1,024 | 1M要素 | 2 MB |
| 4,096 | 16.7M要素 | 33 MB |
| 16,384 | 268M要素 | 536 MB |
| 65,536 | 4.3B要素 | 8.6 GB |
| 131,072 | 17.2B要素 | 34.4 GB |
これらの数値は単一head、単一バッチに対するものである。Multi-head attentionではhead数 を乗じ、バッチサイズ を乗じると、実際のメモリ使用量は遥かに大きくなる。シーケンス長65,536では、単一headだけでもA100 80GB GPUのHBMのかなりの部分を消費することになる。
2.3 HBMボトルネック
Standard attentionの実装では、この 行列をGPU HBM(High Bandwidth Memory)にmaterializationする。すなわち、 を計算してHBMに書き込み、softmaxのために再度読み出し、結果 をHBMに書き込み、 のために再度読み出す。この過程でのHBMに対する読み書き回数は となる。
実際のGPUでこの演算が遅い理由は、計算(compute)ではなくメモリアクセス(memory access)がボトルネックだからである。A100 GPUの計算スループットは312 TFLOPS(FP16)であるのに対し、HBM帯域幅は約2 TB/sに過ぎない。Attention演算はarithmetic intensity(演算量/メモリアクセス量の比率)が低く、典型的なmemory-bound演算である。
3. GPUメモリ階層構造
FlashAttentionを理解するには、GPUのメモリ階層を正確に理解する必要がある。
3.1 HBM(High Bandwidth Memory)
- 容量: A100基準で40GBまたは80GB
- 帯域幅: 約1.5〜2.0 TB/s(A100 80GB SXM: 2,039 GB/s)
- アクセスレイテンシ: 約200〜600サイクル
- 役割: GPUのメインメモリ。モデルパラメータ、入力テンソル、出力テンソルなど全データが格納される
3.2 SRAM(On-chip Shared Memory)
- 容量: A100基準でSMあたり約192KB、合計約20MB(108個のSM)
- 帯域幅: 約19 TB/s
- アクセスレイテンシ: 約20〜30サイクル
- 役割: 各Streaming Multiprocessor(SM)内の高速オンチップメモリ
3.3 核心的な非対称性
SRAMとHBMの間には劇的な非対称性が存在する:
| 特性 | SRAM | HBM |
|---|---|---|
| 帯域幅 | ~19 TB/s | ~2 TB/s |
| 容量 | ~20 MB | 40〜80 GB |
| アクセスレイテンシ | 20〜30サイクル | 200〜600サイクル |
SRAMはHBMより約10倍高速だが、容量は約4000分の1である。 FlashAttentionの核心的な洞察は、この非対称性を積極的に活用することにある: 行列全体をHBMにmaterializationする代わりに、SRAMに収まる小さなブロック単位で演算を実行すれば、HBMアクセスを劇的に削減できる。
4. IO Complexity分析
4.1 Standard AttentionのIO Complexity
Standard attentionは以下のようなHBMアクセスパターンを示す:
- をHBMから読み出し を計算 -> をHBMに書き込み: IO
- をHBMから読み出し を計算 -> をHBMに書き込み: IO
- をHBMから読み出し を計算 -> をHBMに書き込み: IO
合計HBMアクセス量:
シーケンス長 がhead dimension (通常64または128)よりもはるかに大きいため、 の項が支配的となる。
4.2 FlashAttentionのIO Complexity
FlashAttentionはtilingによりHBMアクセス量を以下に削減する:
ここで はSRAMサイズである。直感的には、SRAMが大きいほどより大きなブロックを一度に処理でき、HBMアクセスが減少する。
4.3 最適性の証明(Lower Bound)
論文はさらに以下の下限(lower bound)を証明している:
定理: であるすべてのSRAMサイズ に対して、exact attentionを計算するいかなるアルゴリズムも のHBMアクセスを必要とする。
これは、FlashAttentionが**IO complexityの観点で最適(optimal)**であることを意味する。定数因子や多項対数因子を除けば、より少ないHBMアクセスでexact attentionを計算することは不可能である。
4.4 数値例
A100のSRAMサイズ KB、head dimension 、シーケンス長 の場合:
- Standard attention IO: 要素
- FlashAttention IO: 要素(ブロックサイズにより変動)
実際には サイズの中間行列がHBMに一切書き込まれないため、節約効果はさらに大きい。特にシーケンス長が長くなるほど効果は顕著になる。
5. Tiling技法:SRAMに収まるブロック単位の演算
5.1 アルゴリズム概要
FlashAttentionの核心アルゴリズムは以下の通りである:
- を 個のブロックに分割する:、各ブロックサイズ
- を 個のブロックに分割する: および 、各ブロックサイズ
- ブロックサイズ はSRAMサイズ に合わせて設定:,
5.2 Forward Pass擬似コード
Algorithm: FlashAttention Forward Pass
---------------------------------------
Input: Q, K, V in HBM, SRAM size M
Output: O in HBM
1. ブロックサイズ設定: B_c = ceil(M / 4d), B_r = min(ceil(M / 4d), d)
2. O = zeros(N, d), l = zeros(N), m = -inf * ones(N) をHBMに初期化
3. for j = 1 to T_c: # 外側ループ: K, Vブロック
K_j, V_j をHBMからSRAMにロード
for i = 1 to T_r: # 内側ループ: Qブロック
Q_i, O_i, l_i, m_i をHBMからSRAMにロード
# SRAM内でブロック単位の演算を実行
S_ij = Q_i @ K_j^T # (B_r x B_c)
m_ij = rowmax(S_ij)
P_ij = exp(S_ij - m_ij)
l_ij = rowsum(P_ij)
# 前ブロックの統計量と結合(Online Softmax)
m_new = max(m_i, m_ij)
l_new = exp(m_i - m_new) * l_i + exp(m_ij - m_new) * l_ij
# 出力の更新(rescaling込み)
O_i = diag(exp(m_i - m_new))^(-1) * (diag(l_i) * O_i)
+ diag(exp(m_ij - m_new))^(-1) * P_ij @ V_j
O_i = diag(l_new)^(-1) * O_i
# 統計量の更新
m_i = m_new, l_i = l_new
O_i, l_i, m_i をHBMに書き戻し
end for
end for
4. return O
5.3 なぜこれが機能するのか
核心は、 サイズのattention行列 と がHBMに一切materializationされないという点にある。各 ブロック はSRAM内で計算され、即座にsoftmax統計量の更新と出力の蓄積に使用された後、破棄される。
これを可能にする数学的技法がOnline Softmaxである。
6. Online Softmax(Safe Softmax)アルゴリズム
6.1 Standard Softmaxの問題
Softmaxは**大域的演算(global operation)**である。行ベクトル に対して:
これを計算するには、分母の合計のために行全体を一度に見る必要がある。これがtilingを困難にする根本的な障壁である -- ブロック だけを見てもsoftmaxは完成できない。残りのブロック の値によって分母が変わるからである。
また、数値安定性のために「safe softmax」を使用する:
これも大域的最大値 が必要であり、まず行全体をスキャンする必要がある。
6.2 Online Softmaxトリック
Online Softmax(Milakov & Gimelshein, 2018)の核心アイデアは、running statisticsを維持しながらブロック単位で漸進的(incremental)にsoftmaxを計算することである。
行ごとに2つのスカラーを維持する:
- : これまでに見た要素の最大値(running max)
- : これまでの正規化定数(running sum of exponentials)
新しいブロック が到来すると:
- 新ブロックの行別最大値を計算:
- 大域最大値を更新:
- 前の正規化定数をrescale:
- 前の出力もrescale:
この過程は**数学的に正確(exact)**である。近似ではない。ブロックをどの順序で処理しても、最終結果はstandard attentionとビット単位で同一である(浮動小数点演算の順序による微小な数値差を除く)。
6.3 数学的正当性
証明の核心はsoftmaxのrescaling propertyである:
最大値が から に更新されても、分子と分母に同一のfactor が乗じられるため、比率は変わらない。この性質により、前ブロックの結果を新しい最大値基準で安全にrescaleできる。
7. Backward PassのRecomputation戦略
7.1 Standard Backward Passの問題
Standard attentionのbackward passでは、gradient計算のためにforward passで保存した中間行列 と が必要である。これらのサイズが であるため、forwardで保存しbackwardで再度読み出すことは のメモリを要求する。
7.2 FlashAttentionのRecomputation
FlashAttentionはgradient checkpointingの変形を使用する。Forward passで** と を保存せず**、代わりに以下のみを保存する:
- 最終出力
- Softmax正規化統計量 (行別の最大値と合計)
Backward passでは、これらの統計量と元の を使用して、** と の必要なブロックをSRAM内で再計算(recompute)**する。このrecomputationは追加のFLOPを要求するが、HBMアクセスを大幅に削減する。
7.3 Recomputationの逆説的効果
一般的にgradient checkpointingはメモリを節約する代わりに速度を犠牲にする。しかし、FlashAttentionのrecomputationはむしろ速度まで向上させる。理由は以下の通りである:
- FLOPは増加する:forwardで一度計算したものをbackwardで再計算するため、総FLOPは若干増加する。
- HBM IOは減少する: サイズの 、 をHBMに書き込み読み出すコストがなくなる。
現代のGPUではHBMアクセスが計算よりもはるかに遅いため、FLOP増加量よりもIO削減の利得の方が大きい。実験結果、recomputationによる追加のランタイムオーバーヘッドは5%未満でありながら、メモリ使用量は から に減少する。
7.4 メモリ節約効果
| シーケンス長 | Standard Attentionメモリ | FlashAttentionメモリ | 節約率 |
|---|---|---|---|
| 1K | ~2 MB | ~0.13 MB | ~15x |
| 2K | ~8 MB | ~0.26 MB | ~30x |
| 4K | ~33 MB | ~0.52 MB | ~63x |
| 8K | ~131 MB | ~1.04 MB | ~126x |
この節約効果により、同じGPUメモリでより長いシーケンスを処理したり、より大きなバッチサイズを使用したりすることが可能になる。
8. FlashAttention-2の改善点
Dao(2023)はFlashAttention-2で3つの核心的な改善を導入した。
8.1 Non-matmul FLOPの最小化
A100 GPUのTensor Coreは行列乗算(matmul)に対して312 TFLOPS(FP16)を提供するが、non-matmul演算(softmaxのexp、max、sumなど)は19.5 TFLOPS(FP32)で約16倍遅い。FlashAttention-1ではnon-matmul演算の比率がかなり高かった。
FlashAttention-2はアルゴリズムを再構成してこれらのnon-matmul FLOPを最小化する。具体的には、rescaling演算の回数を削減し、softmax統計量の更新をより効率的に実行する。最終rescalingをループの最後に一度だけ実行するよう変更したことが核心である。
8.2 並列性の向上:シーケンス長次元の並列化
FlashAttention-1はbatch次元とhead次元でのみ並列化していた。バッチサイズが小さい場合やhead数が少ない場合、GPUのSM(Streaming Multiprocessor)を十分に活用できなかった。
FlashAttention-2はシーケンス長次元でも並列化する。外側ループをQブロック基準に変更し( ブロックではなく ブロックを外側ループに)、各Qブロックを独立したthread blockで処理できるようにした。この変更によりforward passでのoccupancyが大幅に向上する。
8.3 Work Partitioningの最適化
Thread block内でのwarp間の作業分配も改善された:
- FlashAttention-1: K, Vを4つのwarpに分割し、各warpが独立して を計算後に結果を同期。この方式はshared memoryを通じた通信と同期オーバーヘッドが発生する。
- FlashAttention-2: Qを4つのwarpに分割し、KとVは全warpが共有。各warpはQの異なる部分に対して独立して出力を計算するため、warp間通信が不要である。
8.4 性能結果
これら3つの改善を組み合わせると:
- FlashAttention-1比約2倍のスピードアップ
- A100でFP16/BF16基準230 TFLOPSを達成(理論最大値の約73%)
- Standard PyTorch attention比最大9倍のスピードアップ
- GEMM(行列乗算)演算の効率に接近
9. FlashAttention-3 最新の進展
FlashAttention-3(Shah et al., 2024)は、NVIDIA Hopperアーキテクチャ(H100)の新しいハードウェア機能を活用してさらに一段階進化した。
9.1 Hopper GPUの新機能
H100 GPUはA100に比べて以下の核心機能を提供する:
- WGMMA(Warpgroup Matrix Multiply-Accumulate): A100の
mma.syncよりもはるかに高いスループットを持つ新しいTensor Core命令 - TMA(Tensor Memory Accelerator): Global memoryとShared memory間のデータ転送を専担するハードウェアユニット。インデックス計算と境界チェックをハードウェアで処理
9.2 3つの核心技法
1. Warp Specializationによる非同期実行
演算(WGMMA)とデータ移動(TMA)を異なるwarp groupに割り当て、**パイプライン方式で重畳(overlap)**実行する。一方のwarp groupが現在のブロックを計算している間に、もう一方のwarp groupが次のブロックのデータをprefetchする。
2. MatmulとSoftmaxのInterleaving
従来はmatmulが終わってからsoftmaxを実行し、再びmatmulを実行する逐次的方式であった。FlashAttention-3はこれをインターリーブして、matmulとsoftmaxが異なるハードウェアユニット上で同時に実行されるようにする。Tensor Coreが次のブロックの を計算している間に、CUDA Coreが現在のブロックのsoftmaxを処理する。
3. FP8低精度サポート
H100のFP8 Tensor Coreを活用してスループットを2倍に引き上げる。単純にFP8で量子化すると精度が低下するが、FlashAttention-3は2つの技法でこれを解決する:
- Block quantization: ブロック単位で個別のスケールファクターを維持してダイナミックレンジを保持
- Incoherent processing: ランダム直交行列を乗じてoutlierを分散させてから量子化。これによりFP8 baseline比2.6倍低い数値誤差を達成
9.3 性能結果
H100でのFlashAttention-3の性能:
| 設定 | TFLOPS | GPU活用率 |
|---|---|---|
| FP16 FlashAttention-2 | ~400 | ~50% |
| FP16 FlashAttention-3 | ~740 | ~75% |
| FP8 FlashAttention-3 | ~1,200 | ~75% |
FP16でFlashAttention-2比1.5〜2.0倍のスピードアップ、FP8では1.2 PFLOPSに迫る性能を達成した。
10. ベンチマーク:速度/メモリ比較
10.1 Attention Forward Pass速度(A100 80GB、FP16)
FlashAttention論文および後続ベンチマークで報告された主要数値は以下の通りである:
| シーケンス長 | Standard Attention | FlashAttention | FlashAttention-2 | Speedup(FA2 vs Std) |
|---|---|---|---|---|
| 512 | 12.2 ms | 3.5 ms | 1.9 ms | 6.4x |
| 1K | 45.8 ms | 7.8 ms | 4.1 ms | 11.2x |
| 2K | 178 ms | 18.9 ms | 9.8 ms | 18.2x |
| 4K | 710 ms | 52.3 ms | 27.1 ms | 26.2x |
| 8K | OOM | 145 ms | 75 ms | - |
| 16K | OOM | 520 ms | 270 ms | - |
シーケンス長が長くなるほどスピードアップはより劇的に増大する。8K以上ではstandard attentionがOOM(Out of Memory)で実行自体が不可能だが、FlashAttentionは問題なく処理する。
10.2 End-to-End学習性能
| モデル | Standard | FlashAttention | Speedup |
|---|---|---|---|
| BERT-large(seq 512) | 100%(MLPerf基準) | 115% | 1.15x |
| GPT-2(seq 1K) | 100% | 300% | 3.0x |
| Long-range Arena(seq 1K-4K) | 100% | 240% | 2.4x |
10.3 メモリ使用量比較
FlashAttentionのattention演算メモリはシーケンス長に対して**線形(linear)であり、standard attentionの二次(quadratic)**と比べて劇的な改善である:
- シーケンス長2K: 約10倍のメモリ節約
- シーケンス長4K: 約20倍のメモリ節約
- シーケンス長64K: standard attentionはA100 80GBでもOOM、FlashAttentionは正常動作
11. PyTorch torch.nn.functional.scaled_dot_product_attention との連携
11.1 ネイティブ統合
PyTorch 2.0以降、FlashAttentionは torch.nn.functional.scaled_dot_product_attention(SDPA)にネイティブに統合されている。PyTorch 2.2以降はFlashAttention-2がデフォルトバックエンドとして使用される。
import torch
import torch.nn.functional as F
# 基本的な使用法 - 自動的にFlashAttentionバックエンドを選択
query = torch.randn(batch_size, num_heads, seq_len, head_dim,
device='cuda', dtype=torch.float16)
key = torch.randn(batch_size, num_heads, seq_len, head_dim,
device='cuda', dtype=torch.float16)
value = torch.randn(batch_size, num_heads, seq_len, head_dim,
device='cuda', dtype=torch.float16)
# PyTorchが自動的に最適なバックエンドを選択する
output = F.scaled_dot_product_attention(query, key, value)
11.2 バックエンドの明示的選択
特定のバックエンドを強制的に使用または除外できる:
from torch.nn.attention import sdpa_kernel, SDPBackend
# FlashAttentionバックエンドのみ使用
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
output = F.scaled_dot_product_attention(query, key, value)
# Memory-efficient attentionバックエンドのみ使用
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
output = F.scaled_dot_product_attention(query, key, value)
# Math(naive)バックエンド使用 - デバッグ用
with sdpa_kernel(SDPBackend.MATH):
output = F.scaled_dot_product_attention(query, key, value)
# CuDNNバックエンド使用(PyTorch 2.2+)
with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
output = F.scaled_dot_product_attention(query, key, value)
11.3 Causal Maskとの併用
LLMの自己回帰生成に不可欠なcausal maskもサポートされている:
# is_causal=Trueでcausal maskを適用
# FlashAttentionはcausal maskをfused kernel内で処理し、追加メモリは不要
output = F.scaled_dot_product_attention(
query, key, value,
is_causal=True
)
# カスタムattention maskの使用
attn_mask = torch.tril(torch.ones(seq_len, seq_len, device='cuda', dtype=torch.bool))
output = F.scaled_dot_product_attention(
query, key, value,
attn_mask=attn_mask
)
11.4 バックエンド選択条件
PyTorch SDPAがFlashAttentionバックエンドを選択するための条件:
- dtype:
float16またはbfloat16(float32は不可) - device: CUDA GPU(CPUは非対応)
- head dimension: 最大256(FlashAttention-2基準)
- attention mask: boolean maskまたは
is_causal=Trueをサポート、任意のfloat maskは非対応
これらの条件を満たさない場合、PyTorchは自動的にmemory-efficient attentionまたはmathバックエンドにフォールバックする。
11.5 実務での活用ヒント
# どのバックエンドが使用されているか確認
import torch.backends.cuda
# 各バックエンドの有効状態を確認
print(f"Flash SDP enabled: {torch.backends.cuda.flash_sdp_enabled()}")
print(f"Mem efficient SDP enabled: {torch.backends.cuda.mem_efficient_sdp_enabled()}")
print(f"Math SDP enabled: {torch.backends.cuda.math_sdp_enabled()}")
# グローバルに特定バックエンドを無効化
torch.backends.cuda.enable_flash_sdp(False) # FlashAttentionを無効化
torch.backends.cuda.enable_mem_efficient_sdp(True)
11.6 flash-attnライブラリの直接使用
PyTorchネイティブのSDPA以外にも、Tri Daoの flash-attn パッケージを直接使用できる。このパッケージはPyTorch SDPAよりも多くの機能(例:sliding window attention、ALiBi、cross-attention最適化)を提供する:
# pip install flash-attn
from flash_attn import flash_attn_func
# (batch, seqlen, nheads, headdim) 形状
output = flash_attn_func(q, k, v, causal=True)
12. まとめと核心的教訓
FlashAttentionの核心的教訓は、アルゴリズムのFLOP計算量だけが性能を決定するわけではないということである。現代のGPUではメモリアクセスパターンが実際の実行時間を支配しており、IO-awareなアルゴリズム設計が実用的な性能に決定的である。
主要な貢献を要約すると:
- IO-Aware設計原則: GPUメモリ階層(HBM vs SRAM)の非対称性を活用したアルゴリズム設計
- Tiling + Online Softmax: SRAMに収まるブロック単位の演算で 行列のHBM materialization除去
- Recomputation戦略: Backward passで中間値を再計算して から へメモリ節約、同時に速度も向上
- 最適性の証明: IO complexityの観点から下限を証明してアルゴリズムの最適性を立証
- Exact Computation: すべての最適化にもかかわらず、近似なしのexact attentionを維持
FlashAttentionは、理論的な美しさと実用的な効果を兼ね備えた稀有な研究であり、現代のLLM学習と推論の核心インフラとなった。PyTorchへのネイティブ統合により、追加の実装なしに F.scaled_dot_product_attention の呼び出しだけでその恩恵を享受できる。
References
- Dao, T., Fu, D.Y., Ermon, S., Rudra, A., & Re, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022. https://arxiv.org/abs/2205.14135
- Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. ICLR 2024. https://arxiv.org/abs/2307.08691
- Shah, J., Bikshandi, G., Zhang, Y., Thakkar, V., Ramani, P., & Dao, T. (2024). FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision. NeurIPS 2024 Spotlight. https://arxiv.org/abs/2407.08608
- Dao-AILab. flash-attention GitHub Repository. https://github.com/Dao-AILab/flash-attention
- PyTorch Documentation.
torch.nn.functional.scaled_dot_product_attention. https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html - PyTorch Documentation.
torch.nn.attention.sdpa_kernel. https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html - PyTorch Blog. FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision. https://pytorch.org/blog/flashattention-3/
- Milakov, M. & Gimelshein, N. (2018). Online Normalizer Calculation for Softmax. arXiv:1805.02867. https://arxiv.org/abs/1805.02867
- NVIDIA. A100 Tensor Core GPU Architecture Whitepaper. https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf
- NVIDIA. Hopper Architecture In-Depth. https://developer.nvidia.com/blog/nvidia-hopper-architecture-in-depth/