Skip to content
Published on

행렬이 GPU에서 어떻게 날아다니는가: GEMM부터 FlashAttention까지 완전 해부

Authors

왜 행렬 곱셈이 딥러닝의 전부인가

딥러닝 모델을 실행할 때 GPU가 실제로 하는 일의 80% 이상은 단 하나의 연산으로 귀결된다. 바로 **행렬 곱셈(Matrix Multiplication)**이다.

믿기 어렵다면 각 레이어를 뜯어보자:

Linear Layer (Fully Connected Layer)

output = W × input + b
shape: (out_features,) = (out_features, in_features) × (in_features,)

가장 기본적인 선형 변환. 순수한 행렬-벡터 곱셈이다.

Self-Attention (Transformer의 심장)

Q = W_q × x        # 행렬 곱셈
K = W_k × x        # 행렬 곱셈
V = W_v × x        # 행렬 곱셈
scores = Q × K^T   # 행렬 곱셈 (N×d × d×N = N×N)
weights = softmax(scores / sqrt(d_k))
output = weights × V  # 행렬 곱셈 (N×N × N×d = N×d)

한 번의 어텐션 레이어에서 최소 6번의 행렬 곱셈이 발생한다.

Convolution (CNN)

im2col 변환 후:
output = W × im2col(input)

컨볼루션도 결국 행렬 곱셈으로 변환된다.

GPT-3 (175B) 기준 FLOPs 분포:

Linear layers (QKV projection):  ~45%
Feed-forward layers:              ~35%
Attention computation (QK^T, AV): ~15%
Others (LayerNorm, softmax, etc.):  ~5%
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
행렬 곱셈 합계:                   ~95%

행렬 곱셈을 얼마나 빠르게 할 수 있느냐 = 딥러닝 추론/학습 속도가 결정된다. 이것이 40년간 수천 명의 엔지니어가 행렬 곱셈 최적화에 매달려온 이유다.


순진한(Naive) 행렬 곱셈과 그 문제점

3중 루프: 정의 그 자체

import numpy as np

def naive_matmul(A, B):
    """
    A: (M, K)  B: (K, N) -> C: (M, N)
    C[i][j] = sum(A[i][k] * B[k][j] for k in range(K))
    """
    M, K = A.shape
    K2, N = B.shape
    assert K == K2, "Inner dimensions must match"

    C = np.zeros((M, N), dtype=np.float64)

    for i in range(M):          # M 반복
        for j in range(N):      # N 반복
            for k in range(K):  # K 반복
                C[i][j] += A[i][k] * B[k][j]  # MAC 연산 1회

    return C
    # 총 연산: M * N * K 번의 multiply-add
    # 복잡도: O(n^3)
    # 4096×4096 행렬: 4096^3 = 약 680억 번의 연산!

수학적으로는 완벽하다. 하지만 실제로 돌려보면 끔찍하게 느리다.

4096×4096 FP32 행렬 곱셈:

  • 이론 연산량: 2 × 4096³ ≈ 137 GFLOPS
  • H100 이론 성능: 67 TFLOPS (FP32)
  • 이론 시간: 137G / 67T ≈ 2ms
  • Naive Python 실제 시간: 약 900초 (450,000배 느림!)

차이의 원인은 두 가지다. 메모리 접근 패턴병렬성 미활용.

캐시 미스: 진짜 적

메모리 계층 구조 (H100 기준):
┌─────────────────────────────────────────────────────────┐
L1 Cache (per SM):     256KB,  ~5 cycles  latency     │
L2 Cache (shared):     50MB,   ~30 cycles latency     │
HBM3 (GPU main mem):  80GB,   ~300 cycles latency     │
CPU DRAM:TB,   ~1000 cycles latency    │
└─────────────────────────────────────────────────────────┘

캐시 접근: 300 cycles ↔ 5 cycles = 60배 차이
캐시 적중률이 곧 성능이다.

Naive 3중 루프에서 메모리 접근 패턴을 보자 (A[4×4], B[4×4] 예시):

A 행렬 (행 우선 저장, row-major):
메모리 주소: [A00][A01][A02][A03] [A10][A11][A12][A13] ...
              ←── 캐시라인 1 ──→  ←── 캐시라인 2 ──→

A[i][k] 접근 패턴 (i=0 고정, k=0,1,2,3 순서):
A[0][0] → 캐시라인 1 로드 (cache MISS, 300 cycles)
A[0][1] → 캐시라인 1 재사용 (cache HIT, 5 cycles)A[0][2] → 캐시라인 1 재사용 (cache HIT)A[0][3] → 캐시라인 1 재사용 (cache HIT)A는 행 방향 접근: 캐시 친화적 ✅

B 행렬의 열 방향 접근 (j=0 고정, k=0,1,2,3 순서):
B[0][0] → 캐시라인 0 로드 (cache MISS)
B[1][0] → 캐시라인 2 로드 (cache MISS! 다른 행)B[2][0] → 캐시라인 4 로드 (cache MISS!)B[3][0] → 캐시라인 6 로드 (cache MISS!)
B의 각 원소가 별도 캐시라인에 있음
대형 행렬(N=4096)에서 B 접근의 99%가 캐시 미스!

실제로 측정해보면:

import time
import numpy as np

def measure_matmul():
    N = 1024
    A = np.random.randn(N, N).astype(np.float32)
    B = np.random.randn(N, N).astype(np.float32)
    B_T = B.T.copy()  # B의 전치를 연속 메모리에 복사

    # 방법 1: 순진한 열 접근 (캐시 비친화적)
    start = time.perf_counter()
    C1 = A @ B  # numpy도 내부적으로 캐시 최적화함
    t1 = time.perf_counter() - start

    print(f"NumPy (최적화됨): {t1*1000:.2f}ms")
    # 순수 Python 3중 루프는 수백 배 더 느림

measure_matmul()

Cache Blocking (Tiling): 해결책

핵심 아이디어: 캐시에 맞는 블록으로 분할

행렬을 캐시에 들어가는 작은 타일(tile)로 분할해 한 번 로드한 데이터를 최대한 재사용한다.

Matrix Tiling 시각화 (M=8, N=8, K=8, TILE_SIZE=4):

전체 행렬을 4×4 타일로 분할:

    B (K×N):
    ┌───┬───┐
B00B01B00 = B[0:4, 0:4]
    ├───┼───┤  B01 = B[0:4, 4:8]
B10B11B10 = B[4:8, 0:4]
    └───┴───┘  B11 = B[4:8, 4:8]

A (M×K):           C (M×N):
┌───┬───┐          ┌───┬───┐
A00A01│          │C00C01C00 = A00×B00 + A01×B10
├───┼───┤          ├───┼───┤  C01 = A00×B01 + A01×B11
A10A11│          │C10C11C10 = A10×B00 + A11×B10
└───┴───┘          └───┴───┘  C11 = A10×B01 + A11×B11

타일 C00 계산 과정:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Step 1: A00(4×4=16 floats=64B) 캐시 로드
        B00(4×4=16 floats=64B) 캐시 로드
        → 128B 로드로 4×4=16번의 내적 계산
16×16=256번의 MAC 연산 (캐시 미스 ZERO!)
Step 2: A01(4×4) 캐시 로드 (A00 evict)
        B10(4×4) 캐시 로드 (B00 evict)
256번의 MAC 연산 (캐시 미스 ZERO!)
결과: C00 완성
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

캐시 재사용률:
- Naive: B의 각 원소를 M번씩 다시 로드 (M번 캐시미스/원소)
- Tiled: 타일 내 원소를 TILE_SIZE재사용 (TILE_SIZE배 캐시미스 감소)

Tiled MatMul 구현

import numpy as np

def tiled_matmul(A, B, tile_size=64):
    """
    캐시 블로킹을 이용한 행렬 곱셈
    tile_size: L1 캐시 크기에 맞게 조정
               L1=256KB, float32: sqrt(256K/4/3) ≈ 146
               실제로는 64-128이 최적인 경우가 많음
    """
    M, K = A.shape
    K2, N = B.shape
    assert K == K2
    C = np.zeros((M, N), dtype=A.dtype)

    for i in range(0, M, tile_size):        # i 타일
        i_end = min(i + tile_size, M)
        for j in range(0, N, tile_size):    # j 타일
            j_end = min(j + tile_size, N)
            # C[i:i_end, j:j_end]를 위한 누적
            for k in range(0, K, tile_size):  # k 타일 (내적 분할)
                k_end = min(k + tile_size, K)
                tile_A = A[i:i_end, k:k_end]   # 캐시에 올라가는 A 타일
                tile_B = B[k:k_end, j:j_end]   # 캐시에 올라가는 B 타일
                # 이 연산 동안 tile_A, tile_B는 캐시에 상주
                C[i:i_end, j:j_end] += tile_A @ tile_B
    return C


def benchmark_tiling():
    """성능 비교"""
    import time
    sizes = [512, 1024, 2048]

    for N in sizes:
        A = np.random.randn(N, N).astype(np.float32)
        B = np.random.randn(N, N).astype(np.float32)

        # NumPy (BLAS 최적화)
        start = time.perf_counter()
        for _ in range(5):
            C_ref = A @ B
        t_blas = (time.perf_counter() - start) / 5

        print(f"N={N}: BLAS={t_blas*1000:.1f}ms, "
              f"TFLOPS={2*N**3/t_blas/1e12:.2f}")

benchmark_tiling()

타일링의 효과:

N=4096, tile_size=64 vs naive (CPU, FP32):
Naive 3-loop:  ~900s
Tiled 3-loop:  ~18s    (50배 향상)
NumPy (BLAS):  ~0.8s   (1125배 향상)
CUDA cuBLAS:   ~0.002s (450,000배 향상)

BLAS: 40년의 최적화 결정체

BLAS 레벨 계층

BLAS (Basic Linear Algebra Subprograms, 1979~):

Level 1: Vector-Vector 연산 (O(n) 데이터, O(n) 연산)
  - dot(x, y):    내적
  - axpy(a, x, y): y = a*x + y  (벡터 덧셈)
  - nrm2(x):      L2 노름

Level 2: Matrix-Vector 연산 (O() 데이터, O() 연산)
  - gemv: y = alpha*A*x + beta*y
  → 산술 집약도 낮음, memory-bound

Level 3: Matrix-Matrix 연산 (O() 데이터, O() 연산)
  - gemm: C = alpha*A*B + beta*CTHE ONE WE CARE ABOUT
  → 산술 집약도 높음, compute-bound 가능
  → 딥러닝의 핵심 연산

GEMM: General Matrix Multiply

GEMM 인터페이스:
C = alpha * op(A) * op(B) + beta * C

파라미터:
  transa/transb: 전치 여부 (N=없음, T=전치, C=켤레전치)
  m, n, k:       행렬 차원 (C는 m×n, A는 m×k, B는 k×n)
  alpha, beta:   스칼라 계수 (alpha=1.0, beta=0.0이면 순수 곱셈)
  A, B, C:       행렬 포인터
  lda, ldb, ldc: leading dimension (메모리 스트라이드)

lda의 역할 (중요!):
  행렬 A가 더 큰 행렬의 부분행렬일 때:
  A[i][j]의 메모리 주소 = base + i*lda + j
  lda >= m 이어야 함
/* cuBLAS 예시 (NVIDIA GPU용 BLAS) */
#include <cublas_v2.h>

void gpu_matmul(float *d_A, float *d_B, float *d_C,
                int M, int N, int K) {
    cublasHandle_t handle;
    cublasCreate(&handle);

    float alpha = 1.0f;
    float beta  = 0.0f;

    /* cuBLAS는 열 우선(column-major) 저장 사용
       C = A*B (행 우선) = (B^T * A^T)^T (열 우선) */
    cublasSgemm(
        handle,
        CUBLAS_OP_N, CUBLAS_OP_N,  // 전치 없음
        N, M, K,                    // 주의: cuBLAS는 B,A 순서!
        &alpha,
        d_B, N,                    // B (N×K), leading dim=N
        d_A, K,                    // A (K×M), leading dim=K
        &beta,
        d_C, N                     // C (N×M), leading dim=N
    );

    cublasDestroy(handle);
    /* cuBLAS는 이론 최고 성능의 95%+ 달성
       수천 person-year의 최적화 결정체 */
}
# PyTorch에서 cuBLAS 활용 (자동으로 사용됨):
import torch

A = torch.randn(4096, 4096, device='cuda', dtype=torch.float16)
B = torch.randn(4096, 4096, device='cuda', dtype=torch.float16)

# 내부적으로 cuBLAS HGEMM (Half-precision GEMM) 호출
C = torch.mm(A, B)

# 성능 프로파일링:
import torch.utils.benchmark as benchmark
t = benchmark.Timer(
    stmt='torch.mm(A, B)',
    globals={'A': A, 'B': B}
)
result = t.timeit(100)
print(f"H100 FP16 GEMM 4096^2: {result.mean*1000:.2f}ms")
# 결과: ~0.5ms, ~2000 TFLOPS (이론 최고 2000 TFLOPS)

Tensor Core: 행렬 곱셈 전용 하드웨어

NVIDIA Tensor Core (H100 기준):

일반 CUDA Core:
  한 번에: 1 FMA (Fused Multiply-Add) 실행
  처리량: 1 FLOP/cycle/core

Tensor Core:
  한 번에: 4×8 × 8×8 = 4×8 행렬 곱셈 실행
  처리량: 256 FLOP/cycle (일반 코어의 256!)
  지원 형식: FP16, BF16, INT8, INT4, FP8 (H100)

H100 SXM5:
  CUDA Cores:   16,896개 × 2 FP32 OPS = ~67 TFLOPS
  Tensor Cores: 528개     × 2048 FP16  = ~1,979 TFLOPS (29.5)

딥러닝에서 FP16이 기본인 이유: Tensor Core 활용!

FlashAttention: 메모리 벽 돌파하기

표준 어텐션의 메모리 복잡도 문제

Transformers의 가장 큰 병목은 어텐션의 메모리 복잡도다.

import torch
import torch.nn.functional as F

def attention_standard(Q, K, V, scale=None):
    """
    표준 어텐션 구현
    Q: (batch, heads, seq_len, d_k)
    K: (batch, heads, seq_len, d_k)
    V: (batch, heads, seq_len, d_v)
    """
    if scale is None:
        scale = Q.shape[-1] ** -0.5

    # Step 1: Q×K^T → (batch, heads, seq_len, seq_len)
    # seq_len=4096, d_k=128:
    # 크기: batch * heads * 4096 * 4096 * 2bytes
    # = 1 * 32 * 16MB = 512MB !!! (HBM에 저장)
    scores = torch.matmul(Q, K.transpose(-2, -1)) * scale

    # Step 2: softmax → 다시 HBM 읽기+쓰기
    weights = F.softmax(scores, dim=-1)      # 512MB 읽기+512MB 쓰기

    # Step 3: weights×V → 다시 HBM 읽기
    output = torch.matmul(weights, V)        # 512MB 읽기

    return output

# seq_len=4096, 32 heads 기준:
# HBM 접근량: scores 생성(512MB) + softmax(1GB) + AV(512MB) = ~2GB
# H100 HBM 대역폭: 3.35 TB/s
# 이론 최소 시간: 2GB / 3.35T = ~0.6ms per forward pass
# → 어텐션의 대부분 시간이 계산이 아닌 메모리 이동!

어텐션은 메모리 바운드(memory-bound) 연산이다. GPU의 계산 유닛은 놀고 있고, HBM(High Bandwidth Memory) 접근이 병목이다.

IO-Aware 어텐션: FlashAttention의 핵심

Tri Dao et al. (2022)의 통찰: "N×N 어텐션 행렬을 굳이 전부 HBM에 쓸 필요가 있나?"

GPU 메모리 계층 (H100):

SRAM (L1/Shared Memory, per SM):
  크기: ~228KB per SM (H100에서 132SM)
  대역폭: ~19 TB/s (HBM6!)
  지연: ~5 cycles
  ↕ 매우 빠름

HBM (High Bandwidth Memory):
  크기: 80GB
  대역폭: ~3.35 TB/s
  지연: ~300 cycles
느림 (상대적으로)

FlashAttention 전략:
  N×N 행렬을 HBM에 쓰지 말고,
  SRAM에 들어가는 타일로 분할해서 처리하자!
FlashAttention 알고리즘 시각화:

Input in HBM:
  Q: (N, d) ─────────────────────────────┐
  K: (N, d) ─────────────────────────────┤
  V: (N, d) ─────────────────────────────┘
              (타일 단위로 SRAM으로 로드)
SRAM (타일 처리):
  ┌──────────────────────────────────────┐
Q_tile: (block_q, d)K_tile: (block_k, d)V_tile: (block_k, d)S_tile: (block_q, block_k)         │ ← HBM에 절대 쓰지 않음!
  │  running_max: (block_q,)  │  running_sum: (block_q,)O_tile: (block_q, d)  └──────────────────────────────────────┘
Output in HBM:
  O: (N, d) ← 타일 완성 후 기록

HBM 접근 횟수:
  표준 어텐션: O(N^2) (N×N 행렬 읽기/쓰기)
  FlashAttention: O(N) (타일 단위 선형 접근)

Online Softmax: 수학적 열쇠

FlashAttention의 핵심 수학적 트릭은 온라인 소프트맥스다. 전체 시퀀스를 보지 않고도 정확한 소프트맥스를 계산할 수 있다.

import torch
import math

def online_softmax_demo():
    """
    온라인 소프트맥스의 수학적 원리 설명

    일반 소프트맥스: softmax(x)[i] = exp(x[i]) / sum(exp(x[j]))
    문제: 전체 x를 알아야 함 (타일 처리 불가)

    수치 안정적 소프트맥스:
    m = max(x)
    softmax(x)[i] = exp(x[i] - m) / sum(exp(x[j] - m))
    → m을 빼도 결과는 동일, exp 오버플로 방지

    온라인 업데이트 (새 타일 x_new가 도착할 때):
    m_new = max(m_old, max(x_new))
    l_new = exp(m_old - m_new) * l_old + sum(exp(x_new - m_new))
    O_new = (exp(m_old - m_new) * l_old * O_old
             + exp(x_new - m_new) @ V_new) / l_new
    """
    pass

def flash_attention_simplified(Q, K, V, block_size=64):
    """
    FlashAttention 핵심 로직 (단순화 버전)
    Q, K, V: (N, d)
    """
    N, d = Q.shape
    scale = 1.0 / math.sqrt(d)
    dtype = Q.dtype

    # 출력 초기화
    O = torch.zeros(N, d, dtype=dtype, device=Q.device)
    L = torch.zeros(N, dtype=dtype, device=Q.device)       # 정규화 팩터
    M = torch.full((N,), float('-inf'), dtype=dtype, device=Q.device)  # 최댓값

    # 외부 루프: K, V 타일
    for j_start in range(0, N, block_size):
        j_end = min(j_start + block_size, N)
        K_j = K[j_start:j_end]   # SRAM에 로드
        V_j = V[j_start:j_end]   # SRAM에 로드

        # 내부 루프: Q 타일
        for i_start in range(0, N, block_size):
            i_end = min(i_start + block_size, N)
            Q_i = Q[i_start:i_end]   # SRAM에 로드

            # 어텐션 스코어 (SRAM 내에서 계산!)
            S_ij = (Q_i @ K_j.T) * scale  # (block_q, block_k)

            # 온라인 최댓값 업데이트
            m_i_old = M[i_start:i_end]
            m_ij = S_ij.max(dim=-1).values  # 현재 타일의 최댓값
            m_i_new = torch.maximum(m_i_old, m_ij)

            # 온라인 sum 업데이트
            # 이전 항: exp(m_old - m_new) * l_old (보정 팩터 적용)
            correction = torch.exp(m_i_old - m_i_new)
            l_i_new = (correction * L[i_start:i_end] +
                       torch.exp(S_ij - m_i_new.unsqueeze(-1)).sum(dim=-1))

            # 출력 업데이트
            P_ij = torch.exp(S_ij - m_i_new.unsqueeze(-1))  # softmax 분자
            O[i_start:i_end] = (
                (correction * L[i_start:i_end]).unsqueeze(-1) * O[i_start:i_end]
                + P_ij @ V_j
            ) / l_i_new.unsqueeze(-1)

            # 상태 저장
            M[i_start:i_end] = m_i_new
            L[i_start:i_end] = l_i_new

    return O  # HBM에 한 번만 씀!


# FlashAttention 성능 수치 (A100 80GB, seq_len=2048):
# 표준 어텐션:     12.3ms
# FlashAttention v1: 3.4ms (3.6배 빠름)
# FlashAttention v2: 2.1ms (5.9배 빠름)
# FlashAttention v3: 1.4ms (8.8배 빠름, H100 전용)

FlashAttention의 추가 이점

역방향 패스(Backward Pass) 메모리 절약:
  표준 어텐션: 순전파에서 attention weights (N×N) 저장 필수
4096 토큰: 64MB per head per layer
32 heads × 32 layers = 65 GB (배치 1에서도!)
  FlashAttention: attention weights 저장 불필요
               (m, l) 벡터만 저장 (O(N) 메모리)
              → 역전파 시 블록 재계산 (recomputation)

FlashAttention-2 추가 최적화 (2023):
  - Q 루프와 K,V 루프 순서 교환 (스레드 병렬성 향상)
  - 불필요한 rescale 연산 제거
  - 워프(warp) 간 작업 분할 최적화
  → 이론 최고 성능의 73% 달성 (v1: 40%)

GEMM 성능 분석: Roofline 모델

언제 Compute-Bound, 언제 Memory-Bound?

Roofline Model:
  성능 상한 = min(compute_peak, memory_bw × arithmetic_intensity)

  산술 집약도 (Arithmetic Intensity, AI):
    AI = FLOPS / BYTES_ACCESSED (FLOP/Byte)

H100 SXM5 파라미터:
  FP16 Tensor Core: 1,979 TFLOPS
  HBM 대역폭:       3,350 GB/s = 3.35 TB/s
  Ridge Point: 1979T / 3.35T = 590 FLOP/Byte
AI > 590: Compute-bound
AI < 590: Memory-bound
def analyze_matmul_intensity(M, K, N, dtype_bytes=2):
    """
    행렬 곱셈 (M×K) × (K×N) = (M×N)의 산술 집약도 계산
    dtype_bytes: FP16=2, FP32=4
    """
    flops = 2 * M * K * N  # multiply + add

    # 메모리 접근: A 읽기 + B 읽기 + C 쓰기
    bytes_accessed = (M * K + K * N + M * N) * dtype_bytes

    ai = flops / bytes_accessed
    return flops, bytes_accessed, ai


# 케이스 1: 대형 행렬 (훈련 시 배치 처리)
M = N = K = 4096
flops, bytes, ai = analyze_matmul_intensity(M, K, N)
print(f"Large GEMM (4096^3):")
print(f"  FLOPs: {flops/1e12:.1f} TFLOPS")
print(f"  Bytes: {bytes/1e6:.0f} MB")
print(f"  AI: {ai:.0f} FLOP/Byte")
print(f"  H100 Ridge: 590 FLOP/Byte")
print(f"  Status: {'COMPUTE-BOUND' if ai > 590 else 'MEMORY-BOUND'}")
# 결과: AI=1370, COMPUTE-BOUND ✅

print()

# 케이스 2: 추론 (배치 크기=1)
M, K, N = 1, 4096, 4096
flops, bytes, ai = analyze_matmul_intensity(M, K, N)
print(f"Inference GEMV (batch=1):")
print(f"  FLOPs: {flops/1e6:.1f} MFLOPS")
print(f"  Bytes: {bytes/1e6:.0f} MB (거의 B 행렬 전체!)")
print(f"  AI: {ai:.2f} FLOP/Byte")
print(f"  Status: {'COMPUTE-BOUND' if ai > 590 else 'MEMORY-BOUND'}")
# 결과: AI=0.99, MEMORY-BOUND ❌
# 이것이 LLM 추론 최적화가 어려운 근본 이유!
실제 성능 측정 (H100, FP16):

배치 크기별 GEMM 성능 (4096×4096 weight matrix):
┌───────────┬────────────┬────────────┬──────────────────┐
Batch (M)TFLOPSAI         │ 상태             │
├───────────┼────────────┼────────────┼──────────────────┤
10.016~1Memory-bound ❌  │
40.063~4Memory-bound ❌  │
160.25~16Memory-bound     │
640.98~63Memory-bound     │
2563.8~250       │ 전환 구간        │
102414.2~1000Compute-bound ✅ │
40961,847~1370Compute-bound ✅ │
└───────────┴────────────┴────────────┴──────────────────┘

→ 배치가 클수록 GPU 활용률 증가
→ 이것이 LLM 서빙에서 배칭이 중요한 이유

실용적 CUDA GEMM 최적화 팁

import torch

# 팁 1: TF32 사용 (H100 기본 FP32 대체)
# FP32: 23비트 가수부 → TF32: 10비트 가수부 (약간의 정밀도 손실)
# 하지만 Tensor Core 사용 가능 → 10배 빠름!
torch.backends.cuda.matmul.allow_tf32 = True    # 행렬 곱셈
torch.backends.cudnn.allow_tf32 = True           # 컨볼루션

# 팁 2: 연속 메모리 텐서 사용
def make_contiguous(x):
    # .contiguous()는 비연속 텐서를 연속 메모리로 복사
    # 비연속 텐서: torch.transpose 후 추가 연산 시 발생
    return x.contiguous()

# 팁 3: 적절한 행렬 크기 (Tensor Core 배수)
# H100 Tensor Core는 16의 배수 크기에서 최적 동작
# 16 × n 크기: 4096=16×256 ✅  4100=16×256+4 → 패딩 필요
def pad_to_multiple(x, multiple=16):
    d = x.shape[-1]
    if d % multiple == 0:
        return x
    pad_size = multiple - (d % multiple)
    return torch.nn.functional.pad(x, (0, pad_size))

# 팁 4: torch.compile (PyTorch 2.0+)
import torch

@torch.compile
def optimized_linear(x, weight, bias=None):
    """torch.compile이 자동으로 커널 퓨전, 최적 레이아웃 선택"""
    return torch.nn.functional.linear(x, weight, bias)

전체 그림: 행렬 곱셈 최적화 계층

현대 딥러닝 행렬 곱셈 최적화 스택:

애플리케이션 레이어:
  PyTorch / JAX / TensorFlow
   (torch.mm, F.linear)

커널 퓨전 / 컴파일러:
  torch.compile (TorchInductor)
  XLA (JAX)
   (최적화된 CUDA 코드 생성)

고수준 라이브러리:
  cuDNN (신경망 특화)
  cuBLAS (범용 GEMM)
  CUTLASS (NVIDIA 오픈소스 템플릿)
   (최적화된 PTX/SASS 코드)

하드웨어:
  Tensor Core (H100: 1,979 TFLOPS FP16)

메모리 시스템:
  L1/L2 캐시 (SRAM)
  HBM3 (3.35 TB/s)

FlashAttention은 이 스택의 "고수준 라이브러리" 레이어에서
IO-aware 알고리즘으로 HBM 접근을 최소화함.

행렬 곱셈은 단순해 보이지만, 그 최적화의 깊이는 40년의 역사와 수천 명의 엔지니어링 노력을 품고 있다. GPU 아키텍처의 변화에 따라 최적화 전략도 계속 진화하고 있으며, FlashAttention 시리즈처럼 알고리즘 레벨의 혁신이 하드웨어 개선만큼이나 큰 영향을 미치는 시대가 됐다.

LLM 인프라 엔지니어라면 이 스택 전체를 이해하는 것이 필수다. 다음 글에서는 이 기반 위에서 실제 LLM 서빙 최적화가 어떻게 이루어지는지 살펴보겠다.