- Authors

- Name
- Youngju Kim
- @fjvbn20031
- 왜 행렬 곱셈이 딥러닝의 전부인가
- 순진한(Naive) 행렬 곱셈과 그 문제점
- Cache Blocking (Tiling): 해결책
- BLAS: 40년의 최적화 결정체
- FlashAttention: 메모리 벽 돌파하기
- GEMM 성능 분석: Roofline 모델
- 전체 그림: 행렬 곱셈 최적화 계층
왜 행렬 곱셈이 딥러닝의 전부인가
딥러닝 모델을 실행할 때 GPU가 실제로 하는 일의 80% 이상은 단 하나의 연산으로 귀결된다. 바로 **행렬 곱셈(Matrix Multiplication)**이다.
믿기 어렵다면 각 레이어를 뜯어보자:
Linear Layer (Fully Connected Layer)
output = W × input + b
shape: (out_features,) = (out_features, in_features) × (in_features,)
가장 기본적인 선형 변환. 순수한 행렬-벡터 곱셈이다.
Self-Attention (Transformer의 심장)
Q = W_q × x # 행렬 곱셈
K = W_k × x # 행렬 곱셈
V = W_v × x # 행렬 곱셈
scores = Q × K^T # 행렬 곱셈 (N×d × d×N = N×N)
weights = softmax(scores / sqrt(d_k))
output = weights × V # 행렬 곱셈 (N×N × N×d = N×d)
한 번의 어텐션 레이어에서 최소 6번의 행렬 곱셈이 발생한다.
Convolution (CNN)
im2col 변환 후:
output = W × im2col(input)
컨볼루션도 결국 행렬 곱셈으로 변환된다.
GPT-3 (175B) 기준 FLOPs 분포:
Linear layers (QKV projection): ~45%
Feed-forward layers: ~35%
Attention computation (QK^T, AV): ~15%
Others (LayerNorm, softmax, etc.): ~5%
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
행렬 곱셈 합계: ~95%
행렬 곱셈을 얼마나 빠르게 할 수 있느냐 = 딥러닝 추론/학습 속도가 결정된다. 이것이 40년간 수천 명의 엔지니어가 행렬 곱셈 최적화에 매달려온 이유다.
순진한(Naive) 행렬 곱셈과 그 문제점
3중 루프: 정의 그 자체
import numpy as np
def naive_matmul(A, B):
"""
A: (M, K) B: (K, N) -> C: (M, N)
C[i][j] = sum(A[i][k] * B[k][j] for k in range(K))
"""
M, K = A.shape
K2, N = B.shape
assert K == K2, "Inner dimensions must match"
C = np.zeros((M, N), dtype=np.float64)
for i in range(M): # M 반복
for j in range(N): # N 반복
for k in range(K): # K 반복
C[i][j] += A[i][k] * B[k][j] # MAC 연산 1회
return C
# 총 연산: M * N * K 번의 multiply-add
# 복잡도: O(n^3)
# 4096×4096 행렬: 4096^3 = 약 680억 번의 연산!
수학적으로는 완벽하다. 하지만 실제로 돌려보면 끔찍하게 느리다.
4096×4096 FP32 행렬 곱셈:
- 이론 연산량: 2 × 4096³ ≈ 137 GFLOPS
- H100 이론 성능: 67 TFLOPS (FP32)
- 이론 시간: 137G / 67T ≈ 2ms
- Naive Python 실제 시간: 약 900초 (450,000배 느림!)
차이의 원인은 두 가지다. 메모리 접근 패턴과 병렬성 미활용.
캐시 미스: 진짜 적
메모리 계층 구조 (H100 기준):
┌─────────────────────────────────────────────────────────┐
│ L1 Cache (per SM): 256KB, ~5 cycles latency │
│ L2 Cache (shared): 50MB, ~30 cycles latency │
│ HBM3 (GPU main mem): 80GB, ~300 cycles latency │
│ CPU DRAM: 수 TB, ~1000 cycles latency │
└─────────────────────────────────────────────────────────┘
캐시 접근: 300 cycles ↔ 5 cycles = 60배 차이
캐시 적중률이 곧 성능이다.
Naive 3중 루프에서 메모리 접근 패턴을 보자 (A[4×4], B[4×4] 예시):
A 행렬 (행 우선 저장, row-major):
메모리 주소: [A00][A01][A02][A03] [A10][A11][A12][A13] ...
←── 캐시라인 1 ──→ ←── 캐시라인 2 ──→
A[i][k] 접근 패턴 (i=0 고정, k=0,1,2,3 순서):
A[0][0] → 캐시라인 1 로드 (cache MISS, 300 cycles)
A[0][1] → 캐시라인 1 재사용 (cache HIT, 5 cycles) ✅
A[0][2] → 캐시라인 1 재사용 (cache HIT) ✅
A[0][3] → 캐시라인 1 재사용 (cache HIT) ✅
→ A는 행 방향 접근: 캐시 친화적 ✅
B 행렬의 열 방향 접근 (j=0 고정, k=0,1,2,3 순서):
B[0][0] → 캐시라인 0 로드 (cache MISS)
B[1][0] → 캐시라인 2 로드 (cache MISS! 다른 행) ❌
B[2][0] → 캐시라인 4 로드 (cache MISS!) ❌
B[3][0] → 캐시라인 6 로드 (cache MISS!) ❌
B의 각 원소가 별도 캐시라인에 있음
대형 행렬(N=4096)에서 B 접근의 99%가 캐시 미스!
실제로 측정해보면:
import time
import numpy as np
def measure_matmul():
N = 1024
A = np.random.randn(N, N).astype(np.float32)
B = np.random.randn(N, N).astype(np.float32)
B_T = B.T.copy() # B의 전치를 연속 메모리에 복사
# 방법 1: 순진한 열 접근 (캐시 비친화적)
start = time.perf_counter()
C1 = A @ B # numpy도 내부적으로 캐시 최적화함
t1 = time.perf_counter() - start
print(f"NumPy (최적화됨): {t1*1000:.2f}ms")
# 순수 Python 3중 루프는 수백 배 더 느림
measure_matmul()
Cache Blocking (Tiling): 해결책
핵심 아이디어: 캐시에 맞는 블록으로 분할
행렬을 캐시에 들어가는 작은 타일(tile)로 분할해 한 번 로드한 데이터를 최대한 재사용한다.
Matrix Tiling 시각화 (M=8, N=8, K=8, TILE_SIZE=4):
전체 행렬을 4×4 타일로 분할:
B (K×N):
┌───┬───┐
│B00│B01│ B00 = B[0:4, 0:4]
├───┼───┤ B01 = B[0:4, 4:8]
│B10│B11│ B10 = B[4:8, 0:4]
└───┴───┘ B11 = B[4:8, 4:8]
A (M×K): C (M×N):
┌───┬───┐ ┌───┬───┐
│A00│A01│ │C00│C01│ C00 = A00×B00 + A01×B10
├───┼───┤ ├───┼───┤ C01 = A00×B01 + A01×B11
│A10│A11│ │C10│C11│ C10 = A10×B00 + A11×B10
└───┴───┘ └───┴───┘ C11 = A10×B01 + A11×B11
타일 C00 계산 과정:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Step 1: A00(4×4=16 floats=64B) 캐시 로드
B00(4×4=16 floats=64B) 캐시 로드
→ 128B 로드로 4×4=16번의 내적 계산
→ 16×16=256번의 MAC 연산 (캐시 미스 ZERO!)
Step 2: A01(4×4) 캐시 로드 (A00 evict)
B10(4×4) 캐시 로드 (B00 evict)
→ 256번의 MAC 연산 (캐시 미스 ZERO!)
결과: C00 완성
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
캐시 재사용률:
- Naive: B의 각 원소를 M번씩 다시 로드 (총 M번 캐시미스/원소)
- Tiled: 타일 내 원소를 TILE_SIZE번 재사용 (TILE_SIZE배 캐시미스 감소)
Tiled MatMul 구현
import numpy as np
def tiled_matmul(A, B, tile_size=64):
"""
캐시 블로킹을 이용한 행렬 곱셈
tile_size: L1 캐시 크기에 맞게 조정
L1=256KB, float32: sqrt(256K/4/3) ≈ 146
실제로는 64-128이 최적인 경우가 많음
"""
M, K = A.shape
K2, N = B.shape
assert K == K2
C = np.zeros((M, N), dtype=A.dtype)
for i in range(0, M, tile_size): # i 타일
i_end = min(i + tile_size, M)
for j in range(0, N, tile_size): # j 타일
j_end = min(j + tile_size, N)
# C[i:i_end, j:j_end]를 위한 누적
for k in range(0, K, tile_size): # k 타일 (내적 분할)
k_end = min(k + tile_size, K)
tile_A = A[i:i_end, k:k_end] # 캐시에 올라가는 A 타일
tile_B = B[k:k_end, j:j_end] # 캐시에 올라가는 B 타일
# 이 연산 동안 tile_A, tile_B는 캐시에 상주
C[i:i_end, j:j_end] += tile_A @ tile_B
return C
def benchmark_tiling():
"""성능 비교"""
import time
sizes = [512, 1024, 2048]
for N in sizes:
A = np.random.randn(N, N).astype(np.float32)
B = np.random.randn(N, N).astype(np.float32)
# NumPy (BLAS 최적화)
start = time.perf_counter()
for _ in range(5):
C_ref = A @ B
t_blas = (time.perf_counter() - start) / 5
print(f"N={N}: BLAS={t_blas*1000:.1f}ms, "
f"TFLOPS={2*N**3/t_blas/1e12:.2f}")
benchmark_tiling()
타일링의 효과:
N=4096, tile_size=64 vs naive (CPU, FP32):
Naive 3-loop: ~900s
Tiled 3-loop: ~18s (50배 향상)
NumPy (BLAS): ~0.8s (1125배 향상)
CUDA cuBLAS: ~0.002s (450,000배 향상)
BLAS: 40년의 최적화 결정체
BLAS 레벨 계층
BLAS (Basic Linear Algebra Subprograms, 1979~):
Level 1: Vector-Vector 연산 (O(n) 데이터, O(n) 연산)
- dot(x, y): 내적
- axpy(a, x, y): y = a*x + y (벡터 덧셈)
- nrm2(x): L2 노름
Level 2: Matrix-Vector 연산 (O(n²) 데이터, O(n²) 연산)
- gemv: y = alpha*A*x + beta*y
→ 산술 집약도 낮음, memory-bound
Level 3: Matrix-Matrix 연산 (O(n²) 데이터, O(n³) 연산)
- gemm: C = alpha*A*B + beta*C ← THE ONE WE CARE ABOUT
→ 산술 집약도 높음, compute-bound 가능
→ 딥러닝의 핵심 연산
GEMM: General Matrix Multiply
GEMM 인터페이스:
C = alpha * op(A) * op(B) + beta * C
파라미터:
transa/transb: 전치 여부 (N=없음, T=전치, C=켤레전치)
m, n, k: 행렬 차원 (C는 m×n, A는 m×k, B는 k×n)
alpha, beta: 스칼라 계수 (alpha=1.0, beta=0.0이면 순수 곱셈)
A, B, C: 행렬 포인터
lda, ldb, ldc: leading dimension (메모리 스트라이드)
lda의 역할 (중요!):
행렬 A가 더 큰 행렬의 부분행렬일 때:
A[i][j]의 메모리 주소 = base + i*lda + j
lda >= m 이어야 함
/* cuBLAS 예시 (NVIDIA GPU용 BLAS) */
#include <cublas_v2.h>
void gpu_matmul(float *d_A, float *d_B, float *d_C,
int M, int N, int K) {
cublasHandle_t handle;
cublasCreate(&handle);
float alpha = 1.0f;
float beta = 0.0f;
/* cuBLAS는 열 우선(column-major) 저장 사용
C = A*B (행 우선) = (B^T * A^T)^T (열 우선) */
cublasSgemm(
handle,
CUBLAS_OP_N, CUBLAS_OP_N, // 전치 없음
N, M, K, // 주의: cuBLAS는 B,A 순서!
&alpha,
d_B, N, // B (N×K), leading dim=N
d_A, K, // A (K×M), leading dim=K
&beta,
d_C, N // C (N×M), leading dim=N
);
cublasDestroy(handle);
/* cuBLAS는 이론 최고 성능의 95%+ 달성
수천 person-year의 최적화 결정체 */
}
# PyTorch에서 cuBLAS 활용 (자동으로 사용됨):
import torch
A = torch.randn(4096, 4096, device='cuda', dtype=torch.float16)
B = torch.randn(4096, 4096, device='cuda', dtype=torch.float16)
# 내부적으로 cuBLAS HGEMM (Half-precision GEMM) 호출
C = torch.mm(A, B)
# 성능 프로파일링:
import torch.utils.benchmark as benchmark
t = benchmark.Timer(
stmt='torch.mm(A, B)',
globals={'A': A, 'B': B}
)
result = t.timeit(100)
print(f"H100 FP16 GEMM 4096^2: {result.mean*1000:.2f}ms")
# 결과: ~0.5ms, ~2000 TFLOPS (이론 최고 2000 TFLOPS)
Tensor Core: 행렬 곱셈 전용 하드웨어
NVIDIA Tensor Core (H100 기준):
일반 CUDA Core:
한 번에: 1 FMA (Fused Multiply-Add) 실행
처리량: 1 FLOP/cycle/core
Tensor Core:
한 번에: 4×8 × 8×8 = 4×8 행렬 곱셈 실행
처리량: 256 FLOP/cycle (일반 코어의 256배!)
지원 형식: FP16, BF16, INT8, INT4, FP8 (H100)
H100 SXM5:
CUDA Cores: 16,896개 × 2 FP32 OPS = ~67 TFLOPS
Tensor Cores: 528개 × 2048 FP16 = ~1,979 TFLOPS (29.5배)
딥러닝에서 FP16이 기본인 이유: Tensor Core 활용!
FlashAttention: 메모리 벽 돌파하기
표준 어텐션의 메모리 복잡도 문제
Transformers의 가장 큰 병목은 어텐션의 메모리 복잡도다.
import torch
import torch.nn.functional as F
def attention_standard(Q, K, V, scale=None):
"""
표준 어텐션 구현
Q: (batch, heads, seq_len, d_k)
K: (batch, heads, seq_len, d_k)
V: (batch, heads, seq_len, d_v)
"""
if scale is None:
scale = Q.shape[-1] ** -0.5
# Step 1: Q×K^T → (batch, heads, seq_len, seq_len)
# seq_len=4096, d_k=128:
# 크기: batch * heads * 4096 * 4096 * 2bytes
# = 1 * 32 * 16MB = 512MB !!! (HBM에 저장)
scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
# Step 2: softmax → 다시 HBM 읽기+쓰기
weights = F.softmax(scores, dim=-1) # 512MB 읽기+512MB 쓰기
# Step 3: weights×V → 다시 HBM 읽기
output = torch.matmul(weights, V) # 512MB 읽기
return output
# seq_len=4096, 32 heads 기준:
# HBM 접근량: scores 생성(512MB) + softmax(1GB) + AV(512MB) = ~2GB
# H100 HBM 대역폭: 3.35 TB/s
# 이론 최소 시간: 2GB / 3.35T = ~0.6ms per forward pass
# → 어텐션의 대부분 시간이 계산이 아닌 메모리 이동!
어텐션은 메모리 바운드(memory-bound) 연산이다. GPU의 계산 유닛은 놀고 있고, HBM(High Bandwidth Memory) 접근이 병목이다.
IO-Aware 어텐션: FlashAttention의 핵심
Tri Dao et al. (2022)의 통찰: "N×N 어텐션 행렬을 굳이 전부 HBM에 쓸 필요가 있나?"
GPU 메모리 계층 (H100):
SRAM (L1/Shared Memory, per SM):
크기: ~228KB per SM (H100에서 132개 SM)
대역폭: ~19 TB/s (HBM의 6배!)
지연: ~5 cycles
↕ 매우 빠름
HBM (High Bandwidth Memory):
크기: 80GB
대역폭: ~3.35 TB/s
지연: ~300 cycles
↕ 느림 (상대적으로)
FlashAttention 전략:
N×N 행렬을 HBM에 쓰지 말고,
SRAM에 들어가는 타일로 분할해서 처리하자!
FlashAttention 알고리즘 시각화:
Input in HBM:
Q: (N, d) ─────────────────────────────┐
K: (N, d) ─────────────────────────────┤
V: (N, d) ─────────────────────────────┘
(타일 단위로 SRAM으로 로드)
↓
SRAM (타일 처리):
┌──────────────────────────────────────┐
│ Q_tile: (block_q, d) │
│ K_tile: (block_k, d) │
│ V_tile: (block_k, d) │
│ S_tile: (block_q, block_k) │ ← HBM에 절대 쓰지 않음!
│ running_max: (block_q,) │
│ running_sum: (block_q,) │
│ O_tile: (block_q, d) │
└──────────────────────────────────────┘
↓
Output in HBM:
O: (N, d) ← 타일 완성 후 기록
HBM 접근 횟수:
표준 어텐션: O(N^2) (N×N 행렬 읽기/쓰기)
FlashAttention: O(N) (타일 단위 선형 접근)
Online Softmax: 수학적 열쇠
FlashAttention의 핵심 수학적 트릭은 온라인 소프트맥스다. 전체 시퀀스를 보지 않고도 정확한 소프트맥스를 계산할 수 있다.
import torch
import math
def online_softmax_demo():
"""
온라인 소프트맥스의 수학적 원리 설명
일반 소프트맥스: softmax(x)[i] = exp(x[i]) / sum(exp(x[j]))
문제: 전체 x를 알아야 함 (타일 처리 불가)
수치 안정적 소프트맥스:
m = max(x)
softmax(x)[i] = exp(x[i] - m) / sum(exp(x[j] - m))
→ m을 빼도 결과는 동일, exp 오버플로 방지
온라인 업데이트 (새 타일 x_new가 도착할 때):
m_new = max(m_old, max(x_new))
l_new = exp(m_old - m_new) * l_old + sum(exp(x_new - m_new))
O_new = (exp(m_old - m_new) * l_old * O_old
+ exp(x_new - m_new) @ V_new) / l_new
"""
pass
def flash_attention_simplified(Q, K, V, block_size=64):
"""
FlashAttention 핵심 로직 (단순화 버전)
Q, K, V: (N, d)
"""
N, d = Q.shape
scale = 1.0 / math.sqrt(d)
dtype = Q.dtype
# 출력 초기화
O = torch.zeros(N, d, dtype=dtype, device=Q.device)
L = torch.zeros(N, dtype=dtype, device=Q.device) # 정규화 팩터
M = torch.full((N,), float('-inf'), dtype=dtype, device=Q.device) # 최댓값
# 외부 루프: K, V 타일
for j_start in range(0, N, block_size):
j_end = min(j_start + block_size, N)
K_j = K[j_start:j_end] # SRAM에 로드
V_j = V[j_start:j_end] # SRAM에 로드
# 내부 루프: Q 타일
for i_start in range(0, N, block_size):
i_end = min(i_start + block_size, N)
Q_i = Q[i_start:i_end] # SRAM에 로드
# 어텐션 스코어 (SRAM 내에서 계산!)
S_ij = (Q_i @ K_j.T) * scale # (block_q, block_k)
# 온라인 최댓값 업데이트
m_i_old = M[i_start:i_end]
m_ij = S_ij.max(dim=-1).values # 현재 타일의 최댓값
m_i_new = torch.maximum(m_i_old, m_ij)
# 온라인 sum 업데이트
# 이전 항: exp(m_old - m_new) * l_old (보정 팩터 적용)
correction = torch.exp(m_i_old - m_i_new)
l_i_new = (correction * L[i_start:i_end] +
torch.exp(S_ij - m_i_new.unsqueeze(-1)).sum(dim=-1))
# 출력 업데이트
P_ij = torch.exp(S_ij - m_i_new.unsqueeze(-1)) # softmax 분자
O[i_start:i_end] = (
(correction * L[i_start:i_end]).unsqueeze(-1) * O[i_start:i_end]
+ P_ij @ V_j
) / l_i_new.unsqueeze(-1)
# 상태 저장
M[i_start:i_end] = m_i_new
L[i_start:i_end] = l_i_new
return O # HBM에 한 번만 씀!
# FlashAttention 성능 수치 (A100 80GB, seq_len=2048):
# 표준 어텐션: 12.3ms
# FlashAttention v1: 3.4ms (3.6배 빠름)
# FlashAttention v2: 2.1ms (5.9배 빠름)
# FlashAttention v3: 1.4ms (8.8배 빠름, H100 전용)
FlashAttention의 추가 이점
역방향 패스(Backward Pass) 메모리 절약:
표준 어텐션: 순전파에서 attention weights (N×N) 저장 필수
→ 4096 토큰: 64MB per head per layer
→ 32 heads × 32 layers = 65 GB (배치 1에서도!)
FlashAttention: attention weights 저장 불필요
→ (m, l) 벡터만 저장 (O(N) 메모리)
→ 역전파 시 블록 재계산 (recomputation)
FlashAttention-2 추가 최적화 (2023):
- Q 루프와 K,V 루프 순서 교환 (스레드 병렬성 향상)
- 불필요한 rescale 연산 제거
- 워프(warp) 간 작업 분할 최적화
→ 이론 최고 성능의 73% 달성 (v1: 40%)
GEMM 성능 분석: Roofline 모델
언제 Compute-Bound, 언제 Memory-Bound?
Roofline Model:
성능 상한 = min(compute_peak, memory_bw × arithmetic_intensity)
산술 집약도 (Arithmetic Intensity, AI):
AI = FLOPS / BYTES_ACCESSED (FLOP/Byte)
H100 SXM5 파라미터:
FP16 Tensor Core: 1,979 TFLOPS
HBM 대역폭: 3,350 GB/s = 3.35 TB/s
Ridge Point: 1979T / 3.35T = 590 FLOP/Byte
→ AI > 590: Compute-bound
→ AI < 590: Memory-bound
def analyze_matmul_intensity(M, K, N, dtype_bytes=2):
"""
행렬 곱셈 (M×K) × (K×N) = (M×N)의 산술 집약도 계산
dtype_bytes: FP16=2, FP32=4
"""
flops = 2 * M * K * N # multiply + add
# 메모리 접근: A 읽기 + B 읽기 + C 쓰기
bytes_accessed = (M * K + K * N + M * N) * dtype_bytes
ai = flops / bytes_accessed
return flops, bytes_accessed, ai
# 케이스 1: 대형 행렬 (훈련 시 배치 처리)
M = N = K = 4096
flops, bytes, ai = analyze_matmul_intensity(M, K, N)
print(f"Large GEMM (4096^3):")
print(f" FLOPs: {flops/1e12:.1f} TFLOPS")
print(f" Bytes: {bytes/1e6:.0f} MB")
print(f" AI: {ai:.0f} FLOP/Byte")
print(f" H100 Ridge: 590 FLOP/Byte")
print(f" Status: {'COMPUTE-BOUND' if ai > 590 else 'MEMORY-BOUND'}")
# 결과: AI=1370, COMPUTE-BOUND ✅
print()
# 케이스 2: 추론 (배치 크기=1)
M, K, N = 1, 4096, 4096
flops, bytes, ai = analyze_matmul_intensity(M, K, N)
print(f"Inference GEMV (batch=1):")
print(f" FLOPs: {flops/1e6:.1f} MFLOPS")
print(f" Bytes: {bytes/1e6:.0f} MB (거의 B 행렬 전체!)")
print(f" AI: {ai:.2f} FLOP/Byte")
print(f" Status: {'COMPUTE-BOUND' if ai > 590 else 'MEMORY-BOUND'}")
# 결과: AI=0.99, MEMORY-BOUND ❌
# 이것이 LLM 추론 최적화가 어려운 근본 이유!
실제 성능 측정 (H100, FP16):
배치 크기별 GEMM 성능 (4096×4096 weight matrix):
┌───────────┬────────────┬────────────┬──────────────────┐
│ Batch (M) │ TFLOPS │ AI │ 상태 │
├───────────┼────────────┼────────────┼──────────────────┤
│ 1 │ 0.016 │ ~1 │ Memory-bound ❌ │
│ 4 │ 0.063 │ ~4 │ Memory-bound ❌ │
│ 16 │ 0.25 │ ~16 │ Memory-bound │
│ 64 │ 0.98 │ ~63 │ Memory-bound │
│ 256 │ 3.8 │ ~250 │ 전환 구간 │
│ 1024 │ 14.2 │ ~1000 │ Compute-bound ✅ │
│ 4096 │ 1,847 │ ~1370 │ Compute-bound ✅ │
└───────────┴────────────┴────────────┴──────────────────┘
→ 배치가 클수록 GPU 활용률 증가
→ 이것이 LLM 서빙에서 배칭이 중요한 이유
실용적 CUDA GEMM 최적화 팁
import torch
# 팁 1: TF32 사용 (H100 기본 FP32 대체)
# FP32: 23비트 가수부 → TF32: 10비트 가수부 (약간의 정밀도 손실)
# 하지만 Tensor Core 사용 가능 → 10배 빠름!
torch.backends.cuda.matmul.allow_tf32 = True # 행렬 곱셈
torch.backends.cudnn.allow_tf32 = True # 컨볼루션
# 팁 2: 연속 메모리 텐서 사용
def make_contiguous(x):
# .contiguous()는 비연속 텐서를 연속 메모리로 복사
# 비연속 텐서: torch.transpose 후 추가 연산 시 발생
return x.contiguous()
# 팁 3: 적절한 행렬 크기 (Tensor Core 배수)
# H100 Tensor Core는 16의 배수 크기에서 최적 동작
# 16 × n 크기: 4096=16×256 ✅ 4100=16×256+4 → 패딩 필요
def pad_to_multiple(x, multiple=16):
d = x.shape[-1]
if d % multiple == 0:
return x
pad_size = multiple - (d % multiple)
return torch.nn.functional.pad(x, (0, pad_size))
# 팁 4: torch.compile (PyTorch 2.0+)
import torch
@torch.compile
def optimized_linear(x, weight, bias=None):
"""torch.compile이 자동으로 커널 퓨전, 최적 레이아웃 선택"""
return torch.nn.functional.linear(x, weight, bias)
전체 그림: 행렬 곱셈 최적화 계층
현대 딥러닝 행렬 곱셈 최적화 스택:
애플리케이션 레이어:
PyTorch / JAX / TensorFlow
↓ (torch.mm, F.linear 등)
커널 퓨전 / 컴파일러:
torch.compile (TorchInductor)
XLA (JAX)
↓ (최적화된 CUDA 코드 생성)
고수준 라이브러리:
cuDNN (신경망 특화)
cuBLAS (범용 GEMM)
CUTLASS (NVIDIA 오픈소스 템플릿)
↓ (최적화된 PTX/SASS 코드)
하드웨어:
Tensor Core (H100: 1,979 TFLOPS FP16)
↓
메모리 시스템:
L1/L2 캐시 (SRAM)
HBM3 (3.35 TB/s)
FlashAttention은 이 스택의 "고수준 라이브러리" 레이어에서
IO-aware 알고리즘으로 HBM 접근을 최소화함.
행렬 곱셈은 단순해 보이지만, 그 최적화의 깊이는 40년의 역사와 수천 명의 엔지니어링 노력을 품고 있다. GPU 아키텍처의 변화에 따라 최적화 전략도 계속 진화하고 있으며, FlashAttention 시리즈처럼 알고리즘 레벨의 혁신이 하드웨어 개선만큼이나 큰 영향을 미치는 시대가 됐다.
LLM 인프라 엔지니어라면 이 스택 전체를 이해하는 것이 필수다. 다음 글에서는 이 기반 위에서 실제 LLM 서빙 최적화가 어떻게 이루어지는지 살펴보겠다.