Skip to content
Published on

行列はGPUでどう飛び回るか:GEMMからFlashAttentionまで完全解剖

Authors

なぜ行列乗算がディープラーニングのすべてなのか

GPUがディープラーニングモデルを実行する際、実際の作業の80%以上は単一の演算に帰着する。それが**行列乗算(Matrix Multiplication)**だ。

各レイヤーを分解して確認しよう:

Linear Layer(全結合層)

output = W × input + b
shape: (out_features,) = (out_features, in_features) × (in_features,)

最も基本的な線形変換。純粋な行列-ベクトル乗算だ。

Self-Attention(Transformerの核心)

Q = W_q × x        # 行列乗算
K = W_k × x        # 行列乗算
V = W_v × x        # 行列乗算
scores = Q × K^T   # 行列乗算 (N×d × d×N = N×N)
weights = softmax(scores / sqrt(d_k))
output = weights × V  # 行列乗算 (N×N × N×d = N×d)

1回のAttentionレイヤーで最低6回の行列乗算が発生する。

Convolution(CNN)

im2col変換後:
output = W × im2col(input)

畳み込みも最終的には行列乗算に変換される。

GPT-3(175B)のFLOP分布:

Linearレイヤー(QKVプロジェクション):  ~45%
Feed-forwardレイヤー:                   ~35%
Attention計算(QK^TAV:              ~15%
その他(LayerNorm、softmax等):           ~5%
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
行列乗算合計:                           ~95%

行列乗算をどれだけ高速化できるか = ディープラーニングの推論・学習速度が決まる。これが40年間、数千人のエンジニアが行列乗算の最適化に取り組んできた理由だ。


ナイーブな行列乗算とその問題点

3重ループ:定義そのもの

import numpy as np

def naive_matmul(A, B):
    """
    A: (M, K)  B: (K, N) -> C: (M, N)
    C[i][j] = sum(A[i][k] * B[k][j] for k in range(K))
    """
    M, K = A.shape
    K2, N = B.shape
    assert K == K2, "内側の次元が一致している必要があります"

    C = np.zeros((M, N), dtype=np.float64)

    for i in range(M):          # M回反復
        for j in range(N):      # N回反復
            for k in range(K):  # K回反復
                C[i][j] += A[i][k] * B[k][j]  # MAC演算1回

    return C
    # 総演算数: M * N * K 回の multiply-add
    # 計算量: O(n^3)
    # 4096×4096行列: 4096^3 = 約680億回の演算!

数学的には完璧だが、実際に動かすと非常に遅い。

4096×4096 FP32行列乗算の場合:

  • 理論演算量:2 × 4096³ ≈ 137 GFLOPS
  • H100の理論性能:67 TFLOPS(FP32)
  • 理論時間:137G / 67T ≈ 2ms
  • ナイーブPythonの実際の時間:約900秒(450,000倍遅い!)

差異の原因は2つ:メモリアクセスパターン並列性の未活用

キャッシュミス:真の敵

メモリ階層(H100基準):
┌─────────────────────────────────────────────────────────┐
L1 Cache (SM):       256KB,  ~5サイクル 遅延         │
L2 Cache (共有):       50MB,   ~30サイクル遅延         │
HBM3 (GPUメインメモリ): 80GB,   ~300サイクル遅延       │
CPU DRAM:TB,    ~1000サイクル遅延       │
└─────────────────────────────────────────────────────────┘

キャッシュアクセス: 300サイクル ↔ 5サイクル = 60倍の差
キャッシュヒット率 = 性能

ナイーブな3重ループのメモリアクセスパターン(A[4×4], B[4×4]の例):

A行列(行優先格納、row-major):
メモリアドレス: [A00][A01][A02][A03] [A10][A11][A12][A13] ...
                ←─ キャッシュライン1 ─→ ←─ キャッシュライン2 ─→

A[i][k]アクセスパターン(i=0固定、k=0,1,2,3順):
A[0][0] → キャッシュライン1ロード(キャッシュMISS300サイクル)
A[0][1] → キャッシュライン1再利用(キャッシュHIT5サイクル)✅
A[0][2] → キャッシュライン1再利用(キャッシュHIT)✅
A[0][3] → キャッシュライン1再利用(キャッシュHIT)✅
Aは行方向アクセス:キャッシュフレンドリー ✅

B行列の列方向アクセス(j=0固定、k=0,1,2,3順):
B[0][0] → キャッシュライン0ロード(キャッシュMISSB[1][0] → キャッシュライン2ロード(キャッシュMISS! 別の行)❌
B[2][0] → キャッシュライン4ロード(キャッシュMISS!)❌
B[3][0] → キャッシュライン6ロード(キャッシュMISS!)❌

Bの各要素が別々のキャッシュラインに存在
大行列(N=4096)ではBアクセスの99%がキャッシュミス!

キャッシュブロッキング(タイリング):解決策

コアアイデア:キャッシュに収まるブロックに分割

行列をキャッシュに収まる小さなタイルに分割し、一度ロードしたデータを最大限再利用する。

行列タイリングの可視化(M=8, N=8, K=8, TILE_SIZE=4:

すべての行列を4×4タイルに分割:

    B (K×N):
    ┌───┬───┐
B00B01B00 = B[0:4, 0:4]
    ├───┼───┤  B01 = B[0:4, 4:8]
B10B11B10 = B[4:8, 0:4]
    └───┴───┘  B11 = B[4:8, 4:8]

A (M×K):           C (M×N):
┌───┬───┐          ┌───┬───┐
A00A01│          │C00C01C00 = A00×B00 + A01×B10
├───┼───┤          ├───┼───┤  C01 = A00×B01 + A01×B11
A10A11│          │C10C11C10 = A10×B00 + A11×B10
└───┴───┘          └───┴───┘  C11 = A10×B01 + A11×B11

タイルC00の計算過程:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Step 1: A004×4=16 float=64B)をキャッシュにロード
        B004×4=16 float=64B)をキャッシュにロード
        → 128Bロードで4×4=16回の内積計算
16×16=256回のMAC演算(キャッシュミスゼロ!)
Step 2: A01をキャッシュにロード(A00をevict)
        B10をキャッシュにロード(B00をevict)
256回のMAC演算(キャッシュミスゼロ!)
結果: C00完成
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

タイルMatMul実装

import numpy as np
import time

def tiled_matmul(A, B, tile_size=64):
    """
    キャッシュブロッキングを使った行列乗算
    tile_size: L1キャッシュサイズに合わせて調整
               L1=256KB、float32: sqrt(256K/4/3) ≈ 146
               実際には64-128が最適なことが多い
    """
    M, K = A.shape
    K2, N = B.shape
    assert K == K2
    C = np.zeros((M, N), dtype=A.dtype)

    for i in range(0, M, tile_size):
        i_end = min(i + tile_size, M)
        for j in range(0, N, tile_size):
            j_end = min(j + tile_size, N)
            for k in range(0, K, tile_size):
                k_end = min(k + tile_size, K)
                tile_A = A[i:i_end, k:k_end]  # キャッシュ常駐Aタイル
                tile_B = B[k:k_end, j:j_end]  # キャッシュ常駐Bタイル
                # このブロック内の演算はすべてキャッシュヒット!
                C[i:i_end, j:j_end] += tile_A @ tile_B
    return C


def benchmark():
    for N in [512, 1024, 2048]:
        A = np.random.randn(N, N).astype(np.float32)
        B = np.random.randn(N, N).astype(np.float32)

        start = time.perf_counter()
        for _ in range(5):
            _ = A @ B
        t = (time.perf_counter() - start) / 5

        tflops = 2 * N**3 / t / 1e12
        print(f"N={N}: {t*1000:.1f}ms, {tflops:.2f} TFLOPS")

benchmark()

性能比較:

N=4096、tile_size=64 vs ナイーブ(CPUFP32:
ナイーブ3重ループ:  ~900タイル3重ループ:    ~18     (50倍向上)
NumPy(BLAS:      ~0.8    (1125倍向上)
CUDA cuBLAS:        ~0.002  (450,000倍向上)

BLAS:40年の最適化の結晶

BLASレベル階層

BLAS(Basic Linear Algebra Subprograms、1979年〜):

Level 1: ベクトル-ベクトル演算(O(n)データ、O(n)演算)
  - dot(x, y):     内積
  - axpy(a, x, y): y = a*x + y
  - nrm2(x):       L2ノルム

Level 2: 行列-ベクトル演算(O()データ、O()演算)
  - gemv: y = alpha*A*x + beta*y
  → 演算密度低い → メモリバウンド

Level 3: 行列-行列演算(O()データ、O()演算)
  - gemm: C = alpha*A*B + beta*C  ← 最重要
  → 演算密度高い → コンピュートバウンド可能
  → ディープラーニングのコア演算

GEMM:汎用行列乗算

GEMMインターフェース:
C = alpha * op(A) * op(B) + beta * C

パラメータ:
  transa/transb: 転置フラグ(N=なし、T=転置、C=共役転置)
  m, n, k:       行列次元(Cはm×n、Aはm×k、Bはk×n)
  alpha, beta:   スカラー係数
  A, B, C:       行列ポインタ
  lda, ldb, ldc: 先行次元(メモリストライド)
/* cuBLAS例(NVIDIA GPU用BLAS)*/
#include <cublas_v2.h>

void gpu_matmul(float *d_A, float *d_B, float *d_C,
                int M, int N, int K) {
    cublasHandle_t handle;
    cublasCreate(&handle);

    float alpha = 1.0f;
    float beta  = 0.0f;

    /* cuBLASは列優先(column-major)格納を使用
       C = A*B(行優先)= (B^T * A^T)^T(列優先)*/
    cublasSgemm(
        handle,
        CUBLAS_OP_N, CUBLAS_OP_N,
        N, M, K,          // 注意: cuBLASはB,Aの順序!
        &alpha,
        d_B, N,           // B (N×K)、先行次元=N
        d_A, K,           // A (K×M)、先行次元=K
        &beta,
        d_C, N            // C (N×M)、先行次元=N
    );

    cublasDestroy(handle);
    /* cuBLASは理論最高性能の95%以上を達成
       数千 person-yearの最適化の結晶 */
}
# PyTorchはcuBLASを自動的に使用:
import torch

A = torch.randn(4096, 4096, device='cuda', dtype=torch.float16)
B = torch.randn(4096, 4096, device='cuda', dtype=torch.float16)

# 内部でcuBLAS HGEMM(半精度GEMM)を呼び出す
C = torch.mm(A, B)

# ベンチマーク:
import torch.utils.benchmark as benchmark
t = benchmark.Timer(
    stmt='torch.mm(A, B)',
    globals={'A': A, 'B': B}
)
result = t.timeit(100)
print(f"H100 FP16 GEMM 4096^2: {result.mean*1000:.2f}ms")
# 結果: ~0.5ms、~2000 TFLOPS(理論最大: 1979 TFLOPS)

テンソルコア:行列乗算専用ハードウェア

NVIDIA テンソルコア(H100:

通常のCUDAコア:
  1サイクルで: 1 FMA(Fused Multiply-Add)実行
  スループット: 1 FLOP/サイクル/コア

テンソルコア:
  1サイクルで: 4×8 × 8×8 行列乗算実行
  スループット: 256 FLOP/サイクル(通常コアの256倍!)
  対応フォーマット: FP16BF16INT8INT4FP8H100
H100 SXM5:
  CUDAコア:   16,896個 × 2 FP32 = ~67 TFLOPS
  テンソルコア: 528個  × 2048 FP16 = ~1,979 TFLOPS29.5倍)

ディープラーニングでFP16がデフォルトである理由:
テンソルコアの活用!

FlashAttention:メモリの壁を突破する

標準Attentionのメモリ複雑度問題

Transformerの最大のボトルネックはAttentionのメモリ複雑度だ。

import torch
import torch.nn.functional as F

def attention_standard(Q, K, V, scale=None):
    """
    標準的なAttention実装
    Q: (batch, heads, seq_len, d_k)
    K: (batch, heads, seq_len, d_k)
    V: (batch, heads, seq_len, d_v)
    """
    if scale is None:
        scale = Q.shape[-1] ** -0.5

    # Step 1: Q×K^T → (batch, heads, seq_len, seq_len)
    # seq_len=4096、d_k=128、32 headsの場合:
    # サイズ: 1 * 32 * 4096 * 4096 * 2バイト = 1GB!!!
    # HBMに保存する必要がある(遅い)
    scores = torch.matmul(Q, K.transpose(-2, -1)) * scale

    # Step 2: softmax → 再度HBM読み書き
    weights = F.softmax(scores, dim=-1)   # 1GB読み + 1GB書き

    # Step 3: weights×V → 再度HBM読み
    output = torch.matmul(weights, V)

    return output

# seq_len=4096、32 headsの場合:
# HBMアクセス量: scores(1GB) + softmax(2GB) + AV(1GB) = ~4GB
# H100 HBM帯域幅: 3.35 TB/s
# 理論最短時間: 4GB / 3.35T = ~1.2ms/Attentionレイヤー
# Attentionの時間のほとんどが計算ではなくメモリ転送!

Attentionは**メモリバウンド(memory-bound)**演算だ。GPUの計算ユニットはアイドル状態で、HBMアクセスがボトルネックとなる。

IO-AwareなAttention:FlashAttentionの洞察

Tri Daoら(2022年)の問い:「N×NのAttention行列を全てHBMに書く必要は本当にあるのか?」

GPU メモリ階層(H100:

SRAML1/共有メモリ、SM毎):
  サイズ:    ~228KB per SMH100132個のSM  帯域幅:    ~19 TB/s(HBM6倍!)
  遅延:      ~5サイクル
  ↕ 非常に高速

HBM(High Bandwidth Memory):
  サイズ:    80GB
  帯域幅:    ~3.35 TB/s
  遅延:      ~300サイクル
  ↕ 低速(相対的に)

FlashAttentionの戦略:
  N×NHBMに書かず、
  SRAMに収まるタイルで処理しよう!
FlashAttentionアルゴリズムの可視化:

HBMの入力:
  Q: (N, d) ─────────────────────────┐
  K: (N, d) ─────────────────────────┤(タイル単位で読み込み)
  V: (N, d) ─────────────────────────┘
SRAM(タイル処理):
  ┌────────────────────────────────────┐
Q_tile: (block_q, d)K_tile: (block_k, d)V_tile: (block_k, d)S_tile: (block_q, block_k)       │ ← HBMには絶対書かない!
  │  running_max: (block_q,)  │  running_sum: (block_q,)O_tile: (block_q, d)  └────────────────────────────────────┘
HBMの出力:
  O: (N, d) ← タイル完成後に書き込み

HBMアクセス回数:
  標準Attention:   O(N^2)N×N行列の読み書き)
  FlashAttention:  O(N)  (タイル単位の線形アクセス)

オンラインSoftmax:数学的な鍵

FlashAttentionのコアトリックはオンラインSoftmaxだ。シーケンス全体を見ずに正確なSoftmaxを計算できる。

import torch
import math

def flash_attention_simplified(Q, K, V, block_size=64):
    """
    FlashAttentionのコアロジック(簡略版)
    Q, K, V: (N, d) — 単一ヘッド、バッチなし
    """
    N, d = Q.shape
    scale = 1.0 / math.sqrt(d)
    dtype = Q.dtype

    # 累算器の初期化
    O = torch.zeros(N, d, dtype=dtype, device=Q.device)
    L = torch.zeros(N, dtype=dtype, device=Q.device)       # 正規化係数
    M = torch.full((N,), float('-inf'), dtype=dtype,
                   device=Q.device)                         # 実行中の最大値

    # オンラインSoftmaxの数学:
    # 新しいタイルx_newが到着した場合:
    #   m_new = max(m_old, max(x_new))
    #   l_new = exp(m_old-m_new)*l_old + sum(exp(x_new-m_new))
    #   O_new = (exp(m_old-m_new)*l_old*O_old + exp(x_new-m_new)@V_new) / l_new

    # 外側ループ: K, Vタイル
    for j_start in range(0, N, block_size):
        j_end = min(j_start + block_size, N)
        K_j = K[j_start:j_end]   # SRAMにロード
        V_j = V[j_start:j_end]   # SRAMにロード

        # 内側ループ: Qタイル
        for i_start in range(0, N, block_size):
            i_end = min(i_start + block_size, N)
            Q_i = Q[i_start:i_end]   # SRAMにロード

            # Attentionスコア(SRAM内で計算!)
            S_ij = (Q_i @ K_j.T) * scale  # (block_q, block_k)

            # 実行中最大値の更新
            m_i_old = M[i_start:i_end]
            m_ij = S_ij.max(dim=-1).values
            m_i_new = torch.maximum(m_i_old, m_ij)

            # 補正係数を使った実行中合計の更新
            correction = torch.exp(m_i_old - m_i_new)
            l_i_new = (correction * L[i_start:i_end] +
                       torch.exp(S_ij - m_i_new.unsqueeze(-1)).sum(dim=-1))

            # 出力の更新
            P_ij = torch.exp(S_ij - m_i_new.unsqueeze(-1))
            O[i_start:i_end] = (
                (correction * L[i_start:i_end]).unsqueeze(-1) * O[i_start:i_end]
                + P_ij @ V_j
            ) / l_i_new.unsqueeze(-1)

            M[i_start:i_end] = m_i_new
            L[i_start:i_end] = l_i_new

    return O  # HBMへの書き込みは1回だけ!


# FlashAttentionの性能数値(A100 80GB、seq_len=2048):
# 標準Attention:          12.3ms
# FlashAttention v1:       3.4ms(3.6倍高速)
# FlashAttention v2:       2.1ms(5.9倍高速)
# FlashAttention v3(H100): 1.4ms(8.8倍高速)

GEMM性能分析:ルーフラインモデル

コンピュートバウンドかメモリバウンドか?

ルーフラインモデル:
  達成可能性能 = min(コンピート最大性能,
                    メモリ帯域幅 × 演算密度)

  演算密度(Arithmetic Intensity、AI:
    AI = FLOPS / BYTES_ACCESSEDFLOP/Byte)

H100 SXM5パラメータ:
  FP16 テンソルコア: 1,979 TFLOPS
  HBM帯域幅:         3,350 GB/s
  リッジポイント: 1979T / 3.35T = 590 FLOP/Byte
AI > 590: コンピュートバウンド
AI < 590: メモリバウンド
def analyze_matmul_intensity(M, K, N, dtype_bytes=2):
    """
    (M×K) × (K×N) = (M×N) の演算密度を計算
    """
    flops = 2 * M * K * N  # 乗算 + 加算

    # メモリアクセス: A読み + B読み + C書き
    bytes_accessed = (M * K + K * N + M * N) * dtype_bytes

    ai = flops / bytes_accessed
    return flops, bytes_accessed, ai


# ケース1: 大規模行列(学習時のバッチ処理)
M = N = K = 4096
flops, bytes_val, ai = analyze_matmul_intensity(M, K, N)
print(f"大規模GEMM(4096^3):")
print(f"  FLOPs:  {flops/1e12:.1f} TFLOPS")
print(f"  Bytes:  {bytes_val/1e6:.0f} MB")
print(f"  AI:     {ai:.0f} FLOP/Byte")
print(f"  H100リッジ: 590 FLOP/Byte")
print(f"  状態: {'コンピュートバウンド' if ai > 590 else 'メモリバウンド'}")
# → AI=1370、コンピュートバウンド ✅

print()

# ケース2: 推論(バッチサイズ=1)
M, K, N = 1, 4096, 4096
flops, bytes_val, ai = analyze_matmul_intensity(M, K, N)
print(f"推論GEMV(batch=1):")
print(f"  FLOPs:  {flops/1e6:.1f} MFLOPS")
print(f"  Bytes:  {bytes_val/1e6:.0f} MB(ほぼB行列全体!)")
print(f"  AI:     {ai:.2f} FLOP/Byte")
print(f"  状態: {'コンピュートバウンド' if ai > 590 else 'メモリバウンド'}")
# → AI=0.99、メモリバウンド ❌
# これがLLM推論の最適化が難しい根本的な理由!
実際の性能測定(H100FP164096×4096重み行列):

バッチサイズ別のGEMM性能:
┌───────────┬────────────┬────────────┬──────────────────┐
バッチ(M)TFLOPSAI         │ 状態             │
├───────────┼────────────┼────────────┼──────────────────┤
10.016~1         │ メモリバウンド ❌│
40.063~4         │ メモリバウンド ❌│
160.25~16        │ メモリバウンド   │
640.98~63        │ メモリバウンド   │
2563.8~250       │ 移行ゾーン       │
102414.2~1000      │ コンピュートバウンド ✅│
40961,847~1370      │ コンピュートバウンド ✅│
└───────────┴────────────┴────────────┴──────────────────┘

→ バッチが大きいほどGPU使用率が向上
→ これがLLMサービングでバッチングが重要な理由

実用的なCUDA GEMM最適化のヒント

import torch

# ヒント1: TF32の有効化(Ampere以降のFP32代替)
# FP32: 23ビット仮数 → TF32: 10ビット仮数(若干の精度低下)
# しかしテンソルコアが使用可能 → 約10倍高速
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# ヒント2: 連続メモリのテンソルを使用
x = torch.randn(64, 128, 128, device='cuda')
x_t = x.transpose(1, 2)          # 非連続!
x_t_cont = x_t.contiguous()      # 連続メモリにコピー

# ヒント3: 行列次元を16の倍数に揃える
# H100テンソルコアは16×nサイズで最適動作
def pad_to_multiple(x, multiple=16):
    d = x.shape[-1]
    if d % multiple == 0:
        return x
    pad_size = multiple - (d % multiple)
    return torch.nn.functional.pad(x, (0, pad_size))

# ヒント4: torch.compile(PyTorch 2.0以降)
@torch.compile
def optimized_linear(x, weight, bias=None):
    """
    torch.compileがカーネルフュージョン、
    最適なメモリレイアウト選択を自動実行
    """
    return torch.nn.functional.linear(x, weight, bias)

全体像:行列乗算最適化のスタック

現代のディープラーニング行列乗算最適化スタック:

アプリケーションレイヤー:
  PyTorch / JAX / TensorFlow
  ↓(torch.mm、F.linear等)

カーネルフュージョン / コンパイラ:
  torch.compile(TorchInductor)
  XLAJAX  ↓(最適化されたCUDAコードを生成)

高水準ライブラリ:
  cuDNN(ニューラルネット特化)
  cuBLAS(汎用GEMM  CUTLASSNVIDIAオープンソーステンプレート)
  ↓(最適化されたPTX/SASSコード)

ハードウェア:
  テンソルコア(H100: 1,979 TFLOPS FP16
メモリシステム:
  L1/L2キャッシュ(SRAM  HBM33.35 TB/s)

FlashAttentionはこのスタックの「高水準ライブラリ」レイヤーで、
IO-awareアルゴリズムによりHBMアクセスを最小化している。

行列乗算はシンプルに見えるが、その最適化の深さは40年の歴史と数千人のエンジニアリングの努力を内包している。FlashAttentionシリーズのように、アルゴリズムレベルの革新がハードウェア改善と同等以上のインパクトを持つ時代となった。

LLMインフラエンジニアにとって、このスタック全体を理解することは必須だ。次の記事では、この基盤の上で実際のLLMサービング最適化がどのように行われるかを解説する:KVキャッシュ、PagedAttention、連続バッチング、そして量子化について深掘りしていく。