はじめに
Transformerアーキテクチャは、NLP、コンピュータビジョン、音声処理など、ほぼすべてのディープラーニング分野の基盤モデルとして定着しました。しかし、Self-Attentionメカニズムの$O(N^2)$メモリ複雑度と繰り返しのGPUメモリアクセスは、学習と推論の双方において深刻なボトルネックを形成しています。特にシーケンス長が数千から数万トークンに拡大する現代のLLM時代において、このボトルネックはもはや無視できないレベルに達しています。
2022年、StanfordのTri Daoを筆頭とする研究チームは、この問題を**アルゴリズムの観点ではなくハードウェアIOの観点**から攻略するFlashAttentionを発表しました。FlashAttentionは、既存のアテンションの精度を一切損なうことなく(exact attention)、GPUメモリ階層構造を明示的に考慮したIO-awareアルゴリズム設計により、学習時に2〜4倍のウォールクロック速度向上と5〜20倍のメモリ削減を達成しました。
その後、FlashAttention-2(2023年)ではGPU内部のワープ(warp)レベルの作業分配を最適化し、A100で理論的最大FLOPSの50〜73%を達成しました。FlashAttention-3(2024年)ではHopperアーキテクチャ(H100)の非同期実行とFP8テンソルコアを活用し、H100で740 TFLOPS/s(FP16)と1.2 PFLOPS/s(FP8)という驚異的な性能を記録しました。
本記事では、FlashAttentionシリーズの発展過程全体を論文レベルで分析します。Standard Attentionの根本的な問題をGPUメモリ階層の観点から診断し、v1のタイリング(Tiling)とrecomputation戦略、v2の並列化改善、v3の非同期パイプラインと低精度対応まで段階的に扱います。さらに、PyTorchおよびTritonベースの実践コード、性能ベンチマーク、そしてプロダクション適用時に直面する失敗事例と復旧方法まで包括的に整理します。
Standard Attentionの問題点: O(N^2)メモリとIOボトルネック
数学的定義
Self-Attentionの数学的定義は以下の通りです。
$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V
$$
ここで$Q, K, V \in \mathbb{R}^{N \times d}$であり、$N$はシーケンス長、$d$はヘッド次元です。問題は、中間行列$S = QK^T \in \mathbb{R}^{N \times N}$を**全体的にmaterialization(物理的に格納)** しなければならない点です。
メモリ分析
シーケンス長別のアテンションスコア行列のメモリ要件は以下の通りです。
| シーケンス長 (N) | アテンション行列サイズ | FP16メモリ | FP32メモリ |
| :--------------: | :--------------------: | :--------: | :--------: |
| 1,024 | 1M | 2 MB | 4 MB |
| 4,096 | 16.7M | 33 MB | 67 MB |
| 16,384 | 268M | 536 MB | 1.07 GB |
| 65,536 | 4.29B | 8.59 GB | 17.18 GB |
| 131,072 | 17.18B | 34.36 GB | 68.72 GB |
バッチサイズとヘッド数を掛けると、実際のメモリ使用量はこれよりはるかに大きくなります。例えば、batch=4、heads=32、N=16384の場合、アテンションスコアだけで約68GBが必要となり、A100 80GBのほぼ全容量を消費します。
IOボトルネックの実態
しかし、純粋な演算の観点から見ると、状況は異なります。Self-AttentionのFLOPSは$O(N^2 d)$であり、メモリアクセス量は$O(N^2 + Nd)$です。ここで算術強度(arithmetic intensity)を計算すると、おおよそ$O(d)$程度ですが、これは現代GPUのFLOPS対メモリ帯域幅比(ops:byte ratio)に比べてかなり低い値です。
A100 GPUの場合、テンソルコアFP16性能は312 TFLOPS/sで、HBM帯域幅は2TB/sです。これはops:byte ratioが約156であることを意味しますが、ヘッド次元$d$が一般的に64〜128であるself-attentionは、この比率に対して演算密度が非常に低くなります。つまり、**GPUは演算を待っているのではなく、データの読み書きを待っている**のです。これこそが、Standard AttentionがGPU活用率30%以下にとどまる根本的な原因です。
def standard_attention(Q, K, V, mask=None):
"""Standard Self-Attention: N x Nスコア行列全体をmaterializationする。"""
d_k = Q.size(-1)
S = Q @ K^T -> (batch, heads, N, N) 全体をメモリに格納
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
softmaxもN x N行列全体に対して実行
attn_weights = F.softmax(scores, dim=-1) # (batch, heads, N, N) を格納
最終出力の計算
output = torch.matmul(attn_weights, V)
return output
ベンチマーク
batch, heads, seq_len, d_head = 4, 32, 4096, 64
Q = torch.randn(batch, heads, seq_len, d_head, device='cuda', dtype=torch.float16)
K = torch.randn_like(Q)
V = torch.randn_like(Q)
ウォームアップ
for _ in range(3):
_ = standard_attention(Q, K, V)
torch.cuda.synchronize()
計測
start = time.time()
for _ in range(100):
_ = standard_attention(Q, K, V)
torch.cuda.synchronize()
elapsed = (time.time() - start) / 100
print(f"Standard Attention: {elapsed*1000:.2f} ms/iter")
print(f"Peak memory: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")
このコードでは、`scores`テンソルが$4 \times 32 \times 4096 \times 4096 \times 2 = 4.29\text{GB}$のHBMを消費し、この値をHBMから読み書きする過程で大部分の時間が消費されます。
GPUメモリ階層構造: SRAM vs HBM
FlashAttentionの核心的な洞察は、GPUのメモリ階層構造を明示的に活用することです。GPUには大きく2つのメモリレベルが存在します。
HBM (High Bandwidth Memory)
- **容量**: 40〜80GB(A100)、80〜141GB(H100/H200)
- **帯域幅**: 1.5〜3.35 TB/s
- **役割**: モデル重み、活性化値、オプティマイザ状態など大容量データの格納
- **特性**: 容量は大きいが、アクセスレイテンシが相対的に長い
SRAM (On-chip Static RAM)
- **容量**: SM当たり192KB(A100)、全体で約20〜40MBレベル
- **帯域幅**: ~19 TB/s(A100基準)
- **役割**: 各Streaming Multiprocessor(SM)の共有メモリとして、カーネル実行中の一時データを格納
- **特性**: HBM比で約10倍高速なアクセス速度だが、容量は約1000分の1
| メモリ階層 | 容量(A100) | 帯域幅 | アクセスレイテンシ | 例え |
| :--------: | :-----------: | :------: | :----------------: | :--------: |
| L1/SRAM | ~20MB(合計) | ~19 TB/s | ~28サイクル | 机上のメモ |
| L2 Cache | 40MB | ~5 TB/s | ~200サイクル | 引き出し |
| HBM | 80GB | ~2 TB/s | ~400サイクル | 書庫 |
| CPU RAM | ~1TB | ~50 GB/s | ~数千サイクル | 図書館 |
Standard Attentionの問題は、中間結果($S$、$P = \text{softmax}(S)$)をすべてHBMに書き込んでから再度読み出す点です。アテンション演算全体における**HBMアクセス回数は$\Omega(Nd + N^2)$** です。FlashAttentionの目標は、SRAMの活用を通じてこのアクセス回数を$O(N^2 d^2 M^{-1})$に削減することです。ここで$M$はSRAMサイズであり、一般的な$d$(64〜128)と$M$(約100KB〜192KB)において、この値はstandardアプローチよりも数倍から数十倍小さくなります。
FlashAttention v1: Tiling + Recomputation
核心アイデア
FlashAttention v1のアルゴリズムは、2つの核心的な手法で構成されています。
1. **タイリング(Tiling)**: Q、K、V行列をSRAMに収まるサイズのブロックに分割し、ブロック単位でアテンションを計算します。この過程で、N x Nサイズのアテンションスコア行列をHBMに一度も全体的に格納しません。
2. **再計算(Recomputation)**: 逆伝播(backward pass)でアテンションスコアを保存せず、順伝播で保存したソフトマックス正規化統計量(最大値$m$と合計$\ell$)のみを使用して、必要なときに再計算します。
Online Softmaxとブロック単位の累積
行全体に対するsoftmaxをブロック単位で正確に計算するために、**Online Softmax**手法を使用します。各ブロック$j$を処理した後の累積出力は以下のように更新されます。
$$
m_i^{(j)} = \max(m_i^{(j-1)}, \tilde{m}_{ij})
$$
$$
\ell_i^{(j)} = e^{m_i^{(j-1)} - m_i^{(j)}} \ell_i^{(j-1)} + e^{\tilde{m}_{ij} - m_i^{(j)}} \tilde{\ell}_{ij}
$$
$$
O_i^{(j)} = \frac{e^{m_i^{(j-1)} - m_i^{(j)}} \ell_i^{(j-1)} O_i^{(j-1)} + e^{\tilde{m}_{ij} - m_i^{(j)}} \tilde{P}_{ij} V_j}{\ell_i^{(j)}}
$$
ここで$\tilde{m}_{ij}$と$\tilde{\ell}_{ij}$は現在のブロックのローカルsoftmax統計量です。この数式により、全体のsoftmaxを一度に計算せずとも、ブロックごとに段階的に**正確な**結果を累積できます。
アルゴリズム擬似コード
FlashAttention v1の順伝播アルゴリズムをPython擬似コードで表現すると以下の通りです。
def flash_attention_forward(Q, K, V, block_size_q, block_size_kv):
"""FlashAttention v1 Forward Pass 擬似コード。
実際のCUDAカーネルでは、すべての演算が単一のGPUカーネル内で実行され、
中間結果はSRAMにのみ保持されます。
"""
batch, heads, N, d = Q.shape
O = torch.zeros_like(Q) # 出力累積
m = torch.full((batch, heads, N, 1), float('-inf'), device=Q.device) # 行ごとの最大値
l = torch.zeros((batch, heads, N, 1), device=Q.device) # 行ごとの合計
外部ループ: Qをブロック単位で巡回
for i in range(0, N, block_size_q):
qi = Q[:, :, i:i+block_size_q, :] # SRAMへロード
内部ループ: K, Vをブロック単位で巡回
for j in range(0, N, block_size_kv):
kj = K[:, :, j:j+block_size_kv, :] # SRAMへロード
vj = V[:, :, j:j+block_size_kv, :] # SRAMへロード
1. ローカルアテンションスコアの計算(SRAM内で)
sij = torch.matmul(qi, kj.transpose(-2, -1)) / math.sqrt(d)
2. ローカルsoftmax統計量
mij_local = sij.max(dim=-1, keepdim=True).values
pij_local = torch.exp(sij - mij_local)
lij_local = pij_local.sum(dim=-1, keepdim=True)
3. Online Softmaxによるグローバル統計量の更新
mi_old = m[:, :, i:i+block_size_q, :]
li_old = l[:, :, i:i+block_size_q, :]
oi_old = O[:, :, i:i+block_size_q, :]
mi_new = torch.maximum(mi_old, mij_local)
alpha = torch.exp(mi_old - mi_new)
beta = torch.exp(mij_local - mi_new)
li_new = alpha * li_old + beta * lij_local
4. 出力累積の更新
O[:, :, i:i+block_size_q, :] = (
alpha * li_old * oi_old + beta * torch.matmul(pij_local, vj)
) / li_new
m[:, :, i:i+block_size_q, :] = mi_new
l[:, :, i:i+block_size_q, :] = li_new
return O, m, l # m, lはbackwardでのrecomputationに使用
Recomputation戦略
逆伝播において従来の方式では、順伝播で保存した$P = \text{softmax}(S)$行列($N \times N$)を使用して勾配を計算します。FlashAttentionはこの$N \times N$行列を保存せず、順伝播で保存した$m$と$\ell$の統計量(それぞれ$O(N)$サイズ)のみで、逆伝播時に$S$と$P$を**ブロック単位で再計算**します。
この再計算に要する追加演算量は全体の順伝播FLOPSの約33%ですが、HBMアクセスを劇的に削減することにより、ウォールクロック時間基準ではむしろ2〜4倍高速化する結果が得られます。これは、現代のGPUが**compute-boundではなくmemory-bound**状態にあるという事実を逆説的に証明する結果です。
IO複雑度の証明
論文で証明されたFlashAttentionのHBMアクセス回数は以下の通りです。
$$
\Theta\left(\frac{N^2 d^2}{M}\right) \quad \text{HBMアクセス}
$$
ここで$M$はSRAMサイズです。Standard Attentionの$\Theta(Nd + N^2)$と比較すると、$d = 64$、$M = 100\text{KB}$、$N = 4096$の場合、FlashAttentionは約**9倍少ないHBMアクセス**を実行します。また論文は、いかなるexact attentionアルゴリズムも$\Omega(N^2 d^2 M^{-1})$未満のHBMアクセスを達成できないことを証明し、FlashAttentionがIO複雑度の観点で**最適(optimal)** であることを示しました。
FlashAttention-2: 並列化とWork Partitioningの改善
v1の限界
FlashAttention v1はA100で理論的最大FLOPSの約30〜50%しか活用できていませんでした。その原因は3つに分析されました。
1. **非効率なnon-matmul演算**: softmaxのリスケーリング、マスキング演算など、行列積でない演算がGPUテンソルコアを活用できません。
2. **シーケンス長方向の並列化の欠如**: バッチとヘッド次元でのみ並列化されるため、小バッチ・少ヘッド数でGPU占有率が低下します。
3. **ワープ間の非効率な作業分配**: 1つのスレッドブロック内でワープ同士が共有メモリを介して不要な同期を行っています。
主要な改善点
FlashAttention-2は以下の3つの最適化を導入しました。
**1. Non-matmul FLOPSの削減**
Online softmaxのリスケーリング演算を再構成し、不要なスケーリング回数を削減しました。また、causal maskingの場合、マスクが不要なブロックではマスキング演算自体をスキップするように改善しました。
**2. シーケンス長方向の並列化**
Forward passで外部ループをQブロック(行)ではなくK/Vブロック(列)に変更し、シーケンス次元でも並列化を可能にしました。Backward passではQブロックとK/Vブロックの両方に対して並列化を適用しました。これにより、バッチサイズが1でヘッド数が少ない場合でも、高いGPU占有率を維持できるようになりました。
**3. ワープレベルの作業分配最適化**
v1では4つのワープが$Q$ブロックと$K/V$ブロックを分担して処理した後、結果を共有メモリを介して合算していました。v2では4つのワープがすべて**同じQブロック**を処理しつつ、それぞれ異なるK/Vブロックを担当するように変更しました。こうすることで、各ワープの結果を合算する必要がなくなり、独立してQの出力に累積できるため、共有メモリの同期が大幅に削減されます。
| 項目 | FlashAttention v1 | FlashAttention-2 |
| :-----------------: | :--------------------: | :--------------------: |
| A100 FP16 達成率 | ~30-50% | ~50-73% |
| 最大 TFLOPS(A100) | ~170 | ~230 |
| 外部ループ軸 | Qブロック(行) | K/Vブロック(列) |
| ワープ分割方式 | QとKVを分割 | KVのみ分割、Q共有 |
| Causal最適化 | 全ブロックにマスク適用 | 不要ブロックをスキップ |
| シーケンス並列化 | 未対応 | 対応 |
FlashAttention-3: FP8、非同期パイプライン
Hopperアーキテクチャの活用
FlashAttention-3は、NVIDIA Hopperアーキテクチャ(H100)の3つの核心的なハードウェア機能を活用します。
**1. WGMMA (Warpgroup Matrix-Multiply-Accumulate)**
Hopperの新しい行列積命令で、前世代のWMMAより大きなタイルサイズと高いスループットを提供します。特に非同期実行をサポートし、データ移動と演算を同時に実行できます。
**2. TMA (Tensor Memory Accelerator)**
HBMから共有メモリ(SRAM)へのデータ転送をハードウェアレベルで非同期的に処理する専用ユニットです。CPUがDMAコントローラにデータ転送を委任するのと同様に、GPUの計算ユニットがメモリ転送を待たずに他の作業を実行できます。
**3. FP8テンソルコア**
HopperはFP8(E4M3、E5M2)形式をハードウェアレベルでサポートし、FP16比2倍の演算スループットを提供します。
ワープ特化(Warp Specialization)
FlashAttention-3の核心手法の1つが**ワープ特化(Warp Specialization)** です。1つのCTA(Cooperative Thread Array)内のワープを**プロデューサー(producer)**と**コンシューマー(consumer)** の役割に分けます。
- **プロデューサーワープ**: TMAを使用してHBMからSRAMへK/Vブロックを非同期的にロードします。
- **コンシューマーワープ**: WGMMAを使用して、すでにSRAMにロードされたデータで行列積とsoftmaxを実行します。
Hopperの`setmaxnreg`命令により、コンシューマーワープにより多くのレジスタを動的に割り当てることができます。プロデューサーワープは最小限のリソースでデータ転送のみを担当し、コンシューマーワープは最大限のレジスタを活用して演算を実行します。
GEMM-Softmax非同期パイプライン
FlashAttention-3は2ステージパイプラインを構成し、GEMM演算とsoftmax演算をオーバーラップさせます。ブロック$j$のsoftmaxを計算している間に、ブロック$j+1$の$QK^T$ GEMMが同時に進行します。このパイプラインは、softmaxがnon-matmul演算でありテンソルコアを使用しないという点を活用したもので、テンソルコアのアイドル時間を最小化します。
FP8対応と精度の維持
FP8の限られた精度による精度低下を、2つの手法で緩和します。
**ブロック量子化(Block Quantization)**: 各ブロックごとに独立したスケールファクターを計算し、ダイナミックレンジを最大化します。テンソル全体に単一のスケールを適用するよりも、表現範囲が大幅に向上します。
**インコヒーレント処理(Incoherent Processing)**: $Q$と$K$にランダム直交行列を掛けて値の分布を均一にした後、量子化します。理論的にこの変換は量子化誤差の期待値を最小化し、アテンション計算後に逆変換を適用して精度を復元します。
| 項目 | FlashAttention-2 | FlashAttention-3 (FP16) | FlashAttention-3 (FP8) |
| :----------------: | :--------------: | :-------------------------------: | :-------------------------------: |
| 対象GPU | A100 | H100 | H100 |
| 最大TFLOPS | ~230 | ~740 | ~1,200 |
| 理論的活用率 | ~73% | ~75% | ~76% |
| ワープ戦略 | 均等分配 | プロデューサー/コンシューマー特化 | プロデューサー/コンシューマー特化 |
| 非同期パイプライン | 未対応 | GEMM-Softmaxオーバーラップ | GEMM-Softmaxオーバーラップ |
| FP8精度補正 | - | - | Block Quant + Incoherent |
性能ベンチマーク比較
学習速度比較(GPT-2モデル基準)
| 構成 | Standard Attention | FlashAttention v1 | FlashAttention-2 | FlashAttention-3 |
| :-------------------: | :----------------: | :---------------: | :--------------: | :--------------: |
| GPU | A100 | A100 | A100 | H100 |
| GPT-2 Small (seq=1K) | 1.0x | 1.7x | 2.3x | 3.8x |
| GPT-2 Medium (seq=1K) | 1.0x | 1.8x | 2.5x | 4.1x |
| GPT-2 Large (seq=2K) | 1.0x | 2.1x | 2.8x | 4.6x |
| GPT-2 XL (seq=2K) | OOM | 1.0x (baseline) | 1.5x | 2.8x |
| seq=4K, d=64 | OOM | 1.0x | 1.7x | 3.2x |
| seq=16K, d=128 | OOM | 1.0x | 1.9x | 3.5x |
FlashAttention v1対比でv2は約1.5〜2倍、v3はFP16基準で約1.5〜2倍(H100ハードウェア向上を含めると3〜4倍)高速化しました。特にシーケンス長が長くなるほど改善幅が大きくなります。
推論速度比較(Prefill段階)
推論のprefill段階は学習と同様に、シーケンス全体に対するアテンションを一括計算するため、FlashAttentionの利点が大きく適用されます。
| シーケンス長 | Standard (ms) | FlashAttention-2 (ms) | 速度向上 |
| :----------: | :-----------: | :-------------------: | :------: |
| 512 | 0.8 | 0.5 | 1.6x |
| 2,048 | 5.2 | 2.1 | 2.5x |
| 8,192 | 78.3 | 18.6 | 4.2x |
| 32,768 | OOM | 285.0 | - |
実践適用: PyTorchとTritonコード
PyTorch SDPAによるFlashAttentionの使用
PyTorch 2.0以降、`torch.nn.functional.scaled_dot_product_attention`にFlashAttentionが統合され、別途ライブラリをインストールせずに使用できます。
from torch.nn.attention import sdpa_kernel, SDPBackend
基本使用: PyTorchが自動的に最適なバックエンドを選択
batch, heads, seq_len, d_head = 2, 16, 4096, 64
Q = torch.randn(batch, heads, seq_len, d_head, device='cuda', dtype=torch.float16)
K = torch.randn_like(Q)
V = torch.randn_like(Q)
自動バックエンド選択(FlashAttentionがサポートされていれば自動使用)
output = F.scaled_dot_product_attention(Q, K, V)
FlashAttentionバックエンドのみを強制使用
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
output_flash = F.scaled_dot_product_attention(
Q, K, V,
dropout_p=0.1, # 学習時のdropout
is_causal=True, # causal masking(デコーダ用)
scale=None, # デフォルト 1/sqrt(d_k) を使用
)
どのバックエンドが使用されているか確認
print(f"Flash available: {torch.backends.cuda.flash_sdp_enabled()}")
print(f"Memory efficient available: {torch.backends.cuda.mem_efficient_sdp_enabled()}")
flash-attnライブラリの直接使用
`flash-attn`パッケージを直接使用すると、より細かい制御と追加機能を活用できます。
pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func, flash_attn_qkvpacked_func
方法1: Q, K, V 分離入力
形状: (batch, seqlen, nheads, headdim) - 注意: PyTorchとhead/seqの順序が異なる
batch, seqlen, nheads, headdim = 2, 4096, 16, 64
Q = torch.randn(batch, seqlen, nheads, headdim, device='cuda', dtype=torch.float16)
K = torch.randn_like(Q)
V = torch.randn_like(Q)
output = flash_attn_func(
Q, K, V,
dropout_p=0.0,
softmax_scale=None, # デフォルト 1/sqrt(headdim)
causal=True,
window_size=(-1, -1), # (-1, -1)は全アテンション、(w, 0)はsliding window
return_attn_probs=False,
)
方法2: QKV packed形式
形状: (batch, seqlen, 3, nheads, headdim)
qkv = torch.randn(batch, seqlen, 3, nheads, headdim, device='cuda', dtype=torch.float16)
output_packed = flash_attn_qkvpacked_func(
qkv,
dropout_p=0.0,
causal=True,
)
方法3: Sliding Window Attention(FlashAttention-2以降)
output_sliding = flash_attn_func(
Q, K, V,
causal=True,
window_size=(256, 0), # 左256トークンのウィンドウ
)
print(f"Output shape: {output.shape}") # (batch, seqlen, nheads, headdim)
TritonによるFlashAttentionカーネル実装スケッチ
OpenAIのTritonコンパイラを使用すると、CUDA Cを直接記述せずにPythonでGPUカーネルを記述できます。
@triton.jit
def flash_attention_kernel(
Q_ptr, K_ptr, V_ptr, O_ptr,
stride_qb, stride_qh, stride_qm, stride_qk,
stride_kb, stride_kh, stride_kn, stride_kk,
stride_vb, stride_vh, stride_vn, stride_vk,
stride_ob, stride_oh, stride_om, stride_ok,
N_CTX: tl.constexpr,
D_HEAD: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""FlashAttention Forward KernelのTriton実装スケッチ。
実際のプロダクションカーネルにはより多くの最適化が含まれますが、
核心となるタイリングロジックの構造を示します。
"""
現在のプログラムが担当するQブロックインデックス
pid_m = tl.program_id(0)
pid_bh = tl.program_id(1)
オフセット計算
off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
off_k = tl.arange(0, D_HEAD)
Qブロックのロード(HBM -> SRAM)
q = tl.load(Q_ptr + pid_bh * stride_qh + off_m[:, None] * stride_qm + off_k[None, :] * stride_qk,
mask=off_m[:, None] < N_CTX)
累積変数の初期化
m_i = tl.full([BLOCK_M], value=float('-inf'), dtype=tl.float32)
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
o_i = tl.zeros([BLOCK_M, D_HEAD], dtype=tl.float32)
K/Vブロックの巡回(内部ループ)
for start_n in range(0, N_CTX, BLOCK_N):
curr_n = start_n + off_n
K, Vブロックのロード(HBM -> SRAM)
k = tl.load(K_ptr + pid_bh * stride_kh + curr_n[:, None] * stride_kn + off_k[None, :] * stride_kk,
mask=curr_n[:, None] < N_CTX)
v = tl.load(V_ptr + pid_bh * stride_vh + curr_n[:, None] * stride_vn + off_k[None, :] * stride_vk,
mask=curr_n[:, None] < N_CTX)
ローカルアテンションスコアの計算(SRAM内)
s = tl.dot(q, tl.trans(k)) * (D_HEAD ** -0.5)
Online Softmaxの更新
m_ij = tl.max(s, axis=1)
m_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_new)
beta = tl.exp(m_ij - m_new)
l_new = alpha * l_i + beta * tl.sum(tl.exp(s - m_ij[:, None]), axis=1)
出力累積の更新
p = tl.exp(s - m_new[:, None])
o_i = alpha[:, None] * l_i[:, None] * o_i + tl.dot(p.to(tl.float16), v)
o_i = o_i / l_new[:, None]
m_i = m_new
l_i = l_new
結果をHBMに格納
tl.store(O_ptr + pid_bh * stride_oh + off_m[:, None] * stride_om + off_k[None, :] * stride_ok,
o_i.to(tl.float16), mask=off_m[:, None] < N_CTX)
HuggingFace Transformersとの統合
HuggingFaceモデルでFlashAttentionを活用する方法です。
from transformers import AutoModelForCausalLM, AutoTokenizer
FlashAttention-2を使用するようにモデルをロード
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B",
torch_dtype=torch.float16,
attn_implementation="flash_attention_2", # 核心設定
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")
ベンチマーク比較
text = "FlashAttention は " * 2000 # 長いシーケンス
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=4096).to("cuda")
FlashAttention-2 推論
torch.cuda.reset_peak_memory_stats()
start = time.time()
with torch.no_grad():
outputs = model(**inputs)
torch.cuda.synchronize()
flash_time = time.time() - start
flash_mem = torch.cuda.max_memory_allocated() / 1e9
print(f"FlashAttention-2: {flash_time:.3f}s, Peak Memory: {flash_mem:.2f} GB")
SDPA eagerモードとの比較(再ロードが必要)
model_eager = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B",
torch_dtype=torch.float16,
attn_implementation="eager",
device_map="auto",
)
torch.cuda.reset_peak_memory_stats()
start = time.time()
with torch.no_grad():
outputs_eager = model_eager(**inputs)
torch.cuda.synchronize()
eager_time = time.time() - start
eager_mem = torch.cuda.max_memory_allocated() / 1e9
print(f"Eager Attention: {eager_time:.3f}s, Peak Memory: {eager_mem:.2f} GB")
print(f"Speedup: {eager_time/flash_time:.2f}x, Memory Savings: {eager_mem/flash_mem:.2f}x")
トラブルシューティング: 失敗事例と復旧
事例1: CUDA互換性エラー
**症状**: `flash-attn`インストール時のビルド失敗、または`RuntimeError: FlashAttention only supports Ampere GPUs or newer`
**原因**: FlashAttentionはSM80(A100)以上のGPUアーキテクチャが必要です。V100(SM70)やT4(SM75)では動作しません。
**解決策**:
- GPUアーキテクチャの確認: `nvidia-smi`または`torch.cuda.get_device_capability()`
- SM75以下の場合、`torch.nn.functional.scaled_dot_product_attention`の`mem_efficient`バックエンドを代替として使用します。このバックエンド(xformersベース)はSM50以上で動作します。
- インストール時に`MAX_JOBS=4 pip install flash-attn --no-build-isolation`で並列ビルド数を制限してOOMを防止
事例2: テンソル形状の不一致
**症状**: `RuntimeError: expected query to have shape (batch, seqlen, nheads, headdim)`
**原因**: `flash_attn`ライブラリは`(batch, seqlen, nheads, headdim)`形状を期待しますが、PyTorchの`MultiheadAttention`は`(batch, nheads, seqlen, headdim)`の順序を使用します。
**解決策**: `q.transpose(1, 2)`または`einops.rearrange(q, 'b h s d -> b s h d')`で変換します。
事例3: シーケンス長のアライメント問題
**症状**: 特定のシーケンス長で結果が不正確になるか、性能が急激に低下する
**原因**: FlashAttentionカーネルは、ブロックサイズ(一般的に64または128)の倍数のときに最適性能を発揮します。倍数でない場合、パディング処理が発生します。
**解決策**: 可能な限りシーケンス長を128の倍数に合わせ、`flash_attn`の`flash_attn_varlen_func`を使用して可変長シーケンスを効率的に処理します。
事例4: GQA/MQAサポート問題
**症状**: Grouped Query AttentionやMulti-Query AttentionモデルでFlashAttention適用時にエラーが発生
**原因**: 初期バージョンでは、Q、K、Vのヘッド数が同一である必要がありました。
**解決策**: FlashAttention-2以降ではGQA/MQAをネイティブサポートしています。`flash_attn_func`にヘッド数の異なるQとK/Vを直接渡すと、内部的にK/Vヘッドが繰り返し(repeat)されて処理されます。
事例5: Backward PassでのNaN発生
**症状**: 学習中にlossがNaNに発散する
**原因**: FP16で非常に長いシーケンス(32K以上)を処理する際、softmaxの指数関数で数値的オーバーフローが発生する可能性があります。
**解決策**: BF16の使用を推奨します。BF16はFP16と同じメモリを使用しながらも、指数範囲がFP32と同一であるため、オーバーフローに強くなります。`torch.autocast(device_type='cuda', dtype=torch.bfloat16)`コンテキスト内で実行してください。
運用時の注意事項
メモリバジェットの計画
FlashAttentionはアテンションスコア行列のメモリを節約しますが、Q/K/Vテンソル自体と出力テンソルのメモリは依然として必要です。おおよそのメモリバジェットは以下のように計算します。
$$
\text{Memory}_{\text{flash}} \approx 2 \times (\text{batch} \times \text{heads} \times N \times d) \times \text{dtype\_size} + O(N)
$$
従来の$O(N^2)$項が$O(N)$に削減されたのであって、全体のメモリがゼロになるわけではありません。
CUDAグラフとの互換性
FlashAttentionは`torch.compile`およびCUDA Graphと互換性がありますが、動的なシーケンス長を使用する場合、CUDA Graphの静的グラフ制約と競合する可能性があります。推論サービングでは、シーケンス長をあらかじめ決められたバケットにパディングしてCUDA Graphを再利用する戦略が効果的です。
モニタリング指標
プロダクション環境でFlashAttention適用後にモニタリングすべき核心指標は以下の通りです。
- **GPU HBM使用率**: `nvidia-smi`のmemory utilizationではなく、実際の割り当て量をモニタリング
- **SM活用率**: `dcgm-exporter`によるStreaming Multiprocessor活用率の追跡
- **カーネル実行時間**: `torch.profiler`または`nsight systems`によるアテンションカーネルレイテンシの追跡
- **数値精度**: FP8使用時に出力の相対誤差を定期的にFP16基準と比較
バージョン選択ガイド
| 状況 | 推奨選択 |
| :----------------------------: | :----------------------------------: |
| A100 + PyTorch 2.0以降 | PyTorch SDPA(自動バックエンド選択) |
| A100 + 最大性能が必要 | flash-attnライブラリの直接使用 |
| H100 + FP16 | FlashAttention-3 |
| H100 + 最大推論スループット | FlashAttention-3 FP8 |
| V100/T4(旧世代GPU) | xformers memory_efficient_attention |
| カスタムアテンション変形が必要 | Tritonで直接実装 |
おわりに
FlashAttentionシリーズは、アルゴリズムの数学的結果を変更することなく、ハードウェアのメモリ階層構造を深く理解し活用するだけで劇的な性能向上を達成した模範的な事例です。「同じ演算を実行しつつ、データをどのように移動させるか」という問いが実際のウォールクロック時間で2〜4倍の差を生み出すという事実は、現代のGPUプログラミングにおいてIO-awarenessがいかに重要であるかを雄弁に物語っています。
v1でタイリング(Tiling)と再計算(Recomputation)という核心アイデアを確立し、v2でGPU内部のミクロレベルの作業分配を最適化し、v3で次世代ハードウェアの新機能を先制的に活用する発展過程は、AIシステム最適化研究が進むべき方向を明確に示しています。FlashAttentionは今やPyTorchに内蔵され、ほとんどの実務者が意識せずとも自動的に使用されるインフラレベルの技術となりましたが、その内部動作原理を理解することは、より優れたモデル設計とシステム最適化の出発点となるでしょう。
参考資料
- [FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (arXiv)](https://arxiv.org/abs/2205.14135) - FlashAttention v1 原論文
- [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning (arXiv)](https://arxiv.org/abs/2307.08691) - FlashAttention-2 論文
- [FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision (arXiv)](https://arxiv.org/abs/2407.08608) - FlashAttention-3 論文
- [Dao-AILab/flash-attention GitHub Repository](https://github.com/Dao-AILab/flash-attention) - 公式実装およびインストールガイド
- [PyTorch Scaled Dot Product Attention 公式ドキュメント](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) - PyTorch SDPA APIドキュメント
- [FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision (PyTorch Blog)](https://pytorch.org/blog/flashattention-3/) - PyTorch公式ブログのFlashAttention-3紹介
- [Stanford CRFM FlashAttention-2 解説](https://crfm.stanford.edu/2023/07/17/flash2.html) - StanfordによるFlashAttention-2公式解説
- [Tri Dao FlashAttention-3 ブログ](https://tridao.me/blog/2024/flash3/) - 著者によるFlashAttention-3解説
현재 단락 (1/341)
Transformerアーキテクチャは、NLP、コンピュータビジョン、音声処理など、ほぼすべてのディープラーニング分野の基盤モデルとして定着しました。しかし、Self-Attentionメカニズムの$...