Skip to content
Published on

FlashAttention 論文分析: IO-Aware Exact AttentionによるTransformer学習・推論速度の革新

Authors
  • Name
    Twitter
FlashAttention

はじめに

Transformerアーキテクチャは、NLP、コンピュータビジョン、音声処理など、ほぼすべてのディープラーニング分野の基盤モデルとして定着しました。しかし、Self-AttentionメカニズムのO(N2)O(N^2)メモリ複雑度と繰り返しのGPUメモリアクセスは、学習と推論の双方において深刻なボトルネックを形成しています。特にシーケンス長が数千から数万トークンに拡大する現代のLLM時代において、このボトルネックはもはや無視できないレベルに達しています。

2022年、StanfordのTri Daoを筆頭とする研究チームは、この問題をアルゴリズムの観点ではなくハードウェアIOの観点から攻略するFlashAttentionを発表しました。FlashAttentionは、既存のアテンションの精度を一切損なうことなく(exact attention)、GPUメモリ階層構造を明示的に考慮したIO-awareアルゴリズム設計により、学習時に2〜4倍のウォールクロック速度向上と5〜20倍のメモリ削減を達成しました。

その後、FlashAttention-2(2023年)ではGPU内部のワープ(warp)レベルの作業分配を最適化し、A100で理論的最大FLOPSの50〜73%を達成しました。FlashAttention-3(2024年)ではHopperアーキテクチャ(H100)の非同期実行とFP8テンソルコアを活用し、H100で740 TFLOPS/s(FP16)と1.2 PFLOPS/s(FP8)という驚異的な性能を記録しました。

本記事では、FlashAttentionシリーズの発展過程全体を論文レベルで分析します。Standard Attentionの根本的な問題をGPUメモリ階層の観点から診断し、v1のタイリング(Tiling)とrecomputation戦略、v2の並列化改善、v3の非同期パイプラインと低精度対応まで段階的に扱います。さらに、PyTorchおよびTritonベースの実践コード、性能ベンチマーク、そしてプロダクション適用時に直面する失敗事例と復旧方法まで包括的に整理します。

Standard Attentionの問題点: O(N^2)メモリとIOボトルネック

数学的定義

Self-Attentionの数学的定義は以下の通りです。

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V

ここでQ,K,VRN×dQ, K, V \in \mathbb{R}^{N \times d}であり、NNはシーケンス長、ddはヘッド次元です。問題は、中間行列S=QKTRN×NS = QK^T \in \mathbb{R}^{N \times N}全体的にmaterialization(物理的に格納) しなければならない点です。

メモリ分析

シーケンス長別のアテンションスコア行列のメモリ要件は以下の通りです。

シーケンス長 (N)アテンション行列サイズFP16メモリFP32メモリ
1,0241M2 MB4 MB
4,09616.7M33 MB67 MB
16,384268M536 MB1.07 GB
65,5364.29B8.59 GB17.18 GB
131,07217.18B34.36 GB68.72 GB

バッチサイズとヘッド数を掛けると、実際のメモリ使用量はこれよりはるかに大きくなります。例えば、batch=4、heads=32、N=16384の場合、アテンションスコアだけで約68GBが必要となり、A100 80GBのほぼ全容量を消費します。

IOボトルネックの実態

しかし、純粋な演算の観点から見ると、状況は異なります。Self-AttentionのFLOPSはO(N2d)O(N^2 d)であり、メモリアクセス量はO(N2+Nd)O(N^2 + Nd)です。ここで算術強度(arithmetic intensity)を計算すると、おおよそO(d)O(d)程度ですが、これは現代GPUのFLOPS対メモリ帯域幅比(ops:byte ratio)に比べてかなり低い値です。

A100 GPUの場合、テンソルコアFP16性能は312 TFLOPS/sで、HBM帯域幅は2TB/sです。これはops:byte ratioが約156であることを意味しますが、ヘッド次元ddが一般的に64〜128であるself-attentionは、この比率に対して演算密度が非常に低くなります。つまり、GPUは演算を待っているのではなく、データの読み書きを待っているのです。これこそが、Standard AttentionがGPU活用率30%以下にとどまる根本的な原因です。

import torch
import torch.nn.functional as F
import time

def standard_attention(Q, K, V, mask=None):
    """Standard Self-Attention: N x Nスコア行列全体をmaterializationする。"""
    d_k = Q.size(-1)
    # S = Q @ K^T -> (batch, heads, N, N) 全体をメモリに格納
    scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # softmaxもN x N行列全体に対して実行
    attn_weights = F.softmax(scores, dim=-1)  # (batch, heads, N, N) を格納

    # 最終出力の計算
    output = torch.matmul(attn_weights, V)
    return output

# ベンチマーク
batch, heads, seq_len, d_head = 4, 32, 4096, 64
Q = torch.randn(batch, heads, seq_len, d_head, device='cuda', dtype=torch.float16)
K = torch.randn_like(Q)
V = torch.randn_like(Q)

# ウォームアップ
for _ in range(3):
    _ = standard_attention(Q, K, V)
torch.cuda.synchronize()

# 計測
start = time.time()
for _ in range(100):
    _ = standard_attention(Q, K, V)
torch.cuda.synchronize()
elapsed = (time.time() - start) / 100

print(f"Standard Attention: {elapsed*1000:.2f} ms/iter")
print(f"Peak memory: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")

このコードでは、scoresテンソルが4×32×4096×4096×2=4.29GB4 \times 32 \times 4096 \times 4096 \times 2 = 4.29\text{GB}のHBMを消費し、この値をHBMから読み書きする過程で大部分の時間が消費されます。

GPUメモリ階層構造: SRAM vs HBM

FlashAttentionの核心的な洞察は、GPUのメモリ階層構造を明示的に活用することです。GPUには大きく2つのメモリレベルが存在します。

HBM (High Bandwidth Memory)

  • 容量: 40〜80GB(A100)、80〜141GB(H100/H200)
  • 帯域幅: 1.5〜3.35 TB/s
  • 役割: モデル重み、活性化値、オプティマイザ状態など大容量データの格納
  • 特性: 容量は大きいが、アクセスレイテンシが相対的に長い

SRAM (On-chip Static RAM)

  • 容量: SM当たり192KB(A100)、全体で約20〜40MBレベル
  • 帯域幅: ~19 TB/s(A100基準)
  • 役割: 各Streaming Multiprocessor(SM)の共有メモリとして、カーネル実行中の一時データを格納
  • 特性: HBM比で約10倍高速なアクセス速度だが、容量は約1000分の1
メモリ階層容量(A100)帯域幅アクセスレイテンシ例え
L1/SRAM~20MB(合計)~19 TB/s~28サイクル机上のメモ
L2 Cache40MB~5 TB/s~200サイクル引き出し
HBM80GB~2 TB/s~400サイクル書庫
CPU RAM~1TB~50 GB/s~数千サイクル図書館

Standard Attentionの問題は、中間結果(SSP=softmax(S)P = \text{softmax}(S))をすべてHBMに書き込んでから再度読み出す点です。アテンション演算全体におけるHBMアクセス回数はΩ(Nd+N2)\Omega(Nd + N^2) です。FlashAttentionの目標は、SRAMの活用を通じてこのアクセス回数をO(N2d2M1)O(N^2 d^2 M^{-1})に削減することです。ここでMMはSRAMサイズであり、一般的なdd(64〜128)とMM(約100KB〜192KB)において、この値はstandardアプローチよりも数倍から数十倍小さくなります。

FlashAttention v1: Tiling + Recomputation

核心アイデア

FlashAttention v1のアルゴリズムは、2つの核心的な手法で構成されています。

  1. タイリング(Tiling): Q、K、V行列をSRAMに収まるサイズのブロックに分割し、ブロック単位でアテンションを計算します。この過程で、N x Nサイズのアテンションスコア行列をHBMに一度も全体的に格納しません。

  2. 再計算(Recomputation): 逆伝播(backward pass)でアテンションスコアを保存せず、順伝播で保存したソフトマックス正規化統計量(最大値mmと合計\ell)のみを使用して、必要なときに再計算します。

Online Softmaxとブロック単位の累積

行全体に対するsoftmaxをブロック単位で正確に計算するために、Online Softmax手法を使用します。各ブロックjjを処理した後の累積出力は以下のように更新されます。

mi(j)=max(mi(j1),m~ij)m_i^{(j)} = \max(m_i^{(j-1)}, \tilde{m}_{ij}) i(j)=emi(j1)mi(j)i(j1)+em~ijmi(j)~ij\ell_i^{(j)} = e^{m_i^{(j-1)} - m_i^{(j)}} \ell_i^{(j-1)} + e^{\tilde{m}_{ij} - m_i^{(j)}} \tilde{\ell}_{ij} Oi(j)=emi(j1)mi(j)i(j1)Oi(j1)+em~ijmi(j)P~ijVji(j)O_i^{(j)} = \frac{e^{m_i^{(j-1)} - m_i^{(j)}} \ell_i^{(j-1)} O_i^{(j-1)} + e^{\tilde{m}_{ij} - m_i^{(j)}} \tilde{P}_{ij} V_j}{\ell_i^{(j)}}

ここでm~ij\tilde{m}_{ij}~ij\tilde{\ell}_{ij}は現在のブロックのローカルsoftmax統計量です。この数式により、全体のsoftmaxを一度に計算せずとも、ブロックごとに段階的に正確な結果を累積できます。

アルゴリズム擬似コード

FlashAttention v1の順伝播アルゴリズムをPython擬似コードで表現すると以下の通りです。

import torch
import math

def flash_attention_forward(Q, K, V, block_size_q, block_size_kv):
    """FlashAttention v1 Forward Pass 擬似コード。

    実際のCUDAカーネルでは、すべての演算が単一のGPUカーネル内で実行され、
    中間結果はSRAMにのみ保持されます。
    """
    batch, heads, N, d = Q.shape
    O = torch.zeros_like(Q)           # 出力累積
    m = torch.full((batch, heads, N, 1), float('-inf'), device=Q.device)  # 行ごとの最大値
    l = torch.zeros((batch, heads, N, 1), device=Q.device)                # 行ごとの合計

    # 外部ループ: Qをブロック単位で巡回
    for i in range(0, N, block_size_q):
        qi = Q[:, :, i:i+block_size_q, :]  # SRAMへロード

        # 内部ループ: K, Vをブロック単位で巡回
        for j in range(0, N, block_size_kv):
            kj = K[:, :, j:j+block_size_kv, :]  # SRAMへロード
            vj = V[:, :, j:j+block_size_kv, :]  # SRAMへロード

            # 1. ローカルアテンションスコアの計算(SRAM内で)
            sij = torch.matmul(qi, kj.transpose(-2, -1)) / math.sqrt(d)

            # 2. ローカルsoftmax統計量
            mij_local = sij.max(dim=-1, keepdim=True).values
            pij_local = torch.exp(sij - mij_local)
            lij_local = pij_local.sum(dim=-1, keepdim=True)

            # 3. Online Softmaxによるグローバル統計量の更新
            mi_old = m[:, :, i:i+block_size_q, :]
            li_old = l[:, :, i:i+block_size_q, :]
            oi_old = O[:, :, i:i+block_size_q, :]

            mi_new = torch.maximum(mi_old, mij_local)
            alpha = torch.exp(mi_old - mi_new)
            beta = torch.exp(mij_local - mi_new)

            li_new = alpha * li_old + beta * lij_local

            # 4. 出力累積の更新
            O[:, :, i:i+block_size_q, :] = (
                alpha * li_old * oi_old + beta * torch.matmul(pij_local, vj)
            ) / li_new

            m[:, :, i:i+block_size_q, :] = mi_new
            l[:, :, i:i+block_size_q, :] = li_new

    return O, m, l  # m, lはbackwardでのrecomputationに使用

Recomputation戦略

逆伝播において従来の方式では、順伝播で保存したP=softmax(S)P = \text{softmax}(S)行列(N×NN \times N)を使用して勾配を計算します。FlashAttentionはこのN×NN \times N行列を保存せず、順伝播で保存したmm\ellの統計量(それぞれO(N)O(N)サイズ)のみで、逆伝播時にSSPPブロック単位で再計算します。

この再計算に要する追加演算量は全体の順伝播FLOPSの約33%ですが、HBMアクセスを劇的に削減することにより、ウォールクロック時間基準ではむしろ2〜4倍高速化する結果が得られます。これは、現代のGPUがcompute-boundではなくmemory-bound状態にあるという事実を逆説的に証明する結果です。

IO複雑度の証明

論文で証明されたFlashAttentionのHBMアクセス回数は以下の通りです。

Θ(N2d2M)HBMアクセス\Theta\left(\frac{N^2 d^2}{M}\right) \quad \text{HBMアクセス}

ここでMMはSRAMサイズです。Standard AttentionのΘ(Nd+N2)\Theta(Nd + N^2)と比較すると、d=64d = 64M=100KBM = 100\text{KB}N=4096N = 4096の場合、FlashAttentionは約9倍少ないHBMアクセスを実行します。また論文は、いかなるexact attentionアルゴリズムもΩ(N2d2M1)\Omega(N^2 d^2 M^{-1})未満のHBMアクセスを達成できないことを証明し、FlashAttentionがIO複雑度の観点で最適(optimal) であることを示しました。

FlashAttention-2: 並列化とWork Partitioningの改善

v1の限界

FlashAttention v1はA100で理論的最大FLOPSの約30〜50%しか活用できていませんでした。その原因は3つに分析されました。

  1. 非効率なnon-matmul演算: softmaxのリスケーリング、マスキング演算など、行列積でない演算がGPUテンソルコアを活用できません。
  2. シーケンス長方向の並列化の欠如: バッチとヘッド次元でのみ並列化されるため、小バッチ・少ヘッド数でGPU占有率が低下します。
  3. ワープ間の非効率な作業分配: 1つのスレッドブロック内でワープ同士が共有メモリを介して不要な同期を行っています。

主要な改善点

FlashAttention-2は以下の3つの最適化を導入しました。

1. Non-matmul FLOPSの削減

Online softmaxのリスケーリング演算を再構成し、不要なスケーリング回数を削減しました。また、causal maskingの場合、マスクが不要なブロックではマスキング演算自体をスキップするように改善しました。

2. シーケンス長方向の並列化

Forward passで外部ループをQブロック(行)ではなくK/Vブロック(列)に変更し、シーケンス次元でも並列化を可能にしました。Backward passではQブロックとK/Vブロックの両方に対して並列化を適用しました。これにより、バッチサイズが1でヘッド数が少ない場合でも、高いGPU占有率を維持できるようになりました。

3. ワープレベルの作業分配最適化

v1では4つのワープがQQブロックとK/VK/Vブロックを分担して処理した後、結果を共有メモリを介して合算していました。v2では4つのワープがすべて同じQブロックを処理しつつ、それぞれ異なるK/Vブロックを担当するように変更しました。こうすることで、各ワープの結果を合算する必要がなくなり、独立してQの出力に累積できるため、共有メモリの同期が大幅に削減されます。

項目FlashAttention v1FlashAttention-2
A100 FP16 達成率~30-50%~50-73%
最大 TFLOPS(A100)~170~230
外部ループ軸Qブロック(行)K/Vブロック(列)
ワープ分割方式QとKVを分割KVのみ分割、Q共有
Causal最適化全ブロックにマスク適用不要ブロックをスキップ
シーケンス並列化未対応対応

FlashAttention-3: FP8、非同期パイプライン

Hopperアーキテクチャの活用

FlashAttention-3は、NVIDIA Hopperアーキテクチャ(H100)の3つの核心的なハードウェア機能を活用します。

1. WGMMA (Warpgroup Matrix-Multiply-Accumulate)

Hopperの新しい行列積命令で、前世代のWMMAより大きなタイルサイズと高いスループットを提供します。特に非同期実行をサポートし、データ移動と演算を同時に実行できます。

2. TMA (Tensor Memory Accelerator)

HBMから共有メモリ(SRAM)へのデータ転送をハードウェアレベルで非同期的に処理する専用ユニットです。CPUがDMAコントローラにデータ転送を委任するのと同様に、GPUの計算ユニットがメモリ転送を待たずに他の作業を実行できます。

3. FP8テンソルコア

HopperはFP8(E4M3、E5M2)形式をハードウェアレベルでサポートし、FP16比2倍の演算スループットを提供します。

ワープ特化(Warp Specialization)

FlashAttention-3の核心手法の1つがワープ特化(Warp Specialization) です。1つのCTA(Cooperative Thread Array)内のワープをプロデューサー(producer)コンシューマー(consumer) の役割に分けます。

  • プロデューサーワープ: TMAを使用してHBMからSRAMへK/Vブロックを非同期的にロードします。
  • コンシューマーワープ: WGMMAを使用して、すでにSRAMにロードされたデータで行列積とsoftmaxを実行します。

Hopperのsetmaxnreg命令により、コンシューマーワープにより多くのレジスタを動的に割り当てることができます。プロデューサーワープは最小限のリソースでデータ転送のみを担当し、コンシューマーワープは最大限のレジスタを活用して演算を実行します。

GEMM-Softmax非同期パイプライン

FlashAttention-3は2ステージパイプラインを構成し、GEMM演算とsoftmax演算をオーバーラップさせます。ブロックjjのsoftmaxを計算している間に、ブロックj+1j+1QKTQK^T GEMMが同時に進行します。このパイプラインは、softmaxがnon-matmul演算でありテンソルコアを使用しないという点を活用したもので、テンソルコアのアイドル時間を最小化します。

FP8対応と精度の維持

FP8の限られた精度による精度低下を、2つの手法で緩和します。

ブロック量子化(Block Quantization): 各ブロックごとに独立したスケールファクターを計算し、ダイナミックレンジを最大化します。テンソル全体に単一のスケールを適用するよりも、表現範囲が大幅に向上します。

インコヒーレント処理(Incoherent Processing): QQKKにランダム直交行列を掛けて値の分布を均一にした後、量子化します。理論的にこの変換は量子化誤差の期待値を最小化し、アテンション計算後に逆変換を適用して精度を復元します。

項目FlashAttention-2FlashAttention-3 (FP16)FlashAttention-3 (FP8)
対象GPUA100H100H100
最大TFLOPS~230~740~1,200
理論的活用率~73%~75%~76%
ワープ戦略均等分配プロデューサー/コンシューマー特化プロデューサー/コンシューマー特化
非同期パイプライン未対応GEMM-SoftmaxオーバーラップGEMM-Softmaxオーバーラップ
FP8精度補正--Block Quant + Incoherent

性能ベンチマーク比較

学習速度比較(GPT-2モデル基準)

構成Standard AttentionFlashAttention v1FlashAttention-2FlashAttention-3
GPUA100A100A100H100
GPT-2 Small (seq=1K)1.0x1.7x2.3x3.8x
GPT-2 Medium (seq=1K)1.0x1.8x2.5x4.1x
GPT-2 Large (seq=2K)1.0x2.1x2.8x4.6x
GPT-2 XL (seq=2K)OOM1.0x (baseline)1.5x2.8x
seq=4K, d=64OOM1.0x1.7x3.2x
seq=16K, d=128OOM1.0x1.9x3.5x

FlashAttention v1対比でv2は約1.5〜2倍、v3はFP16基準で約1.5〜2倍(H100ハードウェア向上を含めると3〜4倍)高速化しました。特にシーケンス長が長くなるほど改善幅が大きくなります。

推論速度比較(Prefill段階)

推論のprefill段階は学習と同様に、シーケンス全体に対するアテンションを一括計算するため、FlashAttentionの利点が大きく適用されます。

シーケンス長Standard (ms)FlashAttention-2 (ms)速度向上
5120.80.51.6x
2,0485.22.12.5x
8,19278.318.64.2x
32,768OOM285.0-

実践適用: PyTorchとTritonコード

PyTorch SDPAによるFlashAttentionの使用

PyTorch 2.0以降、torch.nn.functional.scaled_dot_product_attentionにFlashAttentionが統合され、別途ライブラリをインストールせずに使用できます。

import torch
import torch.nn.functional as F
from torch.nn.attention import sdpa_kernel, SDPBackend

# 基本使用: PyTorchが自動的に最適なバックエンドを選択
batch, heads, seq_len, d_head = 2, 16, 4096, 64
Q = torch.randn(batch, heads, seq_len, d_head, device='cuda', dtype=torch.float16)
K = torch.randn_like(Q)
V = torch.randn_like(Q)

# 自動バックエンド選択(FlashAttentionがサポートされていれば自動使用)
output = F.scaled_dot_product_attention(Q, K, V)

# FlashAttentionバックエンドのみを強制使用
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    output_flash = F.scaled_dot_product_attention(
        Q, K, V,
        dropout_p=0.1,      # 学習時のdropout
        is_causal=True,      # causal masking(デコーダ用)
        scale=None,          # デフォルト 1/sqrt(d_k) を使用
    )

# どのバックエンドが使用されているか確認
print(f"Flash available: {torch.backends.cuda.flash_sdp_enabled()}")
print(f"Memory efficient available: {torch.backends.cuda.mem_efficient_sdp_enabled()}")

flash-attnライブラリの直接使用

flash-attnパッケージを直接使用すると、より細かい制御と追加機能を活用できます。

# pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func, flash_attn_qkvpacked_func
import torch

# 方法1: Q, K, V 分離入力
# 形状: (batch, seqlen, nheads, headdim) - 注意: PyTorchとhead/seqの順序が異なる
batch, seqlen, nheads, headdim = 2, 4096, 16, 64
Q = torch.randn(batch, seqlen, nheads, headdim, device='cuda', dtype=torch.float16)
K = torch.randn_like(Q)
V = torch.randn_like(Q)

output = flash_attn_func(
    Q, K, V,
    dropout_p=0.0,
    softmax_scale=None,   # デフォルト 1/sqrt(headdim)
    causal=True,
    window_size=(-1, -1), # (-1, -1)は全アテンション、(w, 0)はsliding window
    return_attn_probs=False,
)

# 方法2: QKV packed形式
# 形状: (batch, seqlen, 3, nheads, headdim)
qkv = torch.randn(batch, seqlen, 3, nheads, headdim, device='cuda', dtype=torch.float16)
output_packed = flash_attn_qkvpacked_func(
    qkv,
    dropout_p=0.0,
    causal=True,
)

# 方法3: Sliding Window Attention(FlashAttention-2以降)
output_sliding = flash_attn_func(
    Q, K, V,
    causal=True,
    window_size=(256, 0),  # 左256トークンのウィンドウ
)

print(f"Output shape: {output.shape}")  # (batch, seqlen, nheads, headdim)

TritonによるFlashAttentionカーネル実装スケッチ

OpenAIのTritonコンパイラを使用すると、CUDA Cを直接記述せずにPythonでGPUカーネルを記述できます。

import triton
import triton.language as tl
import torch

@triton.jit
def flash_attention_kernel(
    Q_ptr, K_ptr, V_ptr, O_ptr,
    stride_qb, stride_qh, stride_qm, stride_qk,
    stride_kb, stride_kh, stride_kn, stride_kk,
    stride_vb, stride_vh, stride_vn, stride_vk,
    stride_ob, stride_oh, stride_om, stride_ok,
    N_CTX: tl.constexpr,
    D_HEAD: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    """FlashAttention Forward KernelのTriton実装スケッチ。

    実際のプロダクションカーネルにはより多くの最適化が含まれますが、
    核心となるタイリングロジックの構造を示します。
    """
    # 現在のプログラムが担当するQブロックインデックス
    pid_m = tl.program_id(0)
    pid_bh = tl.program_id(1)

    # オフセット計算
    off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    off_n = tl.arange(0, BLOCK_N)
    off_k = tl.arange(0, D_HEAD)

    # Qブロックのロード(HBM -> SRAM)
    q = tl.load(Q_ptr + pid_bh * stride_qh + off_m[:, None] * stride_qm + off_k[None, :] * stride_qk,
                mask=off_m[:, None] < N_CTX)

    # 累積変数の初期化
    m_i = tl.full([BLOCK_M], value=float('-inf'), dtype=tl.float32)
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    o_i = tl.zeros([BLOCK_M, D_HEAD], dtype=tl.float32)

    # K/Vブロックの巡回(内部ループ)
    for start_n in range(0, N_CTX, BLOCK_N):
        curr_n = start_n + off_n

        # K, Vブロックのロード(HBM -> SRAM)
        k = tl.load(K_ptr + pid_bh * stride_kh + curr_n[:, None] * stride_kn + off_k[None, :] * stride_kk,
                    mask=curr_n[:, None] < N_CTX)
        v = tl.load(V_ptr + pid_bh * stride_vh + curr_n[:, None] * stride_vn + off_k[None, :] * stride_vk,
                    mask=curr_n[:, None] < N_CTX)

        # ローカルアテンションスコアの計算(SRAM内)
        s = tl.dot(q, tl.trans(k)) * (D_HEAD ** -0.5)

        # Online Softmaxの更新
        m_ij = tl.max(s, axis=1)
        m_new = tl.maximum(m_i, m_ij)
        alpha = tl.exp(m_i - m_new)
        beta = tl.exp(m_ij - m_new)
        l_new = alpha * l_i + beta * tl.sum(tl.exp(s - m_ij[:, None]), axis=1)

        # 出力累積の更新
        p = tl.exp(s - m_new[:, None])
        o_i = alpha[:, None] * l_i[:, None] * o_i + tl.dot(p.to(tl.float16), v)
        o_i = o_i / l_new[:, None]

        m_i = m_new
        l_i = l_new

    # 結果をHBMに格納
    tl.store(O_ptr + pid_bh * stride_oh + off_m[:, None] * stride_om + off_k[None, :] * stride_ok,
             o_i.to(tl.float16), mask=off_m[:, None] < N_CTX)

HuggingFace Transformersとの統合

HuggingFaceモデルでFlashAttentionを活用する方法です。

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# FlashAttention-2を使用するようにモデルをロード
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B",
    torch_dtype=torch.float16,
    attn_implementation="flash_attention_2",  # 核心設定
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")

# ベンチマーク比較
import time

text = "FlashAttention は " * 2000  # 長いシーケンス
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=4096).to("cuda")

# FlashAttention-2 推論
torch.cuda.reset_peak_memory_stats()
start = time.time()
with torch.no_grad():
    outputs = model(**inputs)
torch.cuda.synchronize()
flash_time = time.time() - start
flash_mem = torch.cuda.max_memory_allocated() / 1e9

print(f"FlashAttention-2: {flash_time:.3f}s, Peak Memory: {flash_mem:.2f} GB")

# SDPA eagerモードとの比較(再ロードが必要)
model_eager = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B",
    torch_dtype=torch.float16,
    attn_implementation="eager",
    device_map="auto",
)

torch.cuda.reset_peak_memory_stats()
start = time.time()
with torch.no_grad():
    outputs_eager = model_eager(**inputs)
torch.cuda.synchronize()
eager_time = time.time() - start
eager_mem = torch.cuda.max_memory_allocated() / 1e9

print(f"Eager Attention: {eager_time:.3f}s, Peak Memory: {eager_mem:.2f} GB")
print(f"Speedup: {eager_time/flash_time:.2f}x, Memory Savings: {eager_mem/flash_mem:.2f}x")

トラブルシューティング: 失敗事例と復旧

事例1: CUDA互換性エラー

症状: flash-attnインストール時のビルド失敗、またはRuntimeError: FlashAttention only supports Ampere GPUs or newer

原因: FlashAttentionはSM80(A100)以上のGPUアーキテクチャが必要です。V100(SM70)やT4(SM75)では動作しません。

解決策:

  • GPUアーキテクチャの確認: nvidia-smiまたはtorch.cuda.get_device_capability()
  • SM75以下の場合、torch.nn.functional.scaled_dot_product_attentionmem_efficientバックエンドを代替として使用します。このバックエンド(xformersベース)はSM50以上で動作します。
  • インストール時にMAX_JOBS=4 pip install flash-attn --no-build-isolationで並列ビルド数を制限してOOMを防止

事例2: テンソル形状の不一致

症状: RuntimeError: expected query to have shape (batch, seqlen, nheads, headdim)

原因: flash_attnライブラリは(batch, seqlen, nheads, headdim)形状を期待しますが、PyTorchのMultiheadAttention(batch, nheads, seqlen, headdim)の順序を使用します。

解決策: q.transpose(1, 2)またはeinops.rearrange(q, 'b h s d -> b s h d')で変換します。

事例3: シーケンス長のアライメント問題

症状: 特定のシーケンス長で結果が不正確になるか、性能が急激に低下する

原因: FlashAttentionカーネルは、ブロックサイズ(一般的に64または128)の倍数のときに最適性能を発揮します。倍数でない場合、パディング処理が発生します。

解決策: 可能な限りシーケンス長を128の倍数に合わせ、flash_attnflash_attn_varlen_funcを使用して可変長シーケンスを効率的に処理します。

事例4: GQA/MQAサポート問題

症状: Grouped Query AttentionやMulti-Query AttentionモデルでFlashAttention適用時にエラーが発生

原因: 初期バージョンでは、Q、K、Vのヘッド数が同一である必要がありました。

解決策: FlashAttention-2以降ではGQA/MQAをネイティブサポートしています。flash_attn_funcにヘッド数の異なるQとK/Vを直接渡すと、内部的にK/Vヘッドが繰り返し(repeat)されて処理されます。

事例5: Backward PassでのNaN発生

症状: 学習中にlossがNaNに発散する

原因: FP16で非常に長いシーケンス(32K以上)を処理する際、softmaxの指数関数で数値的オーバーフローが発生する可能性があります。

解決策: BF16の使用を推奨します。BF16はFP16と同じメモリを使用しながらも、指数範囲がFP32と同一であるため、オーバーフローに強くなります。torch.autocast(device_type='cuda', dtype=torch.bfloat16)コンテキスト内で実行してください。

運用時の注意事項

メモリバジェットの計画

FlashAttentionはアテンションスコア行列のメモリを節約しますが、Q/K/Vテンソル自体と出力テンソルのメモリは依然として必要です。おおよそのメモリバジェットは以下のように計算します。

Memoryflash2×(batch×heads×N×d)×dtype_size+O(N)\text{Memory}_{\text{flash}} \approx 2 \times (\text{batch} \times \text{heads} \times N \times d) \times \text{dtype\_size} + O(N)

従来のO(N2)O(N^2)項がO(N)O(N)に削減されたのであって、全体のメモリがゼロになるわけではありません。

CUDAグラフとの互換性

FlashAttentionはtorch.compileおよびCUDA Graphと互換性がありますが、動的なシーケンス長を使用する場合、CUDA Graphの静的グラフ制約と競合する可能性があります。推論サービングでは、シーケンス長をあらかじめ決められたバケットにパディングしてCUDA Graphを再利用する戦略が効果的です。

モニタリング指標

プロダクション環境でFlashAttention適用後にモニタリングすべき核心指標は以下の通りです。

  • GPU HBM使用率: nvidia-smiのmemory utilizationではなく、実際の割り当て量をモニタリング
  • SM活用率: dcgm-exporterによるStreaming Multiprocessor活用率の追跡
  • カーネル実行時間: torch.profilerまたはnsight systemsによるアテンションカーネルレイテンシの追跡
  • 数値精度: FP8使用時に出力の相対誤差を定期的にFP16基準と比較

バージョン選択ガイド

状況推奨選択
A100 + PyTorch 2.0以降PyTorch SDPA(自動バックエンド選択)
A100 + 最大性能が必要flash-attnライブラリの直接使用
H100 + FP16FlashAttention-3
H100 + 最大推論スループットFlashAttention-3 FP8
V100/T4(旧世代GPU)xformers memory_efficient_attention
カスタムアテンション変形が必要Tritonで直接実装

おわりに

FlashAttentionシリーズは、アルゴリズムの数学的結果を変更することなく、ハードウェアのメモリ階層構造を深く理解し活用するだけで劇的な性能向上を達成した模範的な事例です。「同じ演算を実行しつつ、データをどのように移動させるか」という問いが実際のウォールクロック時間で2〜4倍の差を生み出すという事実は、現代のGPUプログラミングにおいてIO-awarenessがいかに重要であるかを雄弁に物語っています。

v1でタイリング(Tiling)と再計算(Recomputation)という核心アイデアを確立し、v2でGPU内部のミクロレベルの作業分配を最適化し、v3で次世代ハードウェアの新機能を先制的に活用する発展過程は、AIシステム最適化研究が進むべき方向を明確に示しています。FlashAttentionは今やPyTorchに内蔵され、ほとんどの実務者が意識せずとも自動的に使用されるインフラレベルの技術となりましたが、その内部動作原理を理解することは、より優れたモデル設計とシステム最適化の出発点となるでしょう。

参考資料