- Authors

- Name
- Youngju Kim
- @fjvbn20031
- Why Matrix Multiplication Is All of Deep Learning
- The Naive Implementation and Its Problems
- Cache Blocking (Tiling): The Fix
- BLAS: Four Decades of Optimization, Crystallized
- FlashAttention: Breaking Through the Memory Wall
- GEMM Performance Analysis: The Roofline Model
- The Full Picture: Matrix Multiply Optimization Stack
Why Matrix Multiplication Is All of Deep Learning
When a GPU runs a deep learning model, more than 80% of the actual work reduces to a single operation: matrix multiplication.
If that sounds hard to believe, let's dissect each layer type:
Linear Layer (Fully Connected Layer)
output = W × input + b
shape: (out_features,) = (out_features, in_features) × (in_features,)
The most fundamental transformation in neural networks. Pure matrix-vector multiplication.
Self-Attention (The Heart of Transformers)
Q = W_q × x # matrix multiply
K = W_k × x # matrix multiply
V = W_v × x # matrix multiply
scores = Q × K^T # matrix multiply (N×d × d×N = N×N)
weights = softmax(scores / sqrt(d_k))
output = weights × V # matrix multiply (N×N × N×d = N×d)
A single attention layer fires at least 6 matrix multiplications.
Convolution (CNN)
after im2col transform:
output = W × im2col(input)
Even convolution ultimately gets rewritten as matrix multiplication.
FLOP distribution in GPT-3 (175B parameters):
Linear layers (QKV projections): ~45%
Feed-forward layers: ~35%
Attention computation (QK^T, AV): ~15%
Others (LayerNorm, softmax, etc.): ~5%
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Matrix multiplication total: ~95%
How fast you can multiply matrices = how fast training and inference runs. This is why thousands of engineers have spent four decades optimizing this one operation.
The Naive Implementation and Its Problems
Triple Loop: The Definition Itself
import numpy as np
def naive_matmul(A, B):
"""
A: (M, K) B: (K, N) -> C: (M, N)
C[i][j] = sum(A[i][k] * B[k][j] for k in range(K))
"""
M, K = A.shape
K2, N = B.shape
assert K == K2, "Inner dimensions must match"
C = np.zeros((M, N), dtype=np.float64)
for i in range(M): # M iterations
for j in range(N): # N iterations
for k in range(K): # K iterations
C[i][j] += A[i][k] * B[k][j] # 1 MAC operation
return C
# Total ops: M * N * K multiply-add operations
# Complexity: O(n^3)
# For 4096×4096: 4096^3 ≈ 68 BILLION operations!
Mathematically perfect. In practice, horrifyingly slow.
4096×4096 FP32 matrix multiply:
- Theoretical FLOP count: 2 × 4096³ ≈ 137 GFLOPS
- H100 theoretical throughput: 67 TFLOPS (FP32)
- Theoretical time: 137G / 67T ≈ 2ms
- Naive Python actual time: ~900 seconds (450,000× slower!)
Two root causes: memory access patterns and unused parallelism.
Cache Misses: The Real Enemy
Memory hierarchy (H100):
┌─────────────────────────────────────────────────────────┐
│ L1 Cache (per SM): 256KB, ~5 cycle latency │
│ L2 Cache (shared): 50MB, ~30 cycle latency │
│ HBM3 (GPU main mem): 80GB, ~300 cycle latency │
│ CPU DRAM: TB-scale, ~1000 cycle latency │
└─────────────────────────────────────────────────────────┘
Cache hit vs miss: 300 cycles ↔ 5 cycles = 60× difference
Cache hit rate IS performance.
Examine the memory access pattern in the naive triple loop (A[4×4], B[4×4]):
A matrix (row-major storage):
addresses: [A00][A01][A02][A03] [A10][A11][A12][A13] ...
←── cache line 1 ─→ ←── cache line 2 ─→
A[i][k] access pattern (i=0 fixed, k=0,1,2,3):
A[0][0] → load cache line 1 (cache MISS, 300 cycles)
A[0][1] → reuse cache line 1 (cache HIT, 5 cycles) ✅
A[0][2] → reuse cache line 1 (cache HIT) ✅
A[0][3] → reuse cache line 1 (cache HIT) ✅
→ A accessed row-wise: cache-FRIENDLY ✅
B matrix column-wise access (j=0 fixed, k=0,1,2,3):
B[0][0] → load cache line 0 (cache MISS)
B[1][0] → load cache line 2 (cache MISS! different row) ❌
B[2][0] → load cache line 4 (cache MISS!) ❌
B[3][0] → load cache line 6 (cache MISS!) ❌
Each element of B sits on a different cache line.
For large matrices (N=4096): 99% of B accesses are cache misses!
Cache Blocking (Tiling): The Fix
Core Idea: Divide Into Cache-Resident Tiles
Split matrices into tiles that fit in cache so data loaded once gets reused maximally.
Matrix Tiling Visualization (M=8, N=8, K=8, TILE_SIZE=4):
Divide all matrices into 4×4 tiles:
B (K×N):
┌───┬───┐
│B00│B01│ B00 = B[0:4, 0:4]
├───┼───┤ B01 = B[0:4, 4:8]
│B10│B11│ B10 = B[4:8, 0:4]
└───┴───┘ B11 = B[4:8, 4:8]
A (M×K): C (M×N):
┌───┬───┐ ┌───┬───┐
│A00│A01│ │C00│C01│ C00 = A00×B00 + A01×B10
├───┼───┤ ├───┼───┤ C01 = A00×B01 + A01×B11
│A10│A11│ │C10│C11│ C10 = A10×B00 + A11×B10
└───┴───┘ └───┴───┘ C11 = A10×B01 + A11×B11
Computing tile C00:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Step 1: Load A00 (4×4=16 floats=64B) into cache
Load B00 (4×4=16 floats=64B) into cache
→ 128B loaded, 4×4=16 dot products computed
→ 16×16=256 MAC ops (ZERO cache misses!)
Step 2: Load A01 into cache (A00 evicted)
Load B10 into cache (B00 evicted)
→ 256 more MAC ops (ZERO cache misses!)
Result: C00 complete
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Cache reuse ratio:
- Naive: each element of B reloaded M times (M misses/element)
- Tiled: each element reused TILE_SIZE times within tile
Tiled MatMul Implementation
import numpy as np
import time
def tiled_matmul(A, B, tile_size=64):
"""
Cache-blocked matrix multiplication.
tile_size: tune to L1 cache size.
L1=256KB, float32: sqrt(256K/4/3) ≈ 146
In practice 64-128 is often optimal.
"""
M, K = A.shape
K2, N = B.shape
assert K == K2
C = np.zeros((M, N), dtype=A.dtype)
for i in range(0, M, tile_size):
i_end = min(i + tile_size, M)
for j in range(0, N, tile_size):
j_end = min(j + tile_size, N)
for k in range(0, K, tile_size):
k_end = min(k + tile_size, K)
tile_A = A[i:i_end, k:k_end] # cache-resident A tile
tile_B = B[k:k_end, j:j_end] # cache-resident B tile
# All ops during this block hit cache!
C[i:i_end, j:j_end] += tile_A @ tile_B
return C
def benchmark():
for N in [512, 1024, 2048]:
A = np.random.randn(N, N).astype(np.float32)
B = np.random.randn(N, N).astype(np.float32)
start = time.perf_counter()
for _ in range(5):
_ = A @ B
t = (time.perf_counter() - start) / 5
tflops = 2 * N**3 / t / 1e12
print(f"N={N}: {t*1000:.1f}ms, {tflops:.2f} TFLOPS")
benchmark()
Speedup summary:
N=4096, tile_size=64 vs naive (CPU, FP32):
Naive triple loop: ~900s
Tiled triple loop: ~18s (50× speedup)
NumPy (BLAS): ~0.8s (1125× speedup)
CUDA cuBLAS: ~0.002s (450,000× speedup)
BLAS: Four Decades of Optimization, Crystallized
The BLAS Level Hierarchy
BLAS (Basic Linear Algebra Subprograms, 1979–present):
Level 1: Vector-Vector ops (O(n) data, O(n) ops)
- dot(x, y): dot product
- axpy(a, x, y): y = a*x + y
- nrm2(x): L2 norm
Level 2: Matrix-Vector ops (O(n²) data, O(n²) ops)
- gemv: y = alpha*A*x + beta*y
→ Low arithmetic intensity → memory-bound
Level 3: Matrix-Matrix ops (O(n²) data, O(n³) ops)
- gemm: C = alpha*A*B + beta*C ← THE KEY ONE
→ High arithmetic intensity → can be compute-bound
→ The core operation of deep learning
GEMM: General Matrix Multiply
GEMM interface:
C = alpha * op(A) * op(B) + beta * C
Parameters:
transa/transb: transpose flag (N=none, T=transpose, C=conj-transpose)
m, n, k: dimensions (C is m×n, A is m×k, B is k×n)
alpha, beta: scalar coefficients
A, B, C: matrix pointers
lda, ldb, ldc: leading dimensions (memory strides)
Role of lda:
When A is a sub-matrix of a larger matrix:
address of A[i][j] = base + i*lda + j
lda >= m required
/* cuBLAS example (NVIDIA GPU BLAS) */
#include <cublas_v2.h>
void gpu_matmul(float *d_A, float *d_B, float *d_C,
int M, int N, int K) {
cublasHandle_t handle;
cublasCreate(&handle);
float alpha = 1.0f;
float beta = 0.0f;
/* cuBLAS uses column-major storage.
C = A*B (row-major) = (B^T * A^T)^T (column-major) */
cublasSgemm(
handle,
CUBLAS_OP_N, CUBLAS_OP_N,
N, M, K, // note: B, A order for column-major
&alpha,
d_B, N, // B (N×K), leading dim=N
d_A, K, // A (K×M), leading dim=K
&beta,
d_C, N // C (N×M), leading dim=N
);
cublasDestroy(handle);
/* cuBLAS achieves 95%+ of theoretical peak
Equivalent to thousands of person-years of tuning */
}
# PyTorch automatically routes through cuBLAS:
import torch
A = torch.randn(4096, 4096, device='cuda', dtype=torch.float16)
B = torch.randn(4096, 4096, device='cuda', dtype=torch.float16)
# Internally calls cuBLAS HGEMM (Half-precision GEMM)
C = torch.mm(A, B)
# Benchmarking:
import torch.utils.benchmark as benchmark
t = benchmark.Timer(
stmt='torch.mm(A, B)',
globals={'A': A, 'B': B}
)
result = t.timeit(100)
print(f"H100 FP16 GEMM 4096^2: {result.mean*1000:.2f}ms")
# Result: ~0.5ms, ~2000 TFLOPS (theoretical max: 1979 TFLOPS)
Tensor Cores: Dedicated Matrix Multiply Hardware
NVIDIA Tensor Core (H100):
Standard CUDA Core:
Per cycle: 1 FMA (Fused Multiply-Add)
Throughput: 1 FLOP/cycle/core
Tensor Core:
Per cycle: 4×8 × 8×8 matrix multiply
Throughput: 256 FLOP/cycle (256× a regular core!)
Supported: FP16, BF16, INT8, INT4, FP8 (H100)
H100 SXM5 numbers:
CUDA Cores: 16,896 × 2 FP32 = ~67 TFLOPS
Tensor Cores: 528 × 2048 FP16 = ~1,979 TFLOPS (29.5×)
Why FP16 is the default in deep learning: Tensor Core utilization!
FlashAttention: Breaking Through the Memory Wall
The Memory Complexity Problem of Standard Attention
The biggest bottleneck in Transformers is the memory complexity of attention.
import torch
import torch.nn.functional as F
def attention_standard(Q, K, V, scale=None):
"""
Standard attention implementation.
Q: (batch, heads, seq_len, d_k)
K: (batch, heads, seq_len, d_k)
V: (batch, heads, seq_len, d_v)
"""
if scale is None:
scale = Q.shape[-1] ** -0.5
# Step 1: Q×K^T → (batch, heads, seq_len, seq_len)
# For seq_len=4096, d_k=128, 32 heads:
# scores size: 1 * 32 * 4096 * 4096 * 2 bytes = 1 GB!
# Must be materialized in HBM (slow)
scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
# Step 2: softmax → another HBM read + write
weights = F.softmax(scores, dim=-1) # 1GB read + 1GB write
# Step 3: weights×V → another HBM read
output = torch.matmul(weights, V)
return output
# seq_len=4096, 32 heads:
# HBM traffic: scores(1GB) + softmax(2GB) + AV(1GB) = ~4GB
# H100 HBM bandwidth: 3.35 TB/s
# Theoretical minimum time: 4GB / 3.35T = ~1.2ms per attention layer
# Most of attention time is memory movement, not computation!
Attention is memory-bound. The GPU's compute units sit idle while HBM accesses dominate runtime.
IO-Aware Attention: The FlashAttention Insight
Tri Dao et al. (2022) asked: "Do we actually need to write the full N×N attention matrix to HBM?"
GPU Memory Hierarchy (H100):
SRAM (L1/Shared Memory, per SM):
Size: ~228KB per SM (132 SMs on H100)
Bandwidth: ~19 TB/s (6× faster than HBM!)
Latency: ~5 cycles
↕ Very fast
HBM (High Bandwidth Memory):
Size: 80GB
Bandwidth: ~3.35 TB/s
Latency: ~300 cycles
↕ Slow (relatively)
FlashAttention strategy:
Instead of writing N×N to HBM,
process it in tiles that fit entirely in SRAM!
FlashAttention Algorithm Visualization:
Inputs in HBM:
Q: (N, d) ─────────────────────────┐
K: (N, d) ─────────────────────────┤ (tiled reads)
V: (N, d) ─────────────────────────┘
↓
SRAM (tile processing):
┌────────────────────────────────────┐
│ Q_tile: (block_q, d) │
│ K_tile: (block_k, d) │
│ V_tile: (block_k, d) │
│ S_tile: (block_q, block_k) │ ← NEVER written to HBM!
│ running_max: (block_q,) │
│ running_sum: (block_q,) │
│ O_tile: (block_q, d) │
└────────────────────────────────────┘
↓
Output in HBM:
O: (N, d) ← written once per tile
HBM accesses:
Standard attention: O(N^2) (read/write full N×N matrix)
FlashAttention: O(N) (tiled linear access)
Online Softmax: The Mathematical Key
FlashAttention's core trick is the online softmax: computing exact softmax incrementally without seeing the full sequence.
import torch
import math
def flash_attention_simplified(Q, K, V, block_size=64):
"""
FlashAttention core logic (simplified version).
Q, K, V: (N, d) — single-head, no batch for clarity
"""
N, d = Q.shape
scale = 1.0 / math.sqrt(d)
dtype = Q.dtype
# Initialize accumulators
O = torch.zeros(N, d, dtype=dtype, device=Q.device)
L = torch.zeros(N, dtype=dtype, device=Q.device) # normalization factor
M = torch.full((N,), float('-inf'), dtype=dtype,
device=Q.device) # running max
# Online softmax math:
# When new tile arrives with scores x_new:
# m_new = max(m_old, max(x_new))
# l_new = exp(m_old - m_new)*l_old + sum(exp(x_new - m_new))
# O_new = (exp(m_old-m_new)*l_old*O_old + exp(x_new-m_new)@V_new) / l_new
# Outer loop: K, V tiles (loop over j)
for j_start in range(0, N, block_size):
j_end = min(j_start + block_size, N)
K_j = K[j_start:j_end] # load to SRAM
V_j = V[j_start:j_end] # load to SRAM
# Inner loop: Q tiles (loop over i)
for i_start in range(0, N, block_size):
i_end = min(i_start + block_size, N)
Q_i = Q[i_start:i_end] # load to SRAM
# Attention scores (computed entirely in SRAM!)
S_ij = (Q_i @ K_j.T) * scale # (block_q, block_k)
# Running max update
m_i_old = M[i_start:i_end]
m_ij = S_ij.max(dim=-1).values
m_i_new = torch.maximum(m_i_old, m_ij)
# Running sum update with correction factor
correction = torch.exp(m_i_old - m_i_new)
l_i_new = (correction * L[i_start:i_end] +
torch.exp(S_ij - m_i_new.unsqueeze(-1)).sum(dim=-1))
# Output update
P_ij = torch.exp(S_ij - m_i_new.unsqueeze(-1))
O[i_start:i_end] = (
(correction * L[i_start:i_end]).unsqueeze(-1) * O[i_start:i_end]
+ P_ij @ V_j
) / l_i_new.unsqueeze(-1)
M[i_start:i_end] = m_i_new
L[i_start:i_end] = l_i_new
return O # written to HBM only once!
# FlashAttention performance numbers (A100 80GB, seq_len=2048):
# Standard attention: 12.3ms
# FlashAttention v1: 3.4ms (3.6× faster)
# FlashAttention v2: 2.1ms (5.9× faster)
# FlashAttention v3 (H100): 1.4ms (8.8× faster)
Additional Benefits of FlashAttention
Backward Pass Memory Savings:
Standard attention: must store attention weights (N×N) for backprop
→ 4096 tokens: 64MB per head per layer
→ 32 heads × 32 layers × batchsize = massive memory
FlashAttention: no need to store attention weights
→ only (m, l) vectors needed: O(N) memory
→ recompute blocks during backward pass
FlashAttention-2 (2023) additional optimizations:
- Swap Q and K,V loop order (better thread-level parallelism)
- Eliminate unnecessary rescaling operations
- Better warp-level work partitioning
→ 73% of theoretical peak (v1: 40%)
FlashAttention-3 (2024, H100-specific):
- Overlap GEMM with softmax using asynchronous Tensor Cores
- FP8 support (2× the throughput of FP16)
→ 75%+ of theoretical peak on H100
GEMM Performance Analysis: The Roofline Model
When Is It Compute-Bound vs. Memory-Bound?
Roofline Model:
Attainable performance = min(compute_peak,
memory_bw × arithmetic_intensity)
Arithmetic Intensity (AI) = FLOPS / BYTES_ACCESSED
H100 SXM5 parameters:
FP16 Tensor Cores: 1,979 TFLOPS
HBM bandwidth: 3,350 GB/s
Ridge Point: 1979T / 3.35T = 590 FLOP/Byte
→ AI > 590: compute-bound
→ AI < 590: memory-bound
def analyze_matmul_intensity(M, K, N, dtype_bytes=2):
"""
Compute arithmetic intensity of (M×K) × (K×N) = (M×N)
"""
flops = 2 * M * K * N # multiply + add
# Memory: read A, read B, write C
bytes_accessed = (M * K + K * N + M * N) * dtype_bytes
ai = flops / bytes_accessed
return flops, bytes_accessed, ai
# Case 1: Large matrix (training with big batches)
M = N = K = 4096
flops, bytes_val, ai = analyze_matmul_intensity(M, K, N)
print(f"Large GEMM (4096^3):")
print(f" FLOPs: {flops/1e12:.1f} TFLOPS")
print(f" Bytes: {bytes_val/1e6:.0f} MB")
print(f" AI: {ai:.0f} FLOP/Byte")
print(f" H100 Ridge: 590 FLOP/Byte")
print(f" Status: {'COMPUTE-BOUND' if ai > 590 else 'MEMORY-BOUND'}")
# → AI=1370, COMPUTE-BOUND ✅
print()
# Case 2: Inference (batch_size=1)
M, K, N = 1, 4096, 4096
flops, bytes_val, ai = analyze_matmul_intensity(M, K, N)
print(f"Inference GEMV (batch=1):")
print(f" FLOPs: {flops/1e6:.1f} MFLOPS")
print(f" Bytes: {bytes_val/1e6:.0f} MB (nearly all from matrix B!)")
print(f" AI: {ai:.2f} FLOP/Byte")
print(f" Status: {'COMPUTE-BOUND' if ai > 590 else 'MEMORY-BOUND'}")
# → AI=0.99, MEMORY-BOUND ❌
# This is the fundamental reason LLM inference is hard to optimize!
Real performance measurements (H100, FP16, 4096×4096 weight):
Batch size vs. GEMM performance:
┌───────────┬────────────┬────────────┬──────────────────┐
│ Batch (M) │ TFLOPS │ AI │ Bound │
├───────────┼────────────┼────────────┼──────────────────┤
│ 1 │ 0.016 │ ~1 │ Memory-bound ❌ │
│ 4 │ 0.063 │ ~4 │ Memory-bound ❌ │
│ 16 │ 0.25 │ ~16 │ Memory-bound │
│ 64 │ 0.98 │ ~63 │ Memory-bound │
│ 256 │ 3.8 │ ~250 │ Transition zone │
│ 1024 │ 14.2 │ ~1000 │ Compute-bound ✅ │
│ 4096 │ 1,847 │ ~1370 │ Compute-bound ✅ │
└───────────┴────────────┴────────────┴──────────────────┘
→ Larger batch = higher GPU utilization
→ This is why batching matters so much in LLM serving
Practical CUDA GEMM Optimization Tips
import torch
# Tip 1: Enable TF32 (drop-in FP32 replacement on Ampere+)
# FP32: 23-bit mantissa → TF32: 10-bit mantissa (slight precision loss)
# But enables Tensor Core usage → ~10× faster
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Tip 2: Ensure contiguous memory tensors
# Non-contiguous tensors arise after torch.transpose, etc.
x = torch.randn(64, 128, 128, device='cuda')
x_t = x.transpose(1, 2) # non-contiguous!
x_t_cont = x_t.contiguous() # copy to contiguous memory
# Tip 3: Align matrix dimensions to multiples of 16
# H100 Tensor Cores are most efficient at 16×n sizes
def pad_to_multiple(x, multiple=16):
d = x.shape[-1]
if d % multiple == 0:
return x
pad_size = multiple - (d % multiple)
return torch.nn.functional.pad(x, (0, pad_size))
# Tip 4: torch.compile (PyTorch 2.0+)
@torch.compile
def optimized_linear(x, weight, bias=None):
"""
torch.compile auto-generates fused kernels,
selects optimal memory layouts, and exploits
hardware-specific features.
"""
return torch.nn.functional.linear(x, weight, bias)
# Tip 5: Use BF16 for training, FP16 for inference
# BF16: same range as FP32, less precision → stable training
# FP16: higher precision, possible overflow → needs loss scaling
model = model.to(torch.bfloat16) # training
model = model.to(torch.float16) # inference (if no overflow risk)
The Full Picture: Matrix Multiply Optimization Stack
Modern deep learning matrix multiplication stack:
Application Layer:
PyTorch / JAX / TensorFlow
↓ (torch.mm, F.linear, etc.)
Kernel Fusion / Compiler:
torch.compile (TorchInductor)
XLA (JAX backend)
↓ (generates optimized CUDA code)
High-Level Libraries:
cuDNN (neural network focused)
cuBLAS (general-purpose GEMM)
CUTLASS (NVIDIA open-source templates)
↓ (optimized PTX/SASS assembly)
Hardware:
Tensor Cores (H100: 1,979 TFLOPS FP16)
↓
Memory System:
L1/L2 cache (SRAM)
HBM3 (3.35 TB/s)
FlashAttention lives at the "High-Level Libraries" layer,
using an IO-aware algorithm to minimize HBM accesses
and keep everything in fast SRAM.
Matrix multiplication looks simple, but its optimization runs 40 years deep, representing the work of thousands of engineers. As GPU architectures evolve, so do the optimization strategies. The FlashAttention lineage demonstrates that algorithmic innovation at the software level can match hardware improvements in impact.
For LLM infrastructure engineers, understanding this full stack is not optional. The next post in this series digs into how these foundations power real-world LLM serving optimization: KV cache, PagedAttention, continuous batching, and quantization.