Skip to content

Split View: 행렬이 GPU에서 어떻게 날아다니는가: GEMM부터 FlashAttention까지 완전 해부

|

행렬이 GPU에서 어떻게 날아다니는가: GEMM부터 FlashAttention까지 완전 해부

왜 행렬 곱셈이 딥러닝의 전부인가

딥러닝 모델을 실행할 때 GPU가 실제로 하는 일의 80% 이상은 단 하나의 연산으로 귀결된다. 바로 **행렬 곱셈(Matrix Multiplication)**이다.

믿기 어렵다면 각 레이어를 뜯어보자:

Linear Layer (Fully Connected Layer)

output = W × input + b
shape: (out_features,) = (out_features, in_features) × (in_features,)

가장 기본적인 선형 변환. 순수한 행렬-벡터 곱셈이다.

Self-Attention (Transformer의 심장)

Q = W_q × x        # 행렬 곱셈
K = W_k × x        # 행렬 곱셈
V = W_v × x        # 행렬 곱셈
scores = Q × K^T   # 행렬 곱셈 (N×d × d×N = N×N)
weights = softmax(scores / sqrt(d_k))
output = weights × V  # 행렬 곱셈 (N×N × N×d = N×d)

한 번의 어텐션 레이어에서 최소 6번의 행렬 곱셈이 발생한다.

Convolution (CNN)

im2col 변환 후:
output = W × im2col(input)

컨볼루션도 결국 행렬 곱셈으로 변환된다.

GPT-3 (175B) 기준 FLOPs 분포:

Linear layers (QKV projection):  ~45%
Feed-forward layers:              ~35%
Attention computation (QK^T, AV): ~15%
Others (LayerNorm, softmax, etc.):  ~5%
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
행렬 곱셈 합계:                   ~95%

행렬 곱셈을 얼마나 빠르게 할 수 있느냐 = 딥러닝 추론/학습 속도가 결정된다. 이것이 40년간 수천 명의 엔지니어가 행렬 곱셈 최적화에 매달려온 이유다.


순진한(Naive) 행렬 곱셈과 그 문제점

3중 루프: 정의 그 자체

import numpy as np

def naive_matmul(A, B):
    """
    A: (M, K)  B: (K, N) -> C: (M, N)
    C[i][j] = sum(A[i][k] * B[k][j] for k in range(K))
    """
    M, K = A.shape
    K2, N = B.shape
    assert K == K2, "Inner dimensions must match"

    C = np.zeros((M, N), dtype=np.float64)

    for i in range(M):          # M 반복
        for j in range(N):      # N 반복
            for k in range(K):  # K 반복
                C[i][j] += A[i][k] * B[k][j]  # MAC 연산 1회

    return C
    # 총 연산: M * N * K 번의 multiply-add
    # 복잡도: O(n^3)
    # 4096×4096 행렬: 4096^3 = 약 680억 번의 연산!

수학적으로는 완벽하다. 하지만 실제로 돌려보면 끔찍하게 느리다.

4096×4096 FP32 행렬 곱셈:

  • 이론 연산량: 2 × 4096³ ≈ 137 GFLOPS
  • H100 이론 성능: 67 TFLOPS (FP32)
  • 이론 시간: 137G / 67T ≈ 2ms
  • Naive Python 실제 시간: 약 900초 (450,000배 느림!)

차이의 원인은 두 가지다. 메모리 접근 패턴병렬성 미활용.

캐시 미스: 진짜 적

메모리 계층 구조 (H100 기준):
┌─────────────────────────────────────────────────────────┐
L1 Cache (per SM):     256KB,  ~5 cycles  latency     │
L2 Cache (shared):     50MB,   ~30 cycles latency     │
HBM3 (GPU main mem):  80GB,   ~300 cycles latency     │
CPU DRAM:TB,   ~1000 cycles latency    │
└─────────────────────────────────────────────────────────┘

캐시 접근: 300 cycles ↔ 5 cycles = 60배 차이
캐시 적중률이 곧 성능이다.

Naive 3중 루프에서 메모리 접근 패턴을 보자 (A[4×4], B[4×4] 예시):

A 행렬 (행 우선 저장, row-major):
메모리 주소: [A00][A01][A02][A03] [A10][A11][A12][A13] ...
              ←── 캐시라인 1 ──→  ←── 캐시라인 2 ──→

A[i][k] 접근 패턴 (i=0 고정, k=0,1,2,3 순서):
A[0][0] → 캐시라인 1 로드 (cache MISS, 300 cycles)
A[0][1] → 캐시라인 1 재사용 (cache HIT, 5 cycles)A[0][2] → 캐시라인 1 재사용 (cache HIT)A[0][3] → 캐시라인 1 재사용 (cache HIT)A는 행 방향 접근: 캐시 친화적 ✅

B 행렬의 열 방향 접근 (j=0 고정, k=0,1,2,3 순서):
B[0][0] → 캐시라인 0 로드 (cache MISS)
B[1][0] → 캐시라인 2 로드 (cache MISS! 다른 행)B[2][0] → 캐시라인 4 로드 (cache MISS!)B[3][0] → 캐시라인 6 로드 (cache MISS!)
B의 각 원소가 별도 캐시라인에 있음
대형 행렬(N=4096)에서 B 접근의 99%가 캐시 미스!

실제로 측정해보면:

import time
import numpy as np

def measure_matmul():
    N = 1024
    A = np.random.randn(N, N).astype(np.float32)
    B = np.random.randn(N, N).astype(np.float32)
    B_T = B.T.copy()  # B의 전치를 연속 메모리에 복사

    # 방법 1: 순진한 열 접근 (캐시 비친화적)
    start = time.perf_counter()
    C1 = A @ B  # numpy도 내부적으로 캐시 최적화함
    t1 = time.perf_counter() - start

    print(f"NumPy (최적화됨): {t1*1000:.2f}ms")
    # 순수 Python 3중 루프는 수백 배 더 느림

measure_matmul()

Cache Blocking (Tiling): 해결책

핵심 아이디어: 캐시에 맞는 블록으로 분할

행렬을 캐시에 들어가는 작은 타일(tile)로 분할해 한 번 로드한 데이터를 최대한 재사용한다.

Matrix Tiling 시각화 (M=8, N=8, K=8, TILE_SIZE=4):

전체 행렬을 4×4 타일로 분할:

    B (K×N):
    ┌───┬───┐
B00B01B00 = B[0:4, 0:4]
    ├───┼───┤  B01 = B[0:4, 4:8]
B10B11B10 = B[4:8, 0:4]
    └───┴───┘  B11 = B[4:8, 4:8]

A (M×K):           C (M×N):
┌───┬───┐          ┌───┬───┐
A00A01│          │C00C01C00 = A00×B00 + A01×B10
├───┼───┤          ├───┼───┤  C01 = A00×B01 + A01×B11
A10A11│          │C10C11C10 = A10×B00 + A11×B10
└───┴───┘          └───┴───┘  C11 = A10×B01 + A11×B11

타일 C00 계산 과정:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Step 1: A00(4×4=16 floats=64B) 캐시 로드
        B00(4×4=16 floats=64B) 캐시 로드
        → 128B 로드로 4×4=16번의 내적 계산
16×16=256번의 MAC 연산 (캐시 미스 ZERO!)
Step 2: A01(4×4) 캐시 로드 (A00 evict)
        B10(4×4) 캐시 로드 (B00 evict)
256번의 MAC 연산 (캐시 미스 ZERO!)
결과: C00 완성
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

캐시 재사용률:
- Naive: B의 각 원소를 M번씩 다시 로드 (M번 캐시미스/원소)
- Tiled: 타일 내 원소를 TILE_SIZE재사용 (TILE_SIZE배 캐시미스 감소)

Tiled MatMul 구현

import numpy as np

def tiled_matmul(A, B, tile_size=64):
    """
    캐시 블로킹을 이용한 행렬 곱셈
    tile_size: L1 캐시 크기에 맞게 조정
               L1=256KB, float32: sqrt(256K/4/3) ≈ 146
               실제로는 64-128이 최적인 경우가 많음
    """
    M, K = A.shape
    K2, N = B.shape
    assert K == K2
    C = np.zeros((M, N), dtype=A.dtype)

    for i in range(0, M, tile_size):        # i 타일
        i_end = min(i + tile_size, M)
        for j in range(0, N, tile_size):    # j 타일
            j_end = min(j + tile_size, N)
            # C[i:i_end, j:j_end]를 위한 누적
            for k in range(0, K, tile_size):  # k 타일 (내적 분할)
                k_end = min(k + tile_size, K)
                tile_A = A[i:i_end, k:k_end]   # 캐시에 올라가는 A 타일
                tile_B = B[k:k_end, j:j_end]   # 캐시에 올라가는 B 타일
                # 이 연산 동안 tile_A, tile_B는 캐시에 상주
                C[i:i_end, j:j_end] += tile_A @ tile_B
    return C


def benchmark_tiling():
    """성능 비교"""
    import time
    sizes = [512, 1024, 2048]

    for N in sizes:
        A = np.random.randn(N, N).astype(np.float32)
        B = np.random.randn(N, N).astype(np.float32)

        # NumPy (BLAS 최적화)
        start = time.perf_counter()
        for _ in range(5):
            C_ref = A @ B
        t_blas = (time.perf_counter() - start) / 5

        print(f"N={N}: BLAS={t_blas*1000:.1f}ms, "
              f"TFLOPS={2*N**3/t_blas/1e12:.2f}")

benchmark_tiling()

타일링의 효과:

N=4096, tile_size=64 vs naive (CPU, FP32):
Naive 3-loop:  ~900s
Tiled 3-loop:  ~18s    (50배 향상)
NumPy (BLAS):  ~0.8s   (1125배 향상)
CUDA cuBLAS:   ~0.002s (450,000배 향상)

BLAS: 40년의 최적화 결정체

BLAS 레벨 계층

BLAS (Basic Linear Algebra Subprograms, 1979~):

Level 1: Vector-Vector 연산 (O(n) 데이터, O(n) 연산)
  - dot(x, y):    내적
  - axpy(a, x, y): y = a*x + y  (벡터 덧셈)
  - nrm2(x):      L2 노름

Level 2: Matrix-Vector 연산 (O() 데이터, O() 연산)
  - gemv: y = alpha*A*x + beta*y
  → 산술 집약도 낮음, memory-bound

Level 3: Matrix-Matrix 연산 (O() 데이터, O() 연산)
  - gemm: C = alpha*A*B + beta*CTHE ONE WE CARE ABOUT
  → 산술 집약도 높음, compute-bound 가능
  → 딥러닝의 핵심 연산

GEMM: General Matrix Multiply

GEMM 인터페이스:
C = alpha * op(A) * op(B) + beta * C

파라미터:
  transa/transb: 전치 여부 (N=없음, T=전치, C=켤레전치)
  m, n, k:       행렬 차원 (C는 m×n, A는 m×k, B는 k×n)
  alpha, beta:   스칼라 계수 (alpha=1.0, beta=0.0이면 순수 곱셈)
  A, B, C:       행렬 포인터
  lda, ldb, ldc: leading dimension (메모리 스트라이드)

lda의 역할 (중요!):
  행렬 A가 더 큰 행렬의 부분행렬일 때:
  A[i][j]의 메모리 주소 = base + i*lda + j
  lda >= m 이어야 함
/* cuBLAS 예시 (NVIDIA GPU용 BLAS) */
#include <cublas_v2.h>

void gpu_matmul(float *d_A, float *d_B, float *d_C,
                int M, int N, int K) {
    cublasHandle_t handle;
    cublasCreate(&handle);

    float alpha = 1.0f;
    float beta  = 0.0f;

    /* cuBLAS는 열 우선(column-major) 저장 사용
       C = A*B (행 우선) = (B^T * A^T)^T (열 우선) */
    cublasSgemm(
        handle,
        CUBLAS_OP_N, CUBLAS_OP_N,  // 전치 없음
        N, M, K,                    // 주의: cuBLAS는 B,A 순서!
        &alpha,
        d_B, N,                    // B (N×K), leading dim=N
        d_A, K,                    // A (K×M), leading dim=K
        &beta,
        d_C, N                     // C (N×M), leading dim=N
    );

    cublasDestroy(handle);
    /* cuBLAS는 이론 최고 성능의 95%+ 달성
       수천 person-year의 최적화 결정체 */
}
# PyTorch에서 cuBLAS 활용 (자동으로 사용됨):
import torch

A = torch.randn(4096, 4096, device='cuda', dtype=torch.float16)
B = torch.randn(4096, 4096, device='cuda', dtype=torch.float16)

# 내부적으로 cuBLAS HGEMM (Half-precision GEMM) 호출
C = torch.mm(A, B)

# 성능 프로파일링:
import torch.utils.benchmark as benchmark
t = benchmark.Timer(
    stmt='torch.mm(A, B)',
    globals={'A': A, 'B': B}
)
result = t.timeit(100)
print(f"H100 FP16 GEMM 4096^2: {result.mean*1000:.2f}ms")
# 결과: ~0.5ms, ~2000 TFLOPS (이론 최고 2000 TFLOPS)

Tensor Core: 행렬 곱셈 전용 하드웨어

NVIDIA Tensor Core (H100 기준):

일반 CUDA Core:
  한 번에: 1 FMA (Fused Multiply-Add) 실행
  처리량: 1 FLOP/cycle/core

Tensor Core:
  한 번에: 4×8 × 8×8 = 4×8 행렬 곱셈 실행
  처리량: 256 FLOP/cycle (일반 코어의 256!)
  지원 형식: FP16, BF16, INT8, INT4, FP8 (H100)

H100 SXM5:
  CUDA Cores:   16,896개 × 2 FP32 OPS = ~67 TFLOPS
  Tensor Cores: 528개     × 2048 FP16  = ~1,979 TFLOPS (29.5)

딥러닝에서 FP16이 기본인 이유: Tensor Core 활용!

FlashAttention: 메모리 벽 돌파하기

표준 어텐션의 메모리 복잡도 문제

Transformers의 가장 큰 병목은 어텐션의 메모리 복잡도다.

import torch
import torch.nn.functional as F

def attention_standard(Q, K, V, scale=None):
    """
    표준 어텐션 구현
    Q: (batch, heads, seq_len, d_k)
    K: (batch, heads, seq_len, d_k)
    V: (batch, heads, seq_len, d_v)
    """
    if scale is None:
        scale = Q.shape[-1] ** -0.5

    # Step 1: Q×K^T → (batch, heads, seq_len, seq_len)
    # seq_len=4096, d_k=128:
    # 크기: batch * heads * 4096 * 4096 * 2bytes
    # = 1 * 32 * 16MB = 512MB !!! (HBM에 저장)
    scores = torch.matmul(Q, K.transpose(-2, -1)) * scale

    # Step 2: softmax → 다시 HBM 읽기+쓰기
    weights = F.softmax(scores, dim=-1)      # 512MB 읽기+512MB 쓰기

    # Step 3: weights×V → 다시 HBM 읽기
    output = torch.matmul(weights, V)        # 512MB 읽기

    return output

# seq_len=4096, 32 heads 기준:
# HBM 접근량: scores 생성(512MB) + softmax(1GB) + AV(512MB) = ~2GB
# H100 HBM 대역폭: 3.35 TB/s
# 이론 최소 시간: 2GB / 3.35T = ~0.6ms per forward pass
# → 어텐션의 대부분 시간이 계산이 아닌 메모리 이동!

어텐션은 메모리 바운드(memory-bound) 연산이다. GPU의 계산 유닛은 놀고 있고, HBM(High Bandwidth Memory) 접근이 병목이다.

IO-Aware 어텐션: FlashAttention의 핵심

Tri Dao et al. (2022)의 통찰: "N×N 어텐션 행렬을 굳이 전부 HBM에 쓸 필요가 있나?"

GPU 메모리 계층 (H100):

SRAM (L1/Shared Memory, per SM):
  크기: ~228KB per SM (H100에서 132SM)
  대역폭: ~19 TB/s (HBM6!)
  지연: ~5 cycles
  ↕ 매우 빠름

HBM (High Bandwidth Memory):
  크기: 80GB
  대역폭: ~3.35 TB/s
  지연: ~300 cycles
느림 (상대적으로)

FlashAttention 전략:
  N×N 행렬을 HBM에 쓰지 말고,
  SRAM에 들어가는 타일로 분할해서 처리하자!
FlashAttention 알고리즘 시각화:

Input in HBM:
  Q: (N, d) ─────────────────────────────┐
  K: (N, d) ─────────────────────────────┤
  V: (N, d) ─────────────────────────────┘
              (타일 단위로 SRAM으로 로드)
SRAM (타일 처리):
  ┌──────────────────────────────────────┐
Q_tile: (block_q, d)K_tile: (block_k, d)V_tile: (block_k, d)S_tile: (block_q, block_k)         │ ← HBM에 절대 쓰지 않음!
  │  running_max: (block_q,)  │  running_sum: (block_q,)O_tile: (block_q, d)  └──────────────────────────────────────┘
Output in HBM:
  O: (N, d) ← 타일 완성 후 기록

HBM 접근 횟수:
  표준 어텐션: O(N^2) (N×N 행렬 읽기/쓰기)
  FlashAttention: O(N) (타일 단위 선형 접근)

Online Softmax: 수학적 열쇠

FlashAttention의 핵심 수학적 트릭은 온라인 소프트맥스다. 전체 시퀀스를 보지 않고도 정확한 소프트맥스를 계산할 수 있다.

import torch
import math

def online_softmax_demo():
    """
    온라인 소프트맥스의 수학적 원리 설명

    일반 소프트맥스: softmax(x)[i] = exp(x[i]) / sum(exp(x[j]))
    문제: 전체 x를 알아야 함 (타일 처리 불가)

    수치 안정적 소프트맥스:
    m = max(x)
    softmax(x)[i] = exp(x[i] - m) / sum(exp(x[j] - m))
    → m을 빼도 결과는 동일, exp 오버플로 방지

    온라인 업데이트 (새 타일 x_new가 도착할 때):
    m_new = max(m_old, max(x_new))
    l_new = exp(m_old - m_new) * l_old + sum(exp(x_new - m_new))
    O_new = (exp(m_old - m_new) * l_old * O_old
             + exp(x_new - m_new) @ V_new) / l_new
    """
    pass

def flash_attention_simplified(Q, K, V, block_size=64):
    """
    FlashAttention 핵심 로직 (단순화 버전)
    Q, K, V: (N, d)
    """
    N, d = Q.shape
    scale = 1.0 / math.sqrt(d)
    dtype = Q.dtype

    # 출력 초기화
    O = torch.zeros(N, d, dtype=dtype, device=Q.device)
    L = torch.zeros(N, dtype=dtype, device=Q.device)       # 정규화 팩터
    M = torch.full((N,), float('-inf'), dtype=dtype, device=Q.device)  # 최댓값

    # 외부 루프: K, V 타일
    for j_start in range(0, N, block_size):
        j_end = min(j_start + block_size, N)
        K_j = K[j_start:j_end]   # SRAM에 로드
        V_j = V[j_start:j_end]   # SRAM에 로드

        # 내부 루프: Q 타일
        for i_start in range(0, N, block_size):
            i_end = min(i_start + block_size, N)
            Q_i = Q[i_start:i_end]   # SRAM에 로드

            # 어텐션 스코어 (SRAM 내에서 계산!)
            S_ij = (Q_i @ K_j.T) * scale  # (block_q, block_k)

            # 온라인 최댓값 업데이트
            m_i_old = M[i_start:i_end]
            m_ij = S_ij.max(dim=-1).values  # 현재 타일의 최댓값
            m_i_new = torch.maximum(m_i_old, m_ij)

            # 온라인 sum 업데이트
            # 이전 항: exp(m_old - m_new) * l_old (보정 팩터 적용)
            correction = torch.exp(m_i_old - m_i_new)
            l_i_new = (correction * L[i_start:i_end] +
                       torch.exp(S_ij - m_i_new.unsqueeze(-1)).sum(dim=-1))

            # 출력 업데이트
            P_ij = torch.exp(S_ij - m_i_new.unsqueeze(-1))  # softmax 분자
            O[i_start:i_end] = (
                (correction * L[i_start:i_end]).unsqueeze(-1) * O[i_start:i_end]
                + P_ij @ V_j
            ) / l_i_new.unsqueeze(-1)

            # 상태 저장
            M[i_start:i_end] = m_i_new
            L[i_start:i_end] = l_i_new

    return O  # HBM에 한 번만 씀!


# FlashAttention 성능 수치 (A100 80GB, seq_len=2048):
# 표준 어텐션:     12.3ms
# FlashAttention v1: 3.4ms (3.6배 빠름)
# FlashAttention v2: 2.1ms (5.9배 빠름)
# FlashAttention v3: 1.4ms (8.8배 빠름, H100 전용)

FlashAttention의 추가 이점

역방향 패스(Backward Pass) 메모리 절약:
  표준 어텐션: 순전파에서 attention weights (N×N) 저장 필수
4096 토큰: 64MB per head per layer
32 heads × 32 layers = 65 GB (배치 1에서도!)
  FlashAttention: attention weights 저장 불필요
               (m, l) 벡터만 저장 (O(N) 메모리)
              → 역전파 시 블록 재계산 (recomputation)

FlashAttention-2 추가 최적화 (2023):
  - Q 루프와 K,V 루프 순서 교환 (스레드 병렬성 향상)
  - 불필요한 rescale 연산 제거
  - 워프(warp) 간 작업 분할 최적화
  → 이론 최고 성능의 73% 달성 (v1: 40%)

GEMM 성능 분석: Roofline 모델

언제 Compute-Bound, 언제 Memory-Bound?

Roofline Model:
  성능 상한 = min(compute_peak, memory_bw × arithmetic_intensity)

  산술 집약도 (Arithmetic Intensity, AI):
    AI = FLOPS / BYTES_ACCESSED (FLOP/Byte)

H100 SXM5 파라미터:
  FP16 Tensor Core: 1,979 TFLOPS
  HBM 대역폭:       3,350 GB/s = 3.35 TB/s
  Ridge Point: 1979T / 3.35T = 590 FLOP/Byte
AI > 590: Compute-bound
AI < 590: Memory-bound
def analyze_matmul_intensity(M, K, N, dtype_bytes=2):
    """
    행렬 곱셈 (M×K) × (K×N) = (M×N)의 산술 집약도 계산
    dtype_bytes: FP16=2, FP32=4
    """
    flops = 2 * M * K * N  # multiply + add

    # 메모리 접근: A 읽기 + B 읽기 + C 쓰기
    bytes_accessed = (M * K + K * N + M * N) * dtype_bytes

    ai = flops / bytes_accessed
    return flops, bytes_accessed, ai


# 케이스 1: 대형 행렬 (훈련 시 배치 처리)
M = N = K = 4096
flops, bytes, ai = analyze_matmul_intensity(M, K, N)
print(f"Large GEMM (4096^3):")
print(f"  FLOPs: {flops/1e12:.1f} TFLOPS")
print(f"  Bytes: {bytes/1e6:.0f} MB")
print(f"  AI: {ai:.0f} FLOP/Byte")
print(f"  H100 Ridge: 590 FLOP/Byte")
print(f"  Status: {'COMPUTE-BOUND' if ai > 590 else 'MEMORY-BOUND'}")
# 결과: AI=1370, COMPUTE-BOUND ✅

print()

# 케이스 2: 추론 (배치 크기=1)
M, K, N = 1, 4096, 4096
flops, bytes, ai = analyze_matmul_intensity(M, K, N)
print(f"Inference GEMV (batch=1):")
print(f"  FLOPs: {flops/1e6:.1f} MFLOPS")
print(f"  Bytes: {bytes/1e6:.0f} MB (거의 B 행렬 전체!)")
print(f"  AI: {ai:.2f} FLOP/Byte")
print(f"  Status: {'COMPUTE-BOUND' if ai > 590 else 'MEMORY-BOUND'}")
# 결과: AI=0.99, MEMORY-BOUND ❌
# 이것이 LLM 추론 최적화가 어려운 근본 이유!
실제 성능 측정 (H100, FP16):

배치 크기별 GEMM 성능 (4096×4096 weight matrix):
┌───────────┬────────────┬────────────┬──────────────────┐
Batch (M)TFLOPSAI         │ 상태             │
├───────────┼────────────┼────────────┼──────────────────┤
10.016~1Memory-bound ❌  │
40.063~4Memory-bound ❌  │
160.25~16Memory-bound     │
640.98~63Memory-bound     │
2563.8~250       │ 전환 구간        │
102414.2~1000Compute-bound ✅ │
40961,847~1370Compute-bound ✅ │
└───────────┴────────────┴────────────┴──────────────────┘

→ 배치가 클수록 GPU 활용률 증가
→ 이것이 LLM 서빙에서 배칭이 중요한 이유

실용적 CUDA GEMM 최적화 팁

import torch

# 팁 1: TF32 사용 (H100 기본 FP32 대체)
# FP32: 23비트 가수부 → TF32: 10비트 가수부 (약간의 정밀도 손실)
# 하지만 Tensor Core 사용 가능 → 10배 빠름!
torch.backends.cuda.matmul.allow_tf32 = True    # 행렬 곱셈
torch.backends.cudnn.allow_tf32 = True           # 컨볼루션

# 팁 2: 연속 메모리 텐서 사용
def make_contiguous(x):
    # .contiguous()는 비연속 텐서를 연속 메모리로 복사
    # 비연속 텐서: torch.transpose 후 추가 연산 시 발생
    return x.contiguous()

# 팁 3: 적절한 행렬 크기 (Tensor Core 배수)
# H100 Tensor Core는 16의 배수 크기에서 최적 동작
# 16 × n 크기: 4096=16×256 ✅  4100=16×256+4 → 패딩 필요
def pad_to_multiple(x, multiple=16):
    d = x.shape[-1]
    if d % multiple == 0:
        return x
    pad_size = multiple - (d % multiple)
    return torch.nn.functional.pad(x, (0, pad_size))

# 팁 4: torch.compile (PyTorch 2.0+)
import torch

@torch.compile
def optimized_linear(x, weight, bias=None):
    """torch.compile이 자동으로 커널 퓨전, 최적 레이아웃 선택"""
    return torch.nn.functional.linear(x, weight, bias)

전체 그림: 행렬 곱셈 최적화 계층

현대 딥러닝 행렬 곱셈 최적화 스택:

애플리케이션 레이어:
  PyTorch / JAX / TensorFlow
   (torch.mm, F.linear)

커널 퓨전 / 컴파일러:
  torch.compile (TorchInductor)
  XLA (JAX)
   (최적화된 CUDA 코드 생성)

고수준 라이브러리:
  cuDNN (신경망 특화)
  cuBLAS (범용 GEMM)
  CUTLASS (NVIDIA 오픈소스 템플릿)
   (최적화된 PTX/SASS 코드)

하드웨어:
  Tensor Core (H100: 1,979 TFLOPS FP16)

메모리 시스템:
  L1/L2 캐시 (SRAM)
  HBM3 (3.35 TB/s)

FlashAttention은 이 스택의 "고수준 라이브러리" 레이어에서
IO-aware 알고리즘으로 HBM 접근을 최소화함.

행렬 곱셈은 단순해 보이지만, 그 최적화의 깊이는 40년의 역사와 수천 명의 엔지니어링 노력을 품고 있다. GPU 아키텍처의 변화에 따라 최적화 전략도 계속 진화하고 있으며, FlashAttention 시리즈처럼 알고리즘 레벨의 혁신이 하드웨어 개선만큼이나 큰 영향을 미치는 시대가 됐다.

LLM 인프라 엔지니어라면 이 스택 전체를 이해하는 것이 필수다. 다음 글에서는 이 기반 위에서 실제 LLM 서빙 최적화가 어떻게 이루어지는지 살펴보겠다.

How Matrices Fly on GPU: Complete Deep Dive from GEMM to FlashAttention

Why Matrix Multiplication Is All of Deep Learning

When a GPU runs a deep learning model, more than 80% of the actual work reduces to a single operation: matrix multiplication.

If that sounds hard to believe, let's dissect each layer type:

Linear Layer (Fully Connected Layer)

output = W × input + b
shape: (out_features,) = (out_features, in_features) × (in_features,)

The most fundamental transformation in neural networks. Pure matrix-vector multiplication.

Self-Attention (The Heart of Transformers)

Q = W_q × x        # matrix multiply
K = W_k × x        # matrix multiply
V = W_v × x        # matrix multiply
scores = Q × K^T   # matrix multiply (N×d × d×N = N×N)
weights = softmax(scores / sqrt(d_k))
output = weights × V  # matrix multiply (N×N × N×d = N×d)

A single attention layer fires at least 6 matrix multiplications.

Convolution (CNN)

after im2col transform:
output = W × im2col(input)

Even convolution ultimately gets rewritten as matrix multiplication.

FLOP distribution in GPT-3 (175B parameters):

Linear layers (QKV projections):     ~45%
Feed-forward layers:                  ~35%
Attention computation (QK^T, AV):    ~15%
Others (LayerNorm, softmax, etc.):     ~5%
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Matrix multiplication total:         ~95%

How fast you can multiply matrices = how fast training and inference runs. This is why thousands of engineers have spent four decades optimizing this one operation.


The Naive Implementation and Its Problems

Triple Loop: The Definition Itself

import numpy as np

def naive_matmul(A, B):
    """
    A: (M, K)  B: (K, N) -> C: (M, N)
    C[i][j] = sum(A[i][k] * B[k][j] for k in range(K))
    """
    M, K = A.shape
    K2, N = B.shape
    assert K == K2, "Inner dimensions must match"

    C = np.zeros((M, N), dtype=np.float64)

    for i in range(M):          # M iterations
        for j in range(N):      # N iterations
            for k in range(K):  # K iterations
                C[i][j] += A[i][k] * B[k][j]  # 1 MAC operation

    return C
    # Total ops: M * N * K multiply-add operations
    # Complexity: O(n^3)
    # For 4096×4096: 4096^3 ≈ 68 BILLION operations!

Mathematically perfect. In practice, horrifyingly slow.

4096×4096 FP32 matrix multiply:

  • Theoretical FLOP count: 2 × 4096³ ≈ 137 GFLOPS
  • H100 theoretical throughput: 67 TFLOPS (FP32)
  • Theoretical time: 137G / 67T ≈ 2ms
  • Naive Python actual time: ~900 seconds (450,000× slower!)

Two root causes: memory access patterns and unused parallelism.

Cache Misses: The Real Enemy

Memory hierarchy (H100):
┌─────────────────────────────────────────────────────────┐
L1 Cache (per SM):     256KB,  ~5 cycle  latency      │
L2 Cache (shared):     50MB,   ~30 cycle latency      │
HBM3 (GPU main mem):  80GB,   ~300 cycle latency      │
CPU DRAM:              TB-scale, ~1000 cycle latency   │
└─────────────────────────────────────────────────────────┘

Cache hit vs miss: 300 cycles ↔ 5 cycles = 60× difference
Cache hit rate IS performance.

Examine the memory access pattern in the naive triple loop (A[4×4], B[4×4]):

A matrix (row-major storage):
addresses: [A00][A01][A02][A03] [A10][A11][A12][A13] ...
            ←── cache line 1 ─→ ←── cache line 2 ─→

A[i][k] access pattern (i=0 fixed, k=0,1,2,3):
A[0][0] → load cache line 1 (cache MISS, 300 cycles)
A[0][1] → reuse cache line 1 (cache HIT, 5 cycles)A[0][2] → reuse cache line 1 (cache HIT)A[0][3] → reuse cache line 1 (cache HIT)A accessed row-wise: cache-FRIENDLY
B matrix column-wise access (j=0 fixed, k=0,1,2,3):
B[0][0] → load cache line 0 (cache MISS)
B[1][0] → load cache line 2 (cache MISS! different row)B[2][0] → load cache line 4 (cache MISS!)B[3][0] → load cache line 6 (cache MISS!)
Each element of B sits on a different cache line.
For large matrices (N=4096): 99% of B accesses are cache misses!

Cache Blocking (Tiling): The Fix

Core Idea: Divide Into Cache-Resident Tiles

Split matrices into tiles that fit in cache so data loaded once gets reused maximally.

Matrix Tiling Visualization (M=8, N=8, K=8, TILE_SIZE=4):

Divide all matrices into 4×4 tiles:

    B (K×N):
    ┌───┬───┐
B00B01B00 = B[0:4, 0:4]
    ├───┼───┤  B01 = B[0:4, 4:8]
B10B11B10 = B[4:8, 0:4]
    └───┴───┘  B11 = B[4:8, 4:8]

A (M×K):           C (M×N):
┌───┬───┐          ┌───┬───┐
A00A01│          │C00C01C00 = A00×B00 + A01×B10
├───┼───┤          ├───┼───┤  C01 = A00×B01 + A01×B11
A10A11│          │C10C11C10 = A10×B00 + A11×B10
└───┴───┘          └───┴───┘  C11 = A10×B01 + A11×B11

Computing tile C00:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Step 1: Load A00 (4×4=16 floats=64B) into cache
        Load B00 (4×4=16 floats=64B) into cache
        → 128B loaded, 4×4=16 dot products computed
16×16=256 MAC ops (ZERO cache misses!)
Step 2: Load A01 into cache (A00 evicted)
        Load B10 into cache (B00 evicted)
256 more MAC ops (ZERO cache misses!)
Result: C00 complete
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Cache reuse ratio:
- Naive: each element of B reloaded M times (M misses/element)
- Tiled: each element reused TILE_SIZE times within tile

Tiled MatMul Implementation

import numpy as np
import time

def tiled_matmul(A, B, tile_size=64):
    """
    Cache-blocked matrix multiplication.
    tile_size: tune to L1 cache size.
               L1=256KB, float32: sqrt(256K/4/3) ≈ 146
               In practice 64-128 is often optimal.
    """
    M, K = A.shape
    K2, N = B.shape
    assert K == K2
    C = np.zeros((M, N), dtype=A.dtype)

    for i in range(0, M, tile_size):
        i_end = min(i + tile_size, M)
        for j in range(0, N, tile_size):
            j_end = min(j + tile_size, N)
            for k in range(0, K, tile_size):
                k_end = min(k + tile_size, K)
                tile_A = A[i:i_end, k:k_end]  # cache-resident A tile
                tile_B = B[k:k_end, j:j_end]  # cache-resident B tile
                # All ops during this block hit cache!
                C[i:i_end, j:j_end] += tile_A @ tile_B
    return C


def benchmark():
    for N in [512, 1024, 2048]:
        A = np.random.randn(N, N).astype(np.float32)
        B = np.random.randn(N, N).astype(np.float32)

        start = time.perf_counter()
        for _ in range(5):
            _ = A @ B
        t = (time.perf_counter() - start) / 5

        tflops = 2 * N**3 / t / 1e12
        print(f"N={N}: {t*1000:.1f}ms, {tflops:.2f} TFLOPS")

benchmark()

Speedup summary:

N=4096, tile_size=64 vs naive (CPU, FP32):
Naive triple loop:   ~900s
Tiled triple loop:   ~18s      (50× speedup)
NumPy (BLAS):        ~0.8s     (1125× speedup)
CUDA cuBLAS:         ~0.002s   (450,000× speedup)

BLAS: Four Decades of Optimization, Crystallized

The BLAS Level Hierarchy

BLAS (Basic Linear Algebra Subprograms, 1979–present):

Level 1: Vector-Vector ops  (O(n) data, O(n) ops)
  - dot(x, y):     dot product
  - axpy(a, x, y): y = a*x + y
  - nrm2(x):       L2 norm

Level 2: Matrix-Vector ops  (O() data, O() ops)
  - gemv: y = alpha*A*x + beta*y
Low arithmetic intensity → memory-bound

Level 3: Matrix-Matrix ops  (O() data, O() ops)
  - gemm: C = alpha*A*B + beta*CTHE KEY ONE
High arithmetic intensity → can be compute-bound
The core operation of deep learning

GEMM: General Matrix Multiply

GEMM interface:
C = alpha * op(A) * op(B) + beta * C

Parameters:
  transa/transb: transpose flag (N=none, T=transpose, C=conj-transpose)
  m, n, k:       dimensions (C is m×n, A is m×k, B is k×n)
  alpha, beta:   scalar coefficients
  A, B, C:       matrix pointers
  lda, ldb, ldc: leading dimensions (memory strides)

Role of lda:
  When A is a sub-matrix of a larger matrix:
  address of A[i][j] = base + i*lda + j
  lda >= m required
/* cuBLAS example (NVIDIA GPU BLAS) */
#include <cublas_v2.h>

void gpu_matmul(float *d_A, float *d_B, float *d_C,
                int M, int N, int K) {
    cublasHandle_t handle;
    cublasCreate(&handle);

    float alpha = 1.0f;
    float beta  = 0.0f;

    /* cuBLAS uses column-major storage.
       C = A*B (row-major) = (B^T * A^T)^T (column-major) */
    cublasSgemm(
        handle,
        CUBLAS_OP_N, CUBLAS_OP_N,
        N, M, K,          // note: B, A order for column-major
        &alpha,
        d_B, N,           // B (N×K), leading dim=N
        d_A, K,           // A (K×M), leading dim=K
        &beta,
        d_C, N            // C (N×M), leading dim=N
    );

    cublasDestroy(handle);
    /* cuBLAS achieves 95%+ of theoretical peak
       Equivalent to thousands of person-years of tuning */
}
# PyTorch automatically routes through cuBLAS:
import torch

A = torch.randn(4096, 4096, device='cuda', dtype=torch.float16)
B = torch.randn(4096, 4096, device='cuda', dtype=torch.float16)

# Internally calls cuBLAS HGEMM (Half-precision GEMM)
C = torch.mm(A, B)

# Benchmarking:
import torch.utils.benchmark as benchmark
t = benchmark.Timer(
    stmt='torch.mm(A, B)',
    globals={'A': A, 'B': B}
)
result = t.timeit(100)
print(f"H100 FP16 GEMM 4096^2: {result.mean*1000:.2f}ms")
# Result: ~0.5ms, ~2000 TFLOPS (theoretical max: 1979 TFLOPS)

Tensor Cores: Dedicated Matrix Multiply Hardware

NVIDIA Tensor Core (H100):

Standard CUDA Core:
  Per cycle: 1 FMA (Fused Multiply-Add)
  Throughput: 1 FLOP/cycle/core

Tensor Core:
  Per cycle: 4×8 × 8×8 matrix multiply
  Throughput: 256 FLOP/cycle (256× a regular core!)
  Supported: FP16, BF16, INT8, INT4, FP8 (H100)

H100 SXM5 numbers:
  CUDA Cores:    16,896 × 2 FP32 = ~67 TFLOPS
  Tensor Cores:  528 × 2048 FP16  = ~1,979 TFLOPS (29.5×)

Why FP16 is the default in deep learning: Tensor Core utilization!

FlashAttention: Breaking Through the Memory Wall

The Memory Complexity Problem of Standard Attention

The biggest bottleneck in Transformers is the memory complexity of attention.

import torch
import torch.nn.functional as F

def attention_standard(Q, K, V, scale=None):
    """
    Standard attention implementation.
    Q: (batch, heads, seq_len, d_k)
    K: (batch, heads, seq_len, d_k)
    V: (batch, heads, seq_len, d_v)
    """
    if scale is None:
        scale = Q.shape[-1] ** -0.5

    # Step 1: Q×K^T → (batch, heads, seq_len, seq_len)
    # For seq_len=4096, d_k=128, 32 heads:
    # scores size: 1 * 32 * 4096 * 4096 * 2 bytes = 1 GB!
    # Must be materialized in HBM (slow)
    scores = torch.matmul(Q, K.transpose(-2, -1)) * scale

    # Step 2: softmax → another HBM read + write
    weights = F.softmax(scores, dim=-1)   # 1GB read + 1GB write

    # Step 3: weights×V → another HBM read
    output = torch.matmul(weights, V)

    return output

# seq_len=4096, 32 heads:
# HBM traffic: scores(1GB) + softmax(2GB) + AV(1GB) = ~4GB
# H100 HBM bandwidth: 3.35 TB/s
# Theoretical minimum time: 4GB / 3.35T = ~1.2ms per attention layer
# Most of attention time is memory movement, not computation!

Attention is memory-bound. The GPU's compute units sit idle while HBM accesses dominate runtime.

IO-Aware Attention: The FlashAttention Insight

Tri Dao et al. (2022) asked: "Do we actually need to write the full N×N attention matrix to HBM?"

GPU Memory Hierarchy (H100):

SRAM (L1/Shared Memory, per SM):
  Size:      ~228KB per SM (132 SMs on H100)
  Bandwidth: ~19 TB/s (6× faster than HBM!)
  Latency:   ~5 cycles
Very fast

HBM (High Bandwidth Memory):
  Size:      80GB
  Bandwidth: ~3.35 TB/s
  Latency:   ~300 cycles
Slow (relatively)

FlashAttention strategy:
  Instead of writing N×N to HBM,
  process it in tiles that fit entirely in SRAM!
FlashAttention Algorithm Visualization:

Inputs in HBM:
  Q: (N, d) ─────────────────────────┐
  K: (N, d) ─────────────────────────┤ (tiled reads)
  V: (N, d) ─────────────────────────┘
SRAM (tile processing):
  ┌────────────────────────────────────┐
Q_tile: (block_q, d)K_tile: (block_k, d)V_tile: (block_k, d)S_tile: (block_q, block_k)       │ ← NEVER written to HBM!
  │  running_max: (block_q,)  │  running_sum: (block_q,)O_tile: (block_q, d)  └────────────────────────────────────┘
Output in HBM:
  O: (N, d) ← written once per tile

HBM accesses:
  Standard attention: O(N^2) (read/write full N×N matrix)
  FlashAttention:     O(N)   (tiled linear access)

Online Softmax: The Mathematical Key

FlashAttention's core trick is the online softmax: computing exact softmax incrementally without seeing the full sequence.

import torch
import math

def flash_attention_simplified(Q, K, V, block_size=64):
    """
    FlashAttention core logic (simplified version).
    Q, K, V: (N, d) — single-head, no batch for clarity
    """
    N, d = Q.shape
    scale = 1.0 / math.sqrt(d)
    dtype = Q.dtype

    # Initialize accumulators
    O = torch.zeros(N, d, dtype=dtype, device=Q.device)
    L = torch.zeros(N, dtype=dtype, device=Q.device)       # normalization factor
    M = torch.full((N,), float('-inf'), dtype=dtype,
                   device=Q.device)                         # running max

    # Online softmax math:
    # When new tile arrives with scores x_new:
    #   m_new = max(m_old, max(x_new))
    #   l_new = exp(m_old - m_new)*l_old + sum(exp(x_new - m_new))
    #   O_new = (exp(m_old-m_new)*l_old*O_old + exp(x_new-m_new)@V_new) / l_new

    # Outer loop: K, V tiles  (loop over j)
    for j_start in range(0, N, block_size):
        j_end = min(j_start + block_size, N)
        K_j = K[j_start:j_end]   # load to SRAM
        V_j = V[j_start:j_end]   # load to SRAM

        # Inner loop: Q tiles  (loop over i)
        for i_start in range(0, N, block_size):
            i_end = min(i_start + block_size, N)
            Q_i = Q[i_start:i_end]   # load to SRAM

            # Attention scores (computed entirely in SRAM!)
            S_ij = (Q_i @ K_j.T) * scale  # (block_q, block_k)

            # Running max update
            m_i_old = M[i_start:i_end]
            m_ij = S_ij.max(dim=-1).values
            m_i_new = torch.maximum(m_i_old, m_ij)

            # Running sum update with correction factor
            correction = torch.exp(m_i_old - m_i_new)
            l_i_new = (correction * L[i_start:i_end] +
                       torch.exp(S_ij - m_i_new.unsqueeze(-1)).sum(dim=-1))

            # Output update
            P_ij = torch.exp(S_ij - m_i_new.unsqueeze(-1))
            O[i_start:i_end] = (
                (correction * L[i_start:i_end]).unsqueeze(-1) * O[i_start:i_end]
                + P_ij @ V_j
            ) / l_i_new.unsqueeze(-1)

            M[i_start:i_end] = m_i_new
            L[i_start:i_end] = l_i_new

    return O  # written to HBM only once!


# FlashAttention performance numbers (A100 80GB, seq_len=2048):
# Standard attention:      12.3ms
# FlashAttention v1:        3.4ms (3.6× faster)
# FlashAttention v2:        2.1ms (5.9× faster)
# FlashAttention v3 (H100): 1.4ms (8.8× faster)

Additional Benefits of FlashAttention

Backward Pass Memory Savings:
  Standard attention: must store attention weights (N×N) for backprop
4096 tokens: 64MB per head per layer
32 heads × 32 layers × batchsize = massive memory
  FlashAttention: no need to store attention weights
only (m, l) vectors needed: O(N) memory
    → recompute blocks during backward pass

FlashAttention-2 (2023) additional optimizations:
  - Swap Q and K,V loop order (better thread-level parallelism)
  - Eliminate unnecessary rescaling operations
  - Better warp-level work partitioning
73% of theoretical peak (v1: 40%)

FlashAttention-3 (2024, H100-specific):
  - Overlap GEMM with softmax using asynchronous Tensor Cores
  - FP8 support (2× the throughput of FP16)
75%+ of theoretical peak on H100

GEMM Performance Analysis: The Roofline Model

When Is It Compute-Bound vs. Memory-Bound?

Roofline Model:
  Attainable performance = min(compute_peak,
                               memory_bw × arithmetic_intensity)

  Arithmetic Intensity (AI) = FLOPS / BYTES_ACCESSED

H100 SXM5 parameters:
  FP16 Tensor Cores:  1,979 TFLOPS
  HBM bandwidth:      3,350 GB/s
  Ridge Point: 1979T / 3.35T = 590 FLOP/Byte
AI > 590: compute-bound
AI < 590: memory-bound
def analyze_matmul_intensity(M, K, N, dtype_bytes=2):
    """
    Compute arithmetic intensity of (M×K) × (K×N) = (M×N)
    """
    flops = 2 * M * K * N  # multiply + add

    # Memory: read A, read B, write C
    bytes_accessed = (M * K + K * N + M * N) * dtype_bytes

    ai = flops / bytes_accessed
    return flops, bytes_accessed, ai


# Case 1: Large matrix (training with big batches)
M = N = K = 4096
flops, bytes_val, ai = analyze_matmul_intensity(M, K, N)
print(f"Large GEMM (4096^3):")
print(f"  FLOPs:  {flops/1e12:.1f} TFLOPS")
print(f"  Bytes:  {bytes_val/1e6:.0f} MB")
print(f"  AI:     {ai:.0f} FLOP/Byte")
print(f"  H100 Ridge: 590 FLOP/Byte")
print(f"  Status: {'COMPUTE-BOUND' if ai > 590 else 'MEMORY-BOUND'}")
# → AI=1370, COMPUTE-BOUND ✅

print()

# Case 2: Inference (batch_size=1)
M, K, N = 1, 4096, 4096
flops, bytes_val, ai = analyze_matmul_intensity(M, K, N)
print(f"Inference GEMV (batch=1):")
print(f"  FLOPs:  {flops/1e6:.1f} MFLOPS")
print(f"  Bytes:  {bytes_val/1e6:.0f} MB (nearly all from matrix B!)")
print(f"  AI:     {ai:.2f} FLOP/Byte")
print(f"  Status: {'COMPUTE-BOUND' if ai > 590 else 'MEMORY-BOUND'}")
# → AI=0.99, MEMORY-BOUND ❌
# This is the fundamental reason LLM inference is hard to optimize!
Real performance measurements (H100, FP16, 4096×4096 weight):

Batch size vs. GEMM performance:
┌───────────┬────────────┬────────────┬──────────────────┐
Batch (M)TFLOPSAIBound├───────────┼────────────┼────────────┼──────────────────┤
10.016~1Memory-bound ❌  │
40.063~4Memory-bound ❌  │
160.25~16Memory-bound     │
640.98~63Memory-bound     │
2563.8~250Transition zone  │
102414.2~1000Compute-bound ✅ │
40961,847~1370Compute-bound ✅ │
└───────────┴────────────┴────────────┴──────────────────┘

Larger batch = higher GPU utilization
This is why batching matters so much in LLM serving

Practical CUDA GEMM Optimization Tips

import torch

# Tip 1: Enable TF32 (drop-in FP32 replacement on Ampere+)
# FP32: 23-bit mantissa → TF32: 10-bit mantissa (slight precision loss)
# But enables Tensor Core usage → ~10× faster
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# Tip 2: Ensure contiguous memory tensors
# Non-contiguous tensors arise after torch.transpose, etc.
x = torch.randn(64, 128, 128, device='cuda')
x_t = x.transpose(1, 2)          # non-contiguous!
x_t_cont = x_t.contiguous()      # copy to contiguous memory

# Tip 3: Align matrix dimensions to multiples of 16
# H100 Tensor Cores are most efficient at 16×n sizes
def pad_to_multiple(x, multiple=16):
    d = x.shape[-1]
    if d % multiple == 0:
        return x
    pad_size = multiple - (d % multiple)
    return torch.nn.functional.pad(x, (0, pad_size))

# Tip 4: torch.compile (PyTorch 2.0+)
@torch.compile
def optimized_linear(x, weight, bias=None):
    """
    torch.compile auto-generates fused kernels,
    selects optimal memory layouts, and exploits
    hardware-specific features.
    """
    return torch.nn.functional.linear(x, weight, bias)

# Tip 5: Use BF16 for training, FP16 for inference
# BF16: same range as FP32, less precision → stable training
# FP16: higher precision, possible overflow → needs loss scaling
model = model.to(torch.bfloat16)   # training
model = model.to(torch.float16)    # inference (if no overflow risk)

The Full Picture: Matrix Multiply Optimization Stack

Modern deep learning matrix multiplication stack:

Application Layer:
  PyTorch / JAX / TensorFlow
   (torch.mm, F.linear, etc.)

Kernel Fusion / Compiler:
  torch.compile (TorchInductor)
  XLA (JAX backend)
   (generates optimized CUDA code)

High-Level Libraries:
  cuDNN  (neural network focused)
  cuBLAS (general-purpose GEMM)
  CUTLASS (NVIDIA open-source templates)
   (optimized PTX/SASS assembly)

Hardware:
  Tensor Cores (H100: 1,979 TFLOPS FP16)

Memory System:
  L1/L2 cache (SRAM)
  HBM3 (3.35 TB/s)

FlashAttention lives at the "High-Level Libraries" layer,
using an IO-aware algorithm to minimize HBM accesses
and keep everything in fast SRAM.

Matrix multiplication looks simple, but its optimization runs 40 years deep, representing the work of thousands of engineers. As GPU architectures evolve, so do the optimization strategies. The FlashAttention lineage demonstrates that algorithmic innovation at the software level can match hardware improvements in impact.

For LLM infrastructure engineers, understanding this full stack is not optional. The next post in this series digs into how these foundations power real-world LLM serving optimization: KV cache, PagedAttention, continuous batching, and quantization.