- Authors

- Name
- Youngju Kim
- @fjvbn20031
- なぜ行列乗算がディープラーニングのすべてなのか
- ナイーブな行列乗算とその問題点
- キャッシュブロッキング(タイリング):解決策
- BLAS:40年の最適化の結晶
- FlashAttention:メモリの壁を突破する
- GEMM性能分析:ルーフラインモデル
- 全体像:行列乗算最適化のスタック
なぜ行列乗算がディープラーニングのすべてなのか
GPUがディープラーニングモデルを実行する際、実際の作業の80%以上は単一の演算に帰着する。それが**行列乗算(Matrix Multiplication)**だ。
各レイヤーを分解して確認しよう:
Linear Layer(全結合層)
output = W × input + b
shape: (out_features,) = (out_features, in_features) × (in_features,)
最も基本的な線形変換。純粋な行列-ベクトル乗算だ。
Self-Attention(Transformerの核心)
Q = W_q × x # 行列乗算
K = W_k × x # 行列乗算
V = W_v × x # 行列乗算
scores = Q × K^T # 行列乗算 (N×d × d×N = N×N)
weights = softmax(scores / sqrt(d_k))
output = weights × V # 行列乗算 (N×N × N×d = N×d)
1回のAttentionレイヤーで最低6回の行列乗算が発生する。
Convolution(CNN)
im2col変換後:
output = W × im2col(input)
畳み込みも最終的には行列乗算に変換される。
GPT-3(175B)のFLOP分布:
Linearレイヤー(QKVプロジェクション): ~45%
Feed-forwardレイヤー: ~35%
Attention計算(QK^T、AV): ~15%
その他(LayerNorm、softmax等): ~5%
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
行列乗算合計: ~95%
行列乗算をどれだけ高速化できるか = ディープラーニングの推論・学習速度が決まる。これが40年間、数千人のエンジニアが行列乗算の最適化に取り組んできた理由だ。
ナイーブな行列乗算とその問題点
3重ループ:定義そのもの
import numpy as np
def naive_matmul(A, B):
"""
A: (M, K) B: (K, N) -> C: (M, N)
C[i][j] = sum(A[i][k] * B[k][j] for k in range(K))
"""
M, K = A.shape
K2, N = B.shape
assert K == K2, "内側の次元が一致している必要があります"
C = np.zeros((M, N), dtype=np.float64)
for i in range(M): # M回反復
for j in range(N): # N回反復
for k in range(K): # K回反復
C[i][j] += A[i][k] * B[k][j] # MAC演算1回
return C
# 総演算数: M * N * K 回の multiply-add
# 計算量: O(n^3)
# 4096×4096行列: 4096^3 = 約680億回の演算!
数学的には完璧だが、実際に動かすと非常に遅い。
4096×4096 FP32行列乗算の場合:
- 理論演算量:2 × 4096³ ≈ 137 GFLOPS
- H100の理論性能:67 TFLOPS(FP32)
- 理論時間:137G / 67T ≈ 2ms
- ナイーブPythonの実際の時間:約900秒(450,000倍遅い!)
差異の原因は2つ:メモリアクセスパターンと並列性の未活用。
キャッシュミス:真の敵
メモリ階層(H100基準):
┌─────────────────────────────────────────────────────────┐
│ L1 Cache (SM毎): 256KB, ~5サイクル 遅延 │
│ L2 Cache (共有): 50MB, ~30サイクル遅延 │
│ HBM3 (GPUメインメモリ): 80GB, ~300サイクル遅延 │
│ CPU DRAM: 数TB, ~1000サイクル遅延 │
└─────────────────────────────────────────────────────────┘
キャッシュアクセス: 300サイクル ↔ 5サイクル = 60倍の差
キャッシュヒット率 = 性能
ナイーブな3重ループのメモリアクセスパターン(A[4×4], B[4×4]の例):
A行列(行優先格納、row-major):
メモリアドレス: [A00][A01][A02][A03] [A10][A11][A12][A13] ...
←─ キャッシュライン1 ─→ ←─ キャッシュライン2 ─→
A[i][k]アクセスパターン(i=0固定、k=0,1,2,3順):
A[0][0] → キャッシュライン1ロード(キャッシュMISS、300サイクル)
A[0][1] → キャッシュライン1再利用(キャッシュHIT、5サイクル)✅
A[0][2] → キャッシュライン1再利用(キャッシュHIT)✅
A[0][3] → キャッシュライン1再利用(キャッシュHIT)✅
→ Aは行方向アクセス:キャッシュフレンドリー ✅
B行列の列方向アクセス(j=0固定、k=0,1,2,3順):
B[0][0] → キャッシュライン0ロード(キャッシュMISS)
B[1][0] → キャッシュライン2ロード(キャッシュMISS! 別の行)❌
B[2][0] → キャッシュライン4ロード(キャッシュMISS!)❌
B[3][0] → キャッシュライン6ロード(キャッシュMISS!)❌
Bの各要素が別々のキャッシュラインに存在
大行列(N=4096)ではBアクセスの99%がキャッシュミス!
キャッシュブロッキング(タイリング):解決策
コアアイデア:キャッシュに収まるブロックに分割
行列をキャッシュに収まる小さなタイルに分割し、一度ロードしたデータを最大限再利用する。
行列タイリングの可視化(M=8, N=8, K=8, TILE_SIZE=4):
すべての行列を4×4タイルに分割:
B (K×N):
┌───┬───┐
│B00│B01│ B00 = B[0:4, 0:4]
├───┼───┤ B01 = B[0:4, 4:8]
│B10│B11│ B10 = B[4:8, 0:4]
└───┴───┘ B11 = B[4:8, 4:8]
A (M×K): C (M×N):
┌───┬───┐ ┌───┬───┐
│A00│A01│ │C00│C01│ C00 = A00×B00 + A01×B10
├───┼───┤ ├───┼───┤ C01 = A00×B01 + A01×B11
│A10│A11│ │C10│C11│ C10 = A10×B00 + A11×B10
└───┴───┘ └───┴───┘ C11 = A10×B01 + A11×B11
タイルC00の計算過程:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Step 1: A00(4×4=16 float=64B)をキャッシュにロード
B00(4×4=16 float=64B)をキャッシュにロード
→ 128Bロードで4×4=16回の内積計算
→ 16×16=256回のMAC演算(キャッシュミスゼロ!)
Step 2: A01をキャッシュにロード(A00をevict)
B10をキャッシュにロード(B00をevict)
→ 256回のMAC演算(キャッシュミスゼロ!)
結果: C00完成
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
タイルMatMul実装
import numpy as np
import time
def tiled_matmul(A, B, tile_size=64):
"""
キャッシュブロッキングを使った行列乗算
tile_size: L1キャッシュサイズに合わせて調整
L1=256KB、float32: sqrt(256K/4/3) ≈ 146
実際には64-128が最適なことが多い
"""
M, K = A.shape
K2, N = B.shape
assert K == K2
C = np.zeros((M, N), dtype=A.dtype)
for i in range(0, M, tile_size):
i_end = min(i + tile_size, M)
for j in range(0, N, tile_size):
j_end = min(j + tile_size, N)
for k in range(0, K, tile_size):
k_end = min(k + tile_size, K)
tile_A = A[i:i_end, k:k_end] # キャッシュ常駐Aタイル
tile_B = B[k:k_end, j:j_end] # キャッシュ常駐Bタイル
# このブロック内の演算はすべてキャッシュヒット!
C[i:i_end, j:j_end] += tile_A @ tile_B
return C
def benchmark():
for N in [512, 1024, 2048]:
A = np.random.randn(N, N).astype(np.float32)
B = np.random.randn(N, N).astype(np.float32)
start = time.perf_counter()
for _ in range(5):
_ = A @ B
t = (time.perf_counter() - start) / 5
tflops = 2 * N**3 / t / 1e12
print(f"N={N}: {t*1000:.1f}ms, {tflops:.2f} TFLOPS")
benchmark()
性能比較:
N=4096、tile_size=64 vs ナイーブ(CPU、FP32):
ナイーブ3重ループ: ~900秒
タイル3重ループ: ~18秒 (50倍向上)
NumPy(BLAS): ~0.8秒 (1125倍向上)
CUDA cuBLAS: ~0.002秒 (450,000倍向上)
BLAS:40年の最適化の結晶
BLASレベル階層
BLAS(Basic Linear Algebra Subprograms、1979年〜):
Level 1: ベクトル-ベクトル演算(O(n)データ、O(n)演算)
- dot(x, y): 内積
- axpy(a, x, y): y = a*x + y
- nrm2(x): L2ノルム
Level 2: 行列-ベクトル演算(O(n²)データ、O(n²)演算)
- gemv: y = alpha*A*x + beta*y
→ 演算密度低い → メモリバウンド
Level 3: 行列-行列演算(O(n²)データ、O(n³)演算)
- gemm: C = alpha*A*B + beta*C ← 最重要
→ 演算密度高い → コンピュートバウンド可能
→ ディープラーニングのコア演算
GEMM:汎用行列乗算
GEMMインターフェース:
C = alpha * op(A) * op(B) + beta * C
パラメータ:
transa/transb: 転置フラグ(N=なし、T=転置、C=共役転置)
m, n, k: 行列次元(Cはm×n、Aはm×k、Bはk×n)
alpha, beta: スカラー係数
A, B, C: 行列ポインタ
lda, ldb, ldc: 先行次元(メモリストライド)
/* cuBLAS例(NVIDIA GPU用BLAS)*/
#include <cublas_v2.h>
void gpu_matmul(float *d_A, float *d_B, float *d_C,
int M, int N, int K) {
cublasHandle_t handle;
cublasCreate(&handle);
float alpha = 1.0f;
float beta = 0.0f;
/* cuBLASは列優先(column-major)格納を使用
C = A*B(行優先)= (B^T * A^T)^T(列優先)*/
cublasSgemm(
handle,
CUBLAS_OP_N, CUBLAS_OP_N,
N, M, K, // 注意: cuBLASはB,Aの順序!
&alpha,
d_B, N, // B (N×K)、先行次元=N
d_A, K, // A (K×M)、先行次元=K
&beta,
d_C, N // C (N×M)、先行次元=N
);
cublasDestroy(handle);
/* cuBLASは理論最高性能の95%以上を達成
数千 person-yearの最適化の結晶 */
}
# PyTorchはcuBLASを自動的に使用:
import torch
A = torch.randn(4096, 4096, device='cuda', dtype=torch.float16)
B = torch.randn(4096, 4096, device='cuda', dtype=torch.float16)
# 内部でcuBLAS HGEMM(半精度GEMM)を呼び出す
C = torch.mm(A, B)
# ベンチマーク:
import torch.utils.benchmark as benchmark
t = benchmark.Timer(
stmt='torch.mm(A, B)',
globals={'A': A, 'B': B}
)
result = t.timeit(100)
print(f"H100 FP16 GEMM 4096^2: {result.mean*1000:.2f}ms")
# 結果: ~0.5ms、~2000 TFLOPS(理論最大: 1979 TFLOPS)
テンソルコア:行列乗算専用ハードウェア
NVIDIA テンソルコア(H100):
通常のCUDAコア:
1サイクルで: 1 FMA(Fused Multiply-Add)実行
スループット: 1 FLOP/サイクル/コア
テンソルコア:
1サイクルで: 4×8 × 8×8 行列乗算実行
スループット: 256 FLOP/サイクル(通常コアの256倍!)
対応フォーマット: FP16、BF16、INT8、INT4、FP8(H100)
H100 SXM5:
CUDAコア: 16,896個 × 2 FP32 = ~67 TFLOPS
テンソルコア: 528個 × 2048 FP16 = ~1,979 TFLOPS(29.5倍)
ディープラーニングでFP16がデフォルトである理由:
テンソルコアの活用!
FlashAttention:メモリの壁を突破する
標準Attentionのメモリ複雑度問題
Transformerの最大のボトルネックはAttentionのメモリ複雑度だ。
import torch
import torch.nn.functional as F
def attention_standard(Q, K, V, scale=None):
"""
標準的なAttention実装
Q: (batch, heads, seq_len, d_k)
K: (batch, heads, seq_len, d_k)
V: (batch, heads, seq_len, d_v)
"""
if scale is None:
scale = Q.shape[-1] ** -0.5
# Step 1: Q×K^T → (batch, heads, seq_len, seq_len)
# seq_len=4096、d_k=128、32 headsの場合:
# サイズ: 1 * 32 * 4096 * 4096 * 2バイト = 1GB!!!
# HBMに保存する必要がある(遅い)
scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
# Step 2: softmax → 再度HBM読み書き
weights = F.softmax(scores, dim=-1) # 1GB読み + 1GB書き
# Step 3: weights×V → 再度HBM読み
output = torch.matmul(weights, V)
return output
# seq_len=4096、32 headsの場合:
# HBMアクセス量: scores(1GB) + softmax(2GB) + AV(1GB) = ~4GB
# H100 HBM帯域幅: 3.35 TB/s
# 理論最短時間: 4GB / 3.35T = ~1.2ms/Attentionレイヤー
# Attentionの時間のほとんどが計算ではなくメモリ転送!
Attentionは**メモリバウンド(memory-bound)**演算だ。GPUの計算ユニットはアイドル状態で、HBMアクセスがボトルネックとなる。
IO-AwareなAttention:FlashAttentionの洞察
Tri Daoら(2022年)の問い:「N×NのAttention行列を全てHBMに書く必要は本当にあるのか?」
GPU メモリ階層(H100):
SRAM(L1/共有メモリ、SM毎):
サイズ: ~228KB per SM(H100は132個のSM)
帯域幅: ~19 TB/s(HBMの6倍!)
遅延: ~5サイクル
↕ 非常に高速
HBM(High Bandwidth Memory):
サイズ: 80GB
帯域幅: ~3.35 TB/s
遅延: ~300サイクル
↕ 低速(相対的に)
FlashAttentionの戦略:
N×NをHBMに書かず、
SRAMに収まるタイルで処理しよう!
FlashAttentionアルゴリズムの可視化:
HBMの入力:
Q: (N, d) ─────────────────────────┐
K: (N, d) ─────────────────────────┤(タイル単位で読み込み)
V: (N, d) ─────────────────────────┘
↓
SRAM(タイル処理):
┌────────────────────────────────────┐
│ Q_tile: (block_q, d) │
│ K_tile: (block_k, d) │
│ V_tile: (block_k, d) │
│ S_tile: (block_q, block_k) │ ← HBMには絶対書かない!
│ running_max: (block_q,) │
│ running_sum: (block_q,) │
│ O_tile: (block_q, d) │
└────────────────────────────────────┘
↓
HBMの出力:
O: (N, d) ← タイル完成後に書き込み
HBMアクセス回数:
標準Attention: O(N^2)(N×N行列の読み書き)
FlashAttention: O(N) (タイル単位の線形アクセス)
オンラインSoftmax:数学的な鍵
FlashAttentionのコアトリックはオンラインSoftmaxだ。シーケンス全体を見ずに正確なSoftmaxを計算できる。
import torch
import math
def flash_attention_simplified(Q, K, V, block_size=64):
"""
FlashAttentionのコアロジック(簡略版)
Q, K, V: (N, d) — 単一ヘッド、バッチなし
"""
N, d = Q.shape
scale = 1.0 / math.sqrt(d)
dtype = Q.dtype
# 累算器の初期化
O = torch.zeros(N, d, dtype=dtype, device=Q.device)
L = torch.zeros(N, dtype=dtype, device=Q.device) # 正規化係数
M = torch.full((N,), float('-inf'), dtype=dtype,
device=Q.device) # 実行中の最大値
# オンラインSoftmaxの数学:
# 新しいタイルx_newが到着した場合:
# m_new = max(m_old, max(x_new))
# l_new = exp(m_old-m_new)*l_old + sum(exp(x_new-m_new))
# O_new = (exp(m_old-m_new)*l_old*O_old + exp(x_new-m_new)@V_new) / l_new
# 外側ループ: K, Vタイル
for j_start in range(0, N, block_size):
j_end = min(j_start + block_size, N)
K_j = K[j_start:j_end] # SRAMにロード
V_j = V[j_start:j_end] # SRAMにロード
# 内側ループ: Qタイル
for i_start in range(0, N, block_size):
i_end = min(i_start + block_size, N)
Q_i = Q[i_start:i_end] # SRAMにロード
# Attentionスコア(SRAM内で計算!)
S_ij = (Q_i @ K_j.T) * scale # (block_q, block_k)
# 実行中最大値の更新
m_i_old = M[i_start:i_end]
m_ij = S_ij.max(dim=-1).values
m_i_new = torch.maximum(m_i_old, m_ij)
# 補正係数を使った実行中合計の更新
correction = torch.exp(m_i_old - m_i_new)
l_i_new = (correction * L[i_start:i_end] +
torch.exp(S_ij - m_i_new.unsqueeze(-1)).sum(dim=-1))
# 出力の更新
P_ij = torch.exp(S_ij - m_i_new.unsqueeze(-1))
O[i_start:i_end] = (
(correction * L[i_start:i_end]).unsqueeze(-1) * O[i_start:i_end]
+ P_ij @ V_j
) / l_i_new.unsqueeze(-1)
M[i_start:i_end] = m_i_new
L[i_start:i_end] = l_i_new
return O # HBMへの書き込みは1回だけ!
# FlashAttentionの性能数値(A100 80GB、seq_len=2048):
# 標準Attention: 12.3ms
# FlashAttention v1: 3.4ms(3.6倍高速)
# FlashAttention v2: 2.1ms(5.9倍高速)
# FlashAttention v3(H100): 1.4ms(8.8倍高速)
GEMM性能分析:ルーフラインモデル
コンピュートバウンドかメモリバウンドか?
ルーフラインモデル:
達成可能性能 = min(コンピート最大性能,
メモリ帯域幅 × 演算密度)
演算密度(Arithmetic Intensity、AI):
AI = FLOPS / BYTES_ACCESSED(FLOP/Byte)
H100 SXM5パラメータ:
FP16 テンソルコア: 1,979 TFLOPS
HBM帯域幅: 3,350 GB/s
リッジポイント: 1979T / 3.35T = 590 FLOP/Byte
→ AI > 590: コンピュートバウンド
→ AI < 590: メモリバウンド
def analyze_matmul_intensity(M, K, N, dtype_bytes=2):
"""
(M×K) × (K×N) = (M×N) の演算密度を計算
"""
flops = 2 * M * K * N # 乗算 + 加算
# メモリアクセス: A読み + B読み + C書き
bytes_accessed = (M * K + K * N + M * N) * dtype_bytes
ai = flops / bytes_accessed
return flops, bytes_accessed, ai
# ケース1: 大規模行列(学習時のバッチ処理)
M = N = K = 4096
flops, bytes_val, ai = analyze_matmul_intensity(M, K, N)
print(f"大規模GEMM(4096^3):")
print(f" FLOPs: {flops/1e12:.1f} TFLOPS")
print(f" Bytes: {bytes_val/1e6:.0f} MB")
print(f" AI: {ai:.0f} FLOP/Byte")
print(f" H100リッジ: 590 FLOP/Byte")
print(f" 状態: {'コンピュートバウンド' if ai > 590 else 'メモリバウンド'}")
# → AI=1370、コンピュートバウンド ✅
print()
# ケース2: 推論(バッチサイズ=1)
M, K, N = 1, 4096, 4096
flops, bytes_val, ai = analyze_matmul_intensity(M, K, N)
print(f"推論GEMV(batch=1):")
print(f" FLOPs: {flops/1e6:.1f} MFLOPS")
print(f" Bytes: {bytes_val/1e6:.0f} MB(ほぼB行列全体!)")
print(f" AI: {ai:.2f} FLOP/Byte")
print(f" 状態: {'コンピュートバウンド' if ai > 590 else 'メモリバウンド'}")
# → AI=0.99、メモリバウンド ❌
# これがLLM推論の最適化が難しい根本的な理由!
実際の性能測定(H100、FP16、4096×4096重み行列):
バッチサイズ別のGEMM性能:
┌───────────┬────────────┬────────────┬──────────────────┐
│ バッチ(M) │ TFLOPS │ AI │ 状態 │
├───────────┼────────────┼────────────┼──────────────────┤
│ 1 │ 0.016 │ ~1 │ メモリバウンド ❌│
│ 4 │ 0.063 │ ~4 │ メモリバウンド ❌│
│ 16 │ 0.25 │ ~16 │ メモリバウンド │
│ 64 │ 0.98 │ ~63 │ メモリバウンド │
│ 256 │ 3.8 │ ~250 │ 移行ゾーン │
│ 1024 │ 14.2 │ ~1000 │ コンピュートバウンド ✅│
│ 4096 │ 1,847 │ ~1370 │ コンピュートバウンド ✅│
└───────────┴────────────┴────────────┴──────────────────┘
→ バッチが大きいほどGPU使用率が向上
→ これがLLMサービングでバッチングが重要な理由
実用的なCUDA GEMM最適化のヒント
import torch
# ヒント1: TF32の有効化(Ampere以降のFP32代替)
# FP32: 23ビット仮数 → TF32: 10ビット仮数(若干の精度低下)
# しかしテンソルコアが使用可能 → 約10倍高速
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# ヒント2: 連続メモリのテンソルを使用
x = torch.randn(64, 128, 128, device='cuda')
x_t = x.transpose(1, 2) # 非連続!
x_t_cont = x_t.contiguous() # 連続メモリにコピー
# ヒント3: 行列次元を16の倍数に揃える
# H100テンソルコアは16×nサイズで最適動作
def pad_to_multiple(x, multiple=16):
d = x.shape[-1]
if d % multiple == 0:
return x
pad_size = multiple - (d % multiple)
return torch.nn.functional.pad(x, (0, pad_size))
# ヒント4: torch.compile(PyTorch 2.0以降)
@torch.compile
def optimized_linear(x, weight, bias=None):
"""
torch.compileがカーネルフュージョン、
最適なメモリレイアウト選択を自動実行
"""
return torch.nn.functional.linear(x, weight, bias)
全体像:行列乗算最適化のスタック
現代のディープラーニング行列乗算最適化スタック:
アプリケーションレイヤー:
PyTorch / JAX / TensorFlow
↓(torch.mm、F.linear等)
カーネルフュージョン / コンパイラ:
torch.compile(TorchInductor)
XLA(JAX)
↓(最適化されたCUDAコードを生成)
高水準ライブラリ:
cuDNN(ニューラルネット特化)
cuBLAS(汎用GEMM)
CUTLASS(NVIDIAオープンソーステンプレート)
↓(最適化されたPTX/SASSコード)
ハードウェア:
テンソルコア(H100: 1,979 TFLOPS FP16)
↓
メモリシステム:
L1/L2キャッシュ(SRAM)
HBM3(3.35 TB/s)
FlashAttentionはこのスタックの「高水準ライブラリ」レイヤーで、
IO-awareアルゴリズムによりHBMアクセスを最小化している。
行列乗算はシンプルに見えるが、その最適化の深さは40年の歴史と数千人のエンジニアリングの努力を内包している。FlashAttentionシリーズのように、アルゴリズムレベルの革新がハードウェア改善と同等以上のインパクトを持つ時代となった。
LLMインフラエンジニアにとって、このスタック全体を理解することは必須だ。次の記事では、この基盤の上で実際のLLMサービング最適化がどのように行われるかを解説する:KVキャッシュ、PagedAttention、連続バッチング、そして量子化について深掘りしていく。