Skip to content
Published on

How Matrices Fly on GPU: Complete Deep Dive from GEMM to FlashAttention

Authors

Why Matrix Multiplication Is All of Deep Learning

When a GPU runs a deep learning model, more than 80% of the actual work reduces to a single operation: matrix multiplication.

If that sounds hard to believe, let's dissect each layer type:

Linear Layer (Fully Connected Layer)

output = W × input + b
shape: (out_features,) = (out_features, in_features) × (in_features,)

The most fundamental transformation in neural networks. Pure matrix-vector multiplication.

Self-Attention (The Heart of Transformers)

Q = W_q × x        # matrix multiply
K = W_k × x        # matrix multiply
V = W_v × x        # matrix multiply
scores = Q × K^T   # matrix multiply (N×d × d×N = N×N)
weights = softmax(scores / sqrt(d_k))
output = weights × V  # matrix multiply (N×N × N×d = N×d)

A single attention layer fires at least 6 matrix multiplications.

Convolution (CNN)

after im2col transform:
output = W × im2col(input)

Even convolution ultimately gets rewritten as matrix multiplication.

FLOP distribution in GPT-3 (175B parameters):

Linear layers (QKV projections):     ~45%
Feed-forward layers:                  ~35%
Attention computation (QK^T, AV):    ~15%
Others (LayerNorm, softmax, etc.):     ~5%
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Matrix multiplication total:         ~95%

How fast you can multiply matrices = how fast training and inference runs. This is why thousands of engineers have spent four decades optimizing this one operation.


The Naive Implementation and Its Problems

Triple Loop: The Definition Itself

import numpy as np

def naive_matmul(A, B):
    """
    A: (M, K)  B: (K, N) -> C: (M, N)
    C[i][j] = sum(A[i][k] * B[k][j] for k in range(K))
    """
    M, K = A.shape
    K2, N = B.shape
    assert K == K2, "Inner dimensions must match"

    C = np.zeros((M, N), dtype=np.float64)

    for i in range(M):          # M iterations
        for j in range(N):      # N iterations
            for k in range(K):  # K iterations
                C[i][j] += A[i][k] * B[k][j]  # 1 MAC operation

    return C
    # Total ops: M * N * K multiply-add operations
    # Complexity: O(n^3)
    # For 4096×4096: 4096^3 ≈ 68 BILLION operations!

Mathematically perfect. In practice, horrifyingly slow.

4096×4096 FP32 matrix multiply:

  • Theoretical FLOP count: 2 × 4096³ ≈ 137 GFLOPS
  • H100 theoretical throughput: 67 TFLOPS (FP32)
  • Theoretical time: 137G / 67T ≈ 2ms
  • Naive Python actual time: ~900 seconds (450,000× slower!)

Two root causes: memory access patterns and unused parallelism.

Cache Misses: The Real Enemy

Memory hierarchy (H100):
┌─────────────────────────────────────────────────────────┐
L1 Cache (per SM):     256KB,  ~5 cycle  latency      │
L2 Cache (shared):     50MB,   ~30 cycle latency      │
HBM3 (GPU main mem):  80GB,   ~300 cycle latency      │
CPU DRAM:              TB-scale, ~1000 cycle latency   │
└─────────────────────────────────────────────────────────┘

Cache hit vs miss: 300 cycles ↔ 5 cycles = 60× difference
Cache hit rate IS performance.

Examine the memory access pattern in the naive triple loop (A[4×4], B[4×4]):

A matrix (row-major storage):
addresses: [A00][A01][A02][A03] [A10][A11][A12][A13] ...
            ←── cache line 1 ─→ ←── cache line 2 ─→

A[i][k] access pattern (i=0 fixed, k=0,1,2,3):
A[0][0] → load cache line 1 (cache MISS, 300 cycles)
A[0][1] → reuse cache line 1 (cache HIT, 5 cycles)A[0][2] → reuse cache line 1 (cache HIT)A[0][3] → reuse cache line 1 (cache HIT)A accessed row-wise: cache-FRIENDLY
B matrix column-wise access (j=0 fixed, k=0,1,2,3):
B[0][0] → load cache line 0 (cache MISS)
B[1][0] → load cache line 2 (cache MISS! different row)B[2][0] → load cache line 4 (cache MISS!)B[3][0] → load cache line 6 (cache MISS!)
Each element of B sits on a different cache line.
For large matrices (N=4096): 99% of B accesses are cache misses!

Cache Blocking (Tiling): The Fix

Core Idea: Divide Into Cache-Resident Tiles

Split matrices into tiles that fit in cache so data loaded once gets reused maximally.

Matrix Tiling Visualization (M=8, N=8, K=8, TILE_SIZE=4):

Divide all matrices into 4×4 tiles:

    B (K×N):
    ┌───┬───┐
B00B01B00 = B[0:4, 0:4]
    ├───┼───┤  B01 = B[0:4, 4:8]
B10B11B10 = B[4:8, 0:4]
    └───┴───┘  B11 = B[4:8, 4:8]

A (M×K):           C (M×N):
┌───┬───┐          ┌───┬───┐
A00A01│          │C00C01C00 = A00×B00 + A01×B10
├───┼───┤          ├───┼───┤  C01 = A00×B01 + A01×B11
A10A11│          │C10C11C10 = A10×B00 + A11×B10
└───┴───┘          └───┴───┘  C11 = A10×B01 + A11×B11

Computing tile C00:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Step 1: Load A00 (4×4=16 floats=64B) into cache
        Load B00 (4×4=16 floats=64B) into cache
        → 128B loaded, 4×4=16 dot products computed
16×16=256 MAC ops (ZERO cache misses!)
Step 2: Load A01 into cache (A00 evicted)
        Load B10 into cache (B00 evicted)
256 more MAC ops (ZERO cache misses!)
Result: C00 complete
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Cache reuse ratio:
- Naive: each element of B reloaded M times (M misses/element)
- Tiled: each element reused TILE_SIZE times within tile

Tiled MatMul Implementation

import numpy as np
import time

def tiled_matmul(A, B, tile_size=64):
    """
    Cache-blocked matrix multiplication.
    tile_size: tune to L1 cache size.
               L1=256KB, float32: sqrt(256K/4/3) ≈ 146
               In practice 64-128 is often optimal.
    """
    M, K = A.shape
    K2, N = B.shape
    assert K == K2
    C = np.zeros((M, N), dtype=A.dtype)

    for i in range(0, M, tile_size):
        i_end = min(i + tile_size, M)
        for j in range(0, N, tile_size):
            j_end = min(j + tile_size, N)
            for k in range(0, K, tile_size):
                k_end = min(k + tile_size, K)
                tile_A = A[i:i_end, k:k_end]  # cache-resident A tile
                tile_B = B[k:k_end, j:j_end]  # cache-resident B tile
                # All ops during this block hit cache!
                C[i:i_end, j:j_end] += tile_A @ tile_B
    return C


def benchmark():
    for N in [512, 1024, 2048]:
        A = np.random.randn(N, N).astype(np.float32)
        B = np.random.randn(N, N).astype(np.float32)

        start = time.perf_counter()
        for _ in range(5):
            _ = A @ B
        t = (time.perf_counter() - start) / 5

        tflops = 2 * N**3 / t / 1e12
        print(f"N={N}: {t*1000:.1f}ms, {tflops:.2f} TFLOPS")

benchmark()

Speedup summary:

N=4096, tile_size=64 vs naive (CPU, FP32):
Naive triple loop:   ~900s
Tiled triple loop:   ~18s      (50× speedup)
NumPy (BLAS):        ~0.8s     (1125× speedup)
CUDA cuBLAS:         ~0.002s   (450,000× speedup)

BLAS: Four Decades of Optimization, Crystallized

The BLAS Level Hierarchy

BLAS (Basic Linear Algebra Subprograms, 1979–present):

Level 1: Vector-Vector ops  (O(n) data, O(n) ops)
  - dot(x, y):     dot product
  - axpy(a, x, y): y = a*x + y
  - nrm2(x):       L2 norm

Level 2: Matrix-Vector ops  (O() data, O() ops)
  - gemv: y = alpha*A*x + beta*y
Low arithmetic intensity → memory-bound

Level 3: Matrix-Matrix ops  (O() data, O() ops)
  - gemm: C = alpha*A*B + beta*CTHE KEY ONE
High arithmetic intensity → can be compute-bound
The core operation of deep learning

GEMM: General Matrix Multiply

GEMM interface:
C = alpha * op(A) * op(B) + beta * C

Parameters:
  transa/transb: transpose flag (N=none, T=transpose, C=conj-transpose)
  m, n, k:       dimensions (C is m×n, A is m×k, B is k×n)
  alpha, beta:   scalar coefficients
  A, B, C:       matrix pointers
  lda, ldb, ldc: leading dimensions (memory strides)

Role of lda:
  When A is a sub-matrix of a larger matrix:
  address of A[i][j] = base + i*lda + j
  lda >= m required
/* cuBLAS example (NVIDIA GPU BLAS) */
#include <cublas_v2.h>

void gpu_matmul(float *d_A, float *d_B, float *d_C,
                int M, int N, int K) {
    cublasHandle_t handle;
    cublasCreate(&handle);

    float alpha = 1.0f;
    float beta  = 0.0f;

    /* cuBLAS uses column-major storage.
       C = A*B (row-major) = (B^T * A^T)^T (column-major) */
    cublasSgemm(
        handle,
        CUBLAS_OP_N, CUBLAS_OP_N,
        N, M, K,          // note: B, A order for column-major
        &alpha,
        d_B, N,           // B (N×K), leading dim=N
        d_A, K,           // A (K×M), leading dim=K
        &beta,
        d_C, N            // C (N×M), leading dim=N
    );

    cublasDestroy(handle);
    /* cuBLAS achieves 95%+ of theoretical peak
       Equivalent to thousands of person-years of tuning */
}
# PyTorch automatically routes through cuBLAS:
import torch

A = torch.randn(4096, 4096, device='cuda', dtype=torch.float16)
B = torch.randn(4096, 4096, device='cuda', dtype=torch.float16)

# Internally calls cuBLAS HGEMM (Half-precision GEMM)
C = torch.mm(A, B)

# Benchmarking:
import torch.utils.benchmark as benchmark
t = benchmark.Timer(
    stmt='torch.mm(A, B)',
    globals={'A': A, 'B': B}
)
result = t.timeit(100)
print(f"H100 FP16 GEMM 4096^2: {result.mean*1000:.2f}ms")
# Result: ~0.5ms, ~2000 TFLOPS (theoretical max: 1979 TFLOPS)

Tensor Cores: Dedicated Matrix Multiply Hardware

NVIDIA Tensor Core (H100):

Standard CUDA Core:
  Per cycle: 1 FMA (Fused Multiply-Add)
  Throughput: 1 FLOP/cycle/core

Tensor Core:
  Per cycle: 4×8 × 8×8 matrix multiply
  Throughput: 256 FLOP/cycle (256× a regular core!)
  Supported: FP16, BF16, INT8, INT4, FP8 (H100)

H100 SXM5 numbers:
  CUDA Cores:    16,896 × 2 FP32 = ~67 TFLOPS
  Tensor Cores:  528 × 2048 FP16  = ~1,979 TFLOPS (29.5×)

Why FP16 is the default in deep learning: Tensor Core utilization!

FlashAttention: Breaking Through the Memory Wall

The Memory Complexity Problem of Standard Attention

The biggest bottleneck in Transformers is the memory complexity of attention.

import torch
import torch.nn.functional as F

def attention_standard(Q, K, V, scale=None):
    """
    Standard attention implementation.
    Q: (batch, heads, seq_len, d_k)
    K: (batch, heads, seq_len, d_k)
    V: (batch, heads, seq_len, d_v)
    """
    if scale is None:
        scale = Q.shape[-1] ** -0.5

    # Step 1: Q×K^T → (batch, heads, seq_len, seq_len)
    # For seq_len=4096, d_k=128, 32 heads:
    # scores size: 1 * 32 * 4096 * 4096 * 2 bytes = 1 GB!
    # Must be materialized in HBM (slow)
    scores = torch.matmul(Q, K.transpose(-2, -1)) * scale

    # Step 2: softmax → another HBM read + write
    weights = F.softmax(scores, dim=-1)   # 1GB read + 1GB write

    # Step 3: weights×V → another HBM read
    output = torch.matmul(weights, V)

    return output

# seq_len=4096, 32 heads:
# HBM traffic: scores(1GB) + softmax(2GB) + AV(1GB) = ~4GB
# H100 HBM bandwidth: 3.35 TB/s
# Theoretical minimum time: 4GB / 3.35T = ~1.2ms per attention layer
# Most of attention time is memory movement, not computation!

Attention is memory-bound. The GPU's compute units sit idle while HBM accesses dominate runtime.

IO-Aware Attention: The FlashAttention Insight

Tri Dao et al. (2022) asked: "Do we actually need to write the full N×N attention matrix to HBM?"

GPU Memory Hierarchy (H100):

SRAM (L1/Shared Memory, per SM):
  Size:      ~228KB per SM (132 SMs on H100)
  Bandwidth: ~19 TB/s (6× faster than HBM!)
  Latency:   ~5 cycles
Very fast

HBM (High Bandwidth Memory):
  Size:      80GB
  Bandwidth: ~3.35 TB/s
  Latency:   ~300 cycles
Slow (relatively)

FlashAttention strategy:
  Instead of writing N×N to HBM,
  process it in tiles that fit entirely in SRAM!
FlashAttention Algorithm Visualization:

Inputs in HBM:
  Q: (N, d) ─────────────────────────┐
  K: (N, d) ─────────────────────────┤ (tiled reads)
  V: (N, d) ─────────────────────────┘
SRAM (tile processing):
  ┌────────────────────────────────────┐
Q_tile: (block_q, d)K_tile: (block_k, d)V_tile: (block_k, d)S_tile: (block_q, block_k)       │ ← NEVER written to HBM!
  │  running_max: (block_q,)  │  running_sum: (block_q,)O_tile: (block_q, d)  └────────────────────────────────────┘
Output in HBM:
  O: (N, d) ← written once per tile

HBM accesses:
  Standard attention: O(N^2) (read/write full N×N matrix)
  FlashAttention:     O(N)   (tiled linear access)

Online Softmax: The Mathematical Key

FlashAttention's core trick is the online softmax: computing exact softmax incrementally without seeing the full sequence.

import torch
import math

def flash_attention_simplified(Q, K, V, block_size=64):
    """
    FlashAttention core logic (simplified version).
    Q, K, V: (N, d) — single-head, no batch for clarity
    """
    N, d = Q.shape
    scale = 1.0 / math.sqrt(d)
    dtype = Q.dtype

    # Initialize accumulators
    O = torch.zeros(N, d, dtype=dtype, device=Q.device)
    L = torch.zeros(N, dtype=dtype, device=Q.device)       # normalization factor
    M = torch.full((N,), float('-inf'), dtype=dtype,
                   device=Q.device)                         # running max

    # Online softmax math:
    # When new tile arrives with scores x_new:
    #   m_new = max(m_old, max(x_new))
    #   l_new = exp(m_old - m_new)*l_old + sum(exp(x_new - m_new))
    #   O_new = (exp(m_old-m_new)*l_old*O_old + exp(x_new-m_new)@V_new) / l_new

    # Outer loop: K, V tiles  (loop over j)
    for j_start in range(0, N, block_size):
        j_end = min(j_start + block_size, N)
        K_j = K[j_start:j_end]   # load to SRAM
        V_j = V[j_start:j_end]   # load to SRAM

        # Inner loop: Q tiles  (loop over i)
        for i_start in range(0, N, block_size):
            i_end = min(i_start + block_size, N)
            Q_i = Q[i_start:i_end]   # load to SRAM

            # Attention scores (computed entirely in SRAM!)
            S_ij = (Q_i @ K_j.T) * scale  # (block_q, block_k)

            # Running max update
            m_i_old = M[i_start:i_end]
            m_ij = S_ij.max(dim=-1).values
            m_i_new = torch.maximum(m_i_old, m_ij)

            # Running sum update with correction factor
            correction = torch.exp(m_i_old - m_i_new)
            l_i_new = (correction * L[i_start:i_end] +
                       torch.exp(S_ij - m_i_new.unsqueeze(-1)).sum(dim=-1))

            # Output update
            P_ij = torch.exp(S_ij - m_i_new.unsqueeze(-1))
            O[i_start:i_end] = (
                (correction * L[i_start:i_end]).unsqueeze(-1) * O[i_start:i_end]
                + P_ij @ V_j
            ) / l_i_new.unsqueeze(-1)

            M[i_start:i_end] = m_i_new
            L[i_start:i_end] = l_i_new

    return O  # written to HBM only once!


# FlashAttention performance numbers (A100 80GB, seq_len=2048):
# Standard attention:      12.3ms
# FlashAttention v1:        3.4ms (3.6× faster)
# FlashAttention v2:        2.1ms (5.9× faster)
# FlashAttention v3 (H100): 1.4ms (8.8× faster)

Additional Benefits of FlashAttention

Backward Pass Memory Savings:
  Standard attention: must store attention weights (N×N) for backprop
4096 tokens: 64MB per head per layer
32 heads × 32 layers × batchsize = massive memory
  FlashAttention: no need to store attention weights
only (m, l) vectors needed: O(N) memory
    → recompute blocks during backward pass

FlashAttention-2 (2023) additional optimizations:
  - Swap Q and K,V loop order (better thread-level parallelism)
  - Eliminate unnecessary rescaling operations
  - Better warp-level work partitioning
73% of theoretical peak (v1: 40%)

FlashAttention-3 (2024, H100-specific):
  - Overlap GEMM with softmax using asynchronous Tensor Cores
  - FP8 support (2× the throughput of FP16)
75%+ of theoretical peak on H100

GEMM Performance Analysis: The Roofline Model

When Is It Compute-Bound vs. Memory-Bound?

Roofline Model:
  Attainable performance = min(compute_peak,
                               memory_bw × arithmetic_intensity)

  Arithmetic Intensity (AI) = FLOPS / BYTES_ACCESSED

H100 SXM5 parameters:
  FP16 Tensor Cores:  1,979 TFLOPS
  HBM bandwidth:      3,350 GB/s
  Ridge Point: 1979T / 3.35T = 590 FLOP/Byte
AI > 590: compute-bound
AI < 590: memory-bound
def analyze_matmul_intensity(M, K, N, dtype_bytes=2):
    """
    Compute arithmetic intensity of (M×K) × (K×N) = (M×N)
    """
    flops = 2 * M * K * N  # multiply + add

    # Memory: read A, read B, write C
    bytes_accessed = (M * K + K * N + M * N) * dtype_bytes

    ai = flops / bytes_accessed
    return flops, bytes_accessed, ai


# Case 1: Large matrix (training with big batches)
M = N = K = 4096
flops, bytes_val, ai = analyze_matmul_intensity(M, K, N)
print(f"Large GEMM (4096^3):")
print(f"  FLOPs:  {flops/1e12:.1f} TFLOPS")
print(f"  Bytes:  {bytes_val/1e6:.0f} MB")
print(f"  AI:     {ai:.0f} FLOP/Byte")
print(f"  H100 Ridge: 590 FLOP/Byte")
print(f"  Status: {'COMPUTE-BOUND' if ai > 590 else 'MEMORY-BOUND'}")
# → AI=1370, COMPUTE-BOUND ✅

print()

# Case 2: Inference (batch_size=1)
M, K, N = 1, 4096, 4096
flops, bytes_val, ai = analyze_matmul_intensity(M, K, N)
print(f"Inference GEMV (batch=1):")
print(f"  FLOPs:  {flops/1e6:.1f} MFLOPS")
print(f"  Bytes:  {bytes_val/1e6:.0f} MB (nearly all from matrix B!)")
print(f"  AI:     {ai:.2f} FLOP/Byte")
print(f"  Status: {'COMPUTE-BOUND' if ai > 590 else 'MEMORY-BOUND'}")
# → AI=0.99, MEMORY-BOUND ❌
# This is the fundamental reason LLM inference is hard to optimize!
Real performance measurements (H100, FP16, 4096×4096 weight):

Batch size vs. GEMM performance:
┌───────────┬────────────┬────────────┬──────────────────┐
Batch (M)TFLOPSAIBound├───────────┼────────────┼────────────┼──────────────────┤
10.016~1Memory-bound ❌  │
40.063~4Memory-bound ❌  │
160.25~16Memory-bound     │
640.98~63Memory-bound     │
2563.8~250Transition zone  │
102414.2~1000Compute-bound ✅ │
40961,847~1370Compute-bound ✅ │
└───────────┴────────────┴────────────┴──────────────────┘

Larger batch = higher GPU utilization
This is why batching matters so much in LLM serving

Practical CUDA GEMM Optimization Tips

import torch

# Tip 1: Enable TF32 (drop-in FP32 replacement on Ampere+)
# FP32: 23-bit mantissa → TF32: 10-bit mantissa (slight precision loss)
# But enables Tensor Core usage → ~10× faster
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# Tip 2: Ensure contiguous memory tensors
# Non-contiguous tensors arise after torch.transpose, etc.
x = torch.randn(64, 128, 128, device='cuda')
x_t = x.transpose(1, 2)          # non-contiguous!
x_t_cont = x_t.contiguous()      # copy to contiguous memory

# Tip 3: Align matrix dimensions to multiples of 16
# H100 Tensor Cores are most efficient at 16×n sizes
def pad_to_multiple(x, multiple=16):
    d = x.shape[-1]
    if d % multiple == 0:
        return x
    pad_size = multiple - (d % multiple)
    return torch.nn.functional.pad(x, (0, pad_size))

# Tip 4: torch.compile (PyTorch 2.0+)
@torch.compile
def optimized_linear(x, weight, bias=None):
    """
    torch.compile auto-generates fused kernels,
    selects optimal memory layouts, and exploits
    hardware-specific features.
    """
    return torch.nn.functional.linear(x, weight, bias)

# Tip 5: Use BF16 for training, FP16 for inference
# BF16: same range as FP32, less precision → stable training
# FP16: higher precision, possible overflow → needs loss scaling
model = model.to(torch.bfloat16)   # training
model = model.to(torch.float16)    # inference (if no overflow risk)

The Full Picture: Matrix Multiply Optimization Stack

Modern deep learning matrix multiplication stack:

Application Layer:
  PyTorch / JAX / TensorFlow
   (torch.mm, F.linear, etc.)

Kernel Fusion / Compiler:
  torch.compile (TorchInductor)
  XLA (JAX backend)
   (generates optimized CUDA code)

High-Level Libraries:
  cuDNN  (neural network focused)
  cuBLAS (general-purpose GEMM)
  CUTLASS (NVIDIA open-source templates)
   (optimized PTX/SASS assembly)

Hardware:
  Tensor Cores (H100: 1,979 TFLOPS FP16)

Memory System:
  L1/L2 cache (SRAM)
  HBM3 (3.35 TB/s)

FlashAttention lives at the "High-Level Libraries" layer,
using an IO-aware algorithm to minimize HBM accesses
and keep everything in fast SRAM.

Matrix multiplication looks simple, but its optimization runs 40 years deep, representing the work of thousands of engineers. As GPU architectures evolve, so do the optimization strategies. The FlashAttention lineage demonstrates that algorithmic innovation at the software level can match hardware improvements in impact.

For LLM infrastructure engineers, understanding this full stack is not optional. The next post in this series digs into how these foundations power real-world LLM serving optimization: KV cache, PagedAttention, continuous batching, and quantization.