Skip to content

필사 모드: 행렬이 GPU에서 어떻게 날아다니는가: GEMM부터 FlashAttention까지 완전 해부

한국어
0%
정확도 0%
💡 왼쪽 원문을 읽으면서 오른쪽에 따라 써보세요. Tab 키로 힌트를 받을 수 있습니다.
원문 렌더가 준비되기 전까지 텍스트 가이드로 표시합니다.

왜 행렬 곱셈이 딥러닝의 전부인가

딥러닝 모델을 실행할 때 GPU가 실제로 하는 일의 80% 이상은 단 하나의 연산으로 귀결된다. 바로 **행렬 곱셈(Matrix Multiplication)**이다.

믿기 어렵다면 각 레이어를 뜯어보자:

**Linear Layer (Fully Connected Layer)**

output = W × input + b

shape: (out_features,) = (out_features, in_features) × (in_features,)

가장 기본적인 선형 변환. 순수한 행렬-벡터 곱셈이다.

**Self-Attention (Transformer의 심장)**

Q = W_q × x # 행렬 곱셈

K = W_k × x # 행렬 곱셈

V = W_v × x # 행렬 곱셈

scores = Q × K^T # 행렬 곱셈 (N×d × d×N = N×N)

weights = softmax(scores / sqrt(d_k))

output = weights × V # 행렬 곱셈 (N×N × N×d = N×d)

한 번의 어텐션 레이어에서 **최소 6번**의 행렬 곱셈이 발생한다.

**Convolution (CNN)**

im2col 변환 후:

output = W × im2col(input)

컨볼루션도 결국 행렬 곱셈으로 변환된다.

**GPT-3 (175B) 기준 FLOPs 분포:**

Linear layers (QKV projection): ~45%

Feed-forward layers: ~35%

Attention computation (QK^T, AV): ~15%

Others (LayerNorm, softmax, etc.): ~5%

━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

행렬 곱셈 합계: ~95%

행렬 곱셈을 얼마나 빠르게 할 수 있느냐 = 딥러닝 추론/학습 속도가 결정된다. 이것이 40년간 수천 명의 엔지니어가 행렬 곱셈 최적화에 매달려온 이유다.

순진한(Naive) 행렬 곱셈과 그 문제점

3중 루프: 정의 그 자체

def naive_matmul(A, B):

"""

A: (M, K) B: (K, N) -> C: (M, N)

C[i][j] = sum(A[i][k] * B[k][j] for k in range(K))

"""

M, K = A.shape

K2, N = B.shape

assert K == K2, "Inner dimensions must match"

C = np.zeros((M, N), dtype=np.float64)

for i in range(M): # M 반복

for j in range(N): # N 반복

for k in range(K): # K 반복

C[i][j] += A[i][k] * B[k][j] # MAC 연산 1회

return C

총 연산: M * N * K 번의 multiply-add

복잡도: O(n^3)

4096×4096 행렬: 4096^3 = 약 680억 번의 연산!

수학적으로는 완벽하다. 하지만 실제로 돌려보면 끔찍하게 느리다.

4096×4096 FP32 행렬 곱셈:

- 이론 연산량: 2 × 4096³ ≈ 137 GFLOPS

- H100 이론 성능: 67 TFLOPS (FP32)

- 이론 시간: 137G / 67T ≈ **2ms**

- Naive Python 실제 시간: **약 900초 (450,000배 느림!)**

차이의 원인은 두 가지다. **메모리 접근 패턴**과 **병렬성 미활용**.

캐시 미스: 진짜 적

메모리 계층 구조 (H100 기준):

┌─────────────────────────────────────────────────────────┐

│ L1 Cache (per SM): 256KB, ~5 cycles latency │

│ L2 Cache (shared): 50MB, ~30 cycles latency │

│ HBM3 (GPU main mem): 80GB, ~300 cycles latency │

│ CPU DRAM: 수 TB, ~1000 cycles latency │

└─────────────────────────────────────────────────────────┘

캐시 접근: 300 cycles ↔ 5 cycles = 60배 차이

캐시 적중률이 곧 성능이다.

Naive 3중 루프에서 메모리 접근 패턴을 보자 (A[4×4], B[4×4] 예시):

A 행렬 (행 우선 저장, row-major):

메모리 주소: [A00][A01][A02][A03] [A10][A11][A12][A13] ...

←── 캐시라인 1 ──→ ←── 캐시라인 2 ──→

A[i][k] 접근 패턴 (i=0 고정, k=0,1,2,3 순서):

A[0][0] → 캐시라인 1 로드 (cache MISS, 300 cycles)

A[0][1] → 캐시라인 1 재사용 (cache HIT, 5 cycles) ✅

A[0][2] → 캐시라인 1 재사용 (cache HIT) ✅

A[0][3] → 캐시라인 1 재사용 (cache HIT) ✅

→ A는 행 방향 접근: 캐시 친화적 ✅

B 행렬의 열 방향 접근 (j=0 고정, k=0,1,2,3 순서):

B[0][0] → 캐시라인 0 로드 (cache MISS)

B[1][0] → 캐시라인 2 로드 (cache MISS! 다른 행) ❌

B[2][0] → 캐시라인 4 로드 (cache MISS!) ❌

B[3][0] → 캐시라인 6 로드 (cache MISS!) ❌

B의 각 원소가 별도 캐시라인에 있음

대형 행렬(N=4096)에서 B 접근의 99%가 캐시 미스!

실제로 측정해보면:

def measure_matmul():

N = 1024

A = np.random.randn(N, N).astype(np.float32)

B = np.random.randn(N, N).astype(np.float32)

B_T = B.T.copy() # B의 전치를 연속 메모리에 복사

방법 1: 순진한 열 접근 (캐시 비친화적)

start = time.perf_counter()

C1 = A @ B # numpy도 내부적으로 캐시 최적화함

t1 = time.perf_counter() - start

print(f"NumPy (최적화됨): {t1*1000:.2f}ms")

순수 Python 3중 루프는 수백 배 더 느림

measure_matmul()

Cache Blocking (Tiling): 해결책

핵심 아이디어: 캐시에 맞는 블록으로 분할

행렬을 캐시에 들어가는 작은 타일(tile)로 분할해 한 번 로드한 데이터를 최대한 재사용한다.

Matrix Tiling 시각화 (M=8, N=8, K=8, TILE_SIZE=4):

전체 행렬을 4×4 타일로 분할:

B (K×N):

┌───┬───┐

│B00│B01│ B00 = B[0:4, 0:4]

├───┼───┤ B01 = B[0:4, 4:8]

│B10│B11│ B10 = B[4:8, 0:4]

└───┴───┘ B11 = B[4:8, 4:8]

A (M×K): C (M×N):

┌───┬───┐ ┌───┬───┐

│A00│A01│ │C00│C01│ C00 = A00×B00 + A01×B10

├───┼───┤ ├───┼───┤ C01 = A00×B01 + A01×B11

│A10│A11│ │C10│C11│ C10 = A10×B00 + A11×B10

└───┴───┘ └───┴───┘ C11 = A10×B01 + A11×B11

타일 C00 계산 과정:

━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Step 1: A00(4×4=16 floats=64B) 캐시 로드

B00(4×4=16 floats=64B) 캐시 로드

→ 128B 로드로 4×4=16번의 내적 계산

→ 16×16=256번의 MAC 연산 (캐시 미스 ZERO!)

Step 2: A01(4×4) 캐시 로드 (A00 evict)

B10(4×4) 캐시 로드 (B00 evict)

→ 256번의 MAC 연산 (캐시 미스 ZERO!)

결과: C00 완성

━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

캐시 재사용률:

- Naive: B의 각 원소를 M번씩 다시 로드 (총 M번 캐시미스/원소)

- Tiled: 타일 내 원소를 TILE_SIZE번 재사용 (TILE_SIZE배 캐시미스 감소)

Tiled MatMul 구현

def tiled_matmul(A, B, tile_size=64):

"""

캐시 블로킹을 이용한 행렬 곱셈

tile_size: L1 캐시 크기에 맞게 조정

L1=256KB, float32: sqrt(256K/4/3) ≈ 146

실제로는 64-128이 최적인 경우가 많음

"""

M, K = A.shape

K2, N = B.shape

assert K == K2

C = np.zeros((M, N), dtype=A.dtype)

for i in range(0, M, tile_size): # i 타일

i_end = min(i + tile_size, M)

for j in range(0, N, tile_size): # j 타일

j_end = min(j + tile_size, N)

C[i:i_end, j:j_end]를 위한 누적

for k in range(0, K, tile_size): # k 타일 (내적 분할)

k_end = min(k + tile_size, K)

tile_A = A[i:i_end, k:k_end] # 캐시에 올라가는 A 타일

tile_B = B[k:k_end, j:j_end] # 캐시에 올라가는 B 타일

이 연산 동안 tile_A, tile_B는 캐시에 상주

C[i:i_end, j:j_end] += tile_A @ tile_B

return C

def benchmark_tiling():

"""성능 비교"""

sizes = [512, 1024, 2048]

for N in sizes:

A = np.random.randn(N, N).astype(np.float32)

B = np.random.randn(N, N).astype(np.float32)

NumPy (BLAS 최적화)

start = time.perf_counter()

for _ in range(5):

C_ref = A @ B

t_blas = (time.perf_counter() - start) / 5

print(f"N={N}: BLAS={t_blas*1000:.1f}ms, "

f"TFLOPS={2*N**3/t_blas/1e12:.2f}")

benchmark_tiling()

타일링의 효과:

N=4096, tile_size=64 vs naive (CPU, FP32):

Naive 3-loop: ~900s

Tiled 3-loop: ~18s (50배 향상)

NumPy (BLAS): ~0.8s (1125배 향상)

CUDA cuBLAS: ~0.002s (450,000배 향상)

BLAS: 40년의 최적화 결정체

BLAS 레벨 계층

BLAS (Basic Linear Algebra Subprograms, 1979~):

Level 1: Vector-Vector 연산 (O(n) 데이터, O(n) 연산)

- dot(x, y): 내적

- axpy(a, x, y): y = a*x + y (벡터 덧셈)

- nrm2(x): L2 노름

Level 2: Matrix-Vector 연산 (O(n²) 데이터, O(n²) 연산)

- gemv: y = alpha*A*x + beta*y

→ 산술 집약도 낮음, memory-bound

Level 3: Matrix-Matrix 연산 (O(n²) 데이터, O(n³) 연산)

- gemm: C = alpha*A*B + beta*C ← THE ONE WE CARE ABOUT

→ 산술 집약도 높음, compute-bound 가능

→ 딥러닝의 핵심 연산

GEMM: General Matrix Multiply

GEMM 인터페이스:

C = alpha * op(A) * op(B) + beta * C

파라미터:

transa/transb: 전치 여부 (N=없음, T=전치, C=켤레전치)

m, n, k: 행렬 차원 (C는 m×n, A는 m×k, B는 k×n)

alpha, beta: 스칼라 계수 (alpha=1.0, beta=0.0이면 순수 곱셈)

A, B, C: 행렬 포인터

lda, ldb, ldc: leading dimension (메모리 스트라이드)

lda의 역할 (중요!):

행렬 A가 더 큰 행렬의 부분행렬일 때:

A[i][j]의 메모리 주소 = base + i*lda + j

lda >= m 이어야 함

/* cuBLAS 예시 (NVIDIA GPU용 BLAS) */

#include <cublas_v2.h>

void gpu_matmul(float *d_A, float *d_B, float *d_C,

int M, int N, int K) {

cublasHandle_t handle;

cublasCreate(&handle);

float alpha = 1.0f;

float beta = 0.0f;

/* cuBLAS는 열 우선(column-major) 저장 사용

C = A*B (행 우선) = (B^T * A^T)^T (열 우선) */

cublasSgemm(

handle,

CUBLAS_OP_N, CUBLAS_OP_N, // 전치 없음

N, M, K, // 주의: cuBLAS는 B,A 순서!

&alpha,

d_B, N, // B (N×K), leading dim=N

d_A, K, // A (K×M), leading dim=K

&beta,

d_C, N // C (N×M), leading dim=N

);

cublasDestroy(handle);

/* cuBLAS는 이론 최고 성능의 95%+ 달성

수천 person-year의 최적화 결정체 */

}

PyTorch에서 cuBLAS 활용 (자동으로 사용됨):

A = torch.randn(4096, 4096, device='cuda', dtype=torch.float16)

B = torch.randn(4096, 4096, device='cuda', dtype=torch.float16)

내부적으로 cuBLAS HGEMM (Half-precision GEMM) 호출

C = torch.mm(A, B)

성능 프로파일링:

t = benchmark.Timer(

stmt='torch.mm(A, B)',

globals={'A': A, 'B': B}

)

result = t.timeit(100)

print(f"H100 FP16 GEMM 4096^2: {result.mean*1000:.2f}ms")

결과: ~0.5ms, ~2000 TFLOPS (이론 최고 2000 TFLOPS)

Tensor Core: 행렬 곱셈 전용 하드웨어

NVIDIA Tensor Core (H100 기준):

일반 CUDA Core:

한 번에: 1 FMA (Fused Multiply-Add) 실행

처리량: 1 FLOP/cycle/core

Tensor Core:

한 번에: 4×8 × 8×8 = 4×8 행렬 곱셈 실행

처리량: 256 FLOP/cycle (일반 코어의 256배!)

지원 형식: FP16, BF16, INT8, INT4, FP8 (H100)

H100 SXM5:

CUDA Cores: 16,896개 × 2 FP32 OPS = ~67 TFLOPS

Tensor Cores: 528개 × 2048 FP16 = ~1,979 TFLOPS (29.5배)

딥러닝에서 FP16이 기본인 이유: Tensor Core 활용!

FlashAttention: 메모리 벽 돌파하기

표준 어텐션의 메모리 복잡도 문제

Transformers의 가장 큰 병목은 어텐션의 **메모리 복잡도**다.

def attention_standard(Q, K, V, scale=None):

"""

표준 어텐션 구현

Q: (batch, heads, seq_len, d_k)

K: (batch, heads, seq_len, d_k)

V: (batch, heads, seq_len, d_v)

"""

if scale is None:

scale = Q.shape[-1] ** -0.5

Step 1: Q×K^T → (batch, heads, seq_len, seq_len)

seq_len=4096, d_k=128:

크기: batch * heads * 4096 * 4096 * 2bytes

= 1 * 32 * 16MB = 512MB !!! (HBM에 저장)

scores = torch.matmul(Q, K.transpose(-2, -1)) * scale

Step 2: softmax → 다시 HBM 읽기+쓰기

weights = F.softmax(scores, dim=-1) # 512MB 읽기+512MB 쓰기

Step 3: weights×V → 다시 HBM 읽기

output = torch.matmul(weights, V) # 512MB 읽기

return output

seq_len=4096, 32 heads 기준:

HBM 접근량: scores 생성(512MB) + softmax(1GB) + AV(512MB) = ~2GB

H100 HBM 대역폭: 3.35 TB/s

이론 최소 시간: 2GB / 3.35T = ~0.6ms per forward pass

→ 어텐션의 대부분 시간이 계산이 아닌 메모리 이동!

어텐션은 **메모리 바운드(memory-bound)** 연산이다. GPU의 계산 유닛은 놀고 있고, HBM(High Bandwidth Memory) 접근이 병목이다.

IO-Aware 어텐션: FlashAttention의 핵심

Tri Dao et al. (2022)의 통찰: "N×N 어텐션 행렬을 굳이 전부 HBM에 쓸 필요가 있나?"

GPU 메모리 계층 (H100):

SRAM (L1/Shared Memory, per SM):

크기: ~228KB per SM (H100에서 132개 SM)

대역폭: ~19 TB/s (HBM의 6배!)

지연: ~5 cycles

↕ 매우 빠름

HBM (High Bandwidth Memory):

크기: 80GB

대역폭: ~3.35 TB/s

지연: ~300 cycles

↕ 느림 (상대적으로)

FlashAttention 전략:

N×N 행렬을 HBM에 쓰지 말고,

SRAM에 들어가는 타일로 분할해서 처리하자!

FlashAttention 알고리즘 시각화:

Input in HBM:

Q: (N, d) ─────────────────────────────┐

K: (N, d) ─────────────────────────────┤

V: (N, d) ─────────────────────────────┘

(타일 단위로 SRAM으로 로드)

SRAM (타일 처리):

┌──────────────────────────────────────┐

│ Q_tile: (block_q, d) │

│ K_tile: (block_k, d) │

│ V_tile: (block_k, d) │

│ S_tile: (block_q, block_k) │ ← HBM에 절대 쓰지 않음!

│ running_max: (block_q,) │

│ running_sum: (block_q,) │

│ O_tile: (block_q, d) │

└──────────────────────────────────────┘

Output in HBM:

O: (N, d) ← 타일 완성 후 기록

HBM 접근 횟수:

표준 어텐션: O(N^2) (N×N 행렬 읽기/쓰기)

FlashAttention: O(N) (타일 단위 선형 접근)

Online Softmax: 수학적 열쇠

FlashAttention의 핵심 수학적 트릭은 **온라인 소프트맥스**다. 전체 시퀀스를 보지 않고도 정확한 소프트맥스를 계산할 수 있다.

def online_softmax_demo():

"""

온라인 소프트맥스의 수학적 원리 설명

일반 소프트맥스: softmax(x)[i] = exp(x[i]) / sum(exp(x[j]))

문제: 전체 x를 알아야 함 (타일 처리 불가)

수치 안정적 소프트맥스:

m = max(x)

softmax(x)[i] = exp(x[i] - m) / sum(exp(x[j] - m))

→ m을 빼도 결과는 동일, exp 오버플로 방지

온라인 업데이트 (새 타일 x_new가 도착할 때):

m_new = max(m_old, max(x_new))

l_new = exp(m_old - m_new) * l_old + sum(exp(x_new - m_new))

O_new = (exp(m_old - m_new) * l_old * O_old

+ exp(x_new - m_new) @ V_new) / l_new

"""

pass

def flash_attention_simplified(Q, K, V, block_size=64):

"""

FlashAttention 핵심 로직 (단순화 버전)

Q, K, V: (N, d)

"""

N, d = Q.shape

scale = 1.0 / math.sqrt(d)

dtype = Q.dtype

출력 초기화

O = torch.zeros(N, d, dtype=dtype, device=Q.device)

L = torch.zeros(N, dtype=dtype, device=Q.device) # 정규화 팩터

M = torch.full((N,), float('-inf'), dtype=dtype, device=Q.device) # 최댓값

외부 루프: K, V 타일

for j_start in range(0, N, block_size):

j_end = min(j_start + block_size, N)

K_j = K[j_start:j_end] # SRAM에 로드

V_j = V[j_start:j_end] # SRAM에 로드

내부 루프: Q 타일

for i_start in range(0, N, block_size):

i_end = min(i_start + block_size, N)

Q_i = Q[i_start:i_end] # SRAM에 로드

어텐션 스코어 (SRAM 내에서 계산!)

S_ij = (Q_i @ K_j.T) * scale # (block_q, block_k)

온라인 최댓값 업데이트

m_i_old = M[i_start:i_end]

m_ij = S_ij.max(dim=-1).values # 현재 타일의 최댓값

m_i_new = torch.maximum(m_i_old, m_ij)

온라인 sum 업데이트

이전 항: exp(m_old - m_new) * l_old (보정 팩터 적용)

correction = torch.exp(m_i_old - m_i_new)

l_i_new = (correction * L[i_start:i_end] +

torch.exp(S_ij - m_i_new.unsqueeze(-1)).sum(dim=-1))

출력 업데이트

P_ij = torch.exp(S_ij - m_i_new.unsqueeze(-1)) # softmax 분자

O[i_start:i_end] = (

(correction * L[i_start:i_end]).unsqueeze(-1) * O[i_start:i_end]

+ P_ij @ V_j

) / l_i_new.unsqueeze(-1)

상태 저장

M[i_start:i_end] = m_i_new

L[i_start:i_end] = l_i_new

return O # HBM에 한 번만 씀!

FlashAttention 성능 수치 (A100 80GB, seq_len=2048):

표준 어텐션: 12.3ms

FlashAttention v1: 3.4ms (3.6배 빠름)

FlashAttention v2: 2.1ms (5.9배 빠름)

FlashAttention v3: 1.4ms (8.8배 빠름, H100 전용)

FlashAttention의 추가 이점

역방향 패스(Backward Pass) 메모리 절약:

표준 어텐션: 순전파에서 attention weights (N×N) 저장 필수

→ 4096 토큰: 64MB per head per layer

→ 32 heads × 32 layers = 65 GB (배치 1에서도!)

FlashAttention: attention weights 저장 불필요

→ (m, l) 벡터만 저장 (O(N) 메모리)

→ 역전파 시 블록 재계산 (recomputation)

FlashAttention-2 추가 최적화 (2023):

- Q 루프와 K,V 루프 순서 교환 (스레드 병렬성 향상)

- 불필요한 rescale 연산 제거

- 워프(warp) 간 작업 분할 최적화

→ 이론 최고 성능의 73% 달성 (v1: 40%)

GEMM 성능 분석: Roofline 모델

언제 Compute-Bound, 언제 Memory-Bound?

Roofline Model:

성능 상한 = min(compute_peak, memory_bw × arithmetic_intensity)

산술 집약도 (Arithmetic Intensity, AI):

AI = FLOPS / BYTES_ACCESSED (FLOP/Byte)

H100 SXM5 파라미터:

FP16 Tensor Core: 1,979 TFLOPS

HBM 대역폭: 3,350 GB/s = 3.35 TB/s

Ridge Point: 1979T / 3.35T = 590 FLOP/Byte

→ AI > 590: Compute-bound

→ AI < 590: Memory-bound

def analyze_matmul_intensity(M, K, N, dtype_bytes=2):

"""

행렬 곱셈 (M×K) × (K×N) = (M×N)의 산술 집약도 계산

dtype_bytes: FP16=2, FP32=4

"""

flops = 2 * M * K * N # multiply + add

메모리 접근: A 읽기 + B 읽기 + C 쓰기

bytes_accessed = (M * K + K * N + M * N) * dtype_bytes

ai = flops / bytes_accessed

return flops, bytes_accessed, ai

케이스 1: 대형 행렬 (훈련 시 배치 처리)

M = N = K = 4096

flops, bytes, ai = analyze_matmul_intensity(M, K, N)

print(f"Large GEMM (4096^3):")

print(f" FLOPs: {flops/1e12:.1f} TFLOPS")

print(f" Bytes: {bytes/1e6:.0f} MB")

print(f" AI: {ai:.0f} FLOP/Byte")

print(f" H100 Ridge: 590 FLOP/Byte")

print(f" Status: {'COMPUTE-BOUND' if ai > 590 else 'MEMORY-BOUND'}")

결과: AI=1370, COMPUTE-BOUND ✅

print()

케이스 2: 추론 (배치 크기=1)

M, K, N = 1, 4096, 4096

flops, bytes, ai = analyze_matmul_intensity(M, K, N)

print(f"Inference GEMV (batch=1):")

print(f" FLOPs: {flops/1e6:.1f} MFLOPS")

print(f" Bytes: {bytes/1e6:.0f} MB (거의 B 행렬 전체!)")

print(f" AI: {ai:.2f} FLOP/Byte")

print(f" Status: {'COMPUTE-BOUND' if ai > 590 else 'MEMORY-BOUND'}")

결과: AI=0.99, MEMORY-BOUND ❌

이것이 LLM 추론 최적화가 어려운 근본 이유!

실제 성능 측정 (H100, FP16):

배치 크기별 GEMM 성능 (4096×4096 weight matrix):

┌───────────┬────────────┬────────────┬──────────────────┐

│ Batch (M) │ TFLOPS │ AI │ 상태 │

├───────────┼────────────┼────────────┼──────────────────┤

│ 1 │ 0.016 │ ~1 │ Memory-bound ❌ │

│ 4 │ 0.063 │ ~4 │ Memory-bound ❌ │

│ 16 │ 0.25 │ ~16 │ Memory-bound │

│ 64 │ 0.98 │ ~63 │ Memory-bound │

│ 256 │ 3.8 │ ~250 │ 전환 구간 │

│ 1024 │ 14.2 │ ~1000 │ Compute-bound ✅ │

│ 4096 │ 1,847 │ ~1370 │ Compute-bound ✅ │

└───────────┴────────────┴────────────┴──────────────────┘

→ 배치가 클수록 GPU 활용률 증가

→ 이것이 LLM 서빙에서 배칭이 중요한 이유

실용적 CUDA GEMM 최적화 팁

팁 1: TF32 사용 (H100 기본 FP32 대체)

FP32: 23비트 가수부 → TF32: 10비트 가수부 (약간의 정밀도 손실)

하지만 Tensor Core 사용 가능 → 10배 빠름!

torch.backends.cuda.matmul.allow_tf32 = True # 행렬 곱셈

torch.backends.cudnn.allow_tf32 = True # 컨볼루션

팁 2: 연속 메모리 텐서 사용

def make_contiguous(x):

.contiguous()는 비연속 텐서를 연속 메모리로 복사

비연속 텐서: torch.transpose 후 추가 연산 시 발생

return x.contiguous()

팁 3: 적절한 행렬 크기 (Tensor Core 배수)

H100 Tensor Core는 16의 배수 크기에서 최적 동작

16 × n 크기: 4096=16×256 ✅ 4100=16×256+4 → 패딩 필요

def pad_to_multiple(x, multiple=16):

d = x.shape[-1]

if d % multiple == 0:

return x

pad_size = multiple - (d % multiple)

return torch.nn.functional.pad(x, (0, pad_size))

팁 4: torch.compile (PyTorch 2.0+)

@torch.compile

def optimized_linear(x, weight, bias=None):

"""torch.compile이 자동으로 커널 퓨전, 최적 레이아웃 선택"""

return torch.nn.functional.linear(x, weight, bias)

전체 그림: 행렬 곱셈 최적화 계층

현대 딥러닝 행렬 곱셈 최적화 스택:

애플리케이션 레이어:

PyTorch / JAX / TensorFlow

↓ (torch.mm, F.linear 등)

커널 퓨전 / 컴파일러:

torch.compile (TorchInductor)

XLA (JAX)

↓ (최적화된 CUDA 코드 생성)

고수준 라이브러리:

cuDNN (신경망 특화)

cuBLAS (범용 GEMM)

CUTLASS (NVIDIA 오픈소스 템플릿)

↓ (최적화된 PTX/SASS 코드)

하드웨어:

Tensor Core (H100: 1,979 TFLOPS FP16)

메모리 시스템:

L1/L2 캐시 (SRAM)

HBM3 (3.35 TB/s)

FlashAttention은 이 스택의 "고수준 라이브러리" 레이어에서

IO-aware 알고리즘으로 HBM 접근을 최소화함.

행렬 곱셈은 단순해 보이지만, 그 최적화의 깊이는 40년의 역사와 수천 명의 엔지니어링 노력을 품고 있다. GPU 아키텍처의 변화에 따라 최적화 전략도 계속 진화하고 있으며, FlashAttention 시리즈처럼 알고리즘 레벨의 혁신이 하드웨어 개선만큼이나 큰 영향을 미치는 시대가 됐다.

LLM 인프라 엔지니어라면 이 스택 전체를 이해하는 것이 필수다. 다음 글에서는 이 기반 위에서 실제 LLM 서빙 최적화가 어떻게 이루어지는지 살펴보겠다.

현재 단락 (1/423)

딥러닝 모델을 실행할 때 GPU가 실제로 하는 일의 80% 이상은 단 하나의 연산으로 귀결된다. 바로 **행렬 곱셈(Matrix Multiplication)**이다.

작성 글자: 0원문 글자: 12,945작성 단락: 0/423