Split View: PyTorch 고급 기법 완전 가이드: torch.compile, Custom Ops, Memory 최적화
PyTorch 고급 기법 완전 가이드: torch.compile, Custom Ops, Memory 최적화
PyTorch 고급 기법 완전 가이드
PyTorch는 딥러닝 연구와 프로덕션 배포 모두에서 가장 인기 있는 프레임워크 중 하나입니다. 하지만 대부분의 개발자는 기본적인 텐서 연산과 nn.Module 정의에서 멈춥니다. 이 가이드에서는 실제 프로덕션 환경에서 성능을 극대화하고, 커스텀 연산을 구현하며, 메모리를 효율적으로 관리하는 고급 기법들을 다룹니다.
1. torch.compile (PyTorch 2.0+)
PyTorch 2.0에서 도입된 torch.compile은 모델을 컴파일하여 실행 속도를 획기적으로 향상시키는 기능입니다. 기존의 TorchScript나 ONNX export와 달리, torch.compile은 파이썬 코드를 거의 수정하지 않고도 최대 2배 이상의 속도 향상을 가져올 수 있습니다.
torch.compile 소개와 장점
torch.compile은 세 가지 핵심 컴포넌트로 구성됩니다:
- TorchDynamo: Python 바이트코드를 인터셉트하여 FX 그래프를 생성
- AOTAutograd: 자동 미분 그래프를 사전 컴파일
- Inductor: TorchInductor 백엔드로 최적화된 커널 생성 (Triton GPU 커널 또는 C++ CPU 커널)
import torch
import torch.nn as nn
import time
# 기본 모델 정의
class TransformerBlock(nn.Module):
def __init__(self, d_model=512, nhead=8, dim_feedforward=2048):
super().__init__()
self.attention = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
nn.ReLU(),
nn.Linear(dim_feedforward, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
attn_out, _ = self.attention(x, x, x)
x = self.norm1(x + attn_out)
ff_out = self.feed_forward(x)
x = self.norm2(x + ff_out)
return x
device = "cuda" if torch.cuda.is_available() else "cpu"
model = TransformerBlock().to(device)
# torch.compile 적용
compiled_model = torch.compile(model)
# 워밍업
x = torch.randn(32, 128, 512, device=device)
for _ in range(3):
_ = compiled_model(x)
# 성능 비교
N = 100
# 일반 모델
start = time.perf_counter()
for _ in range(N):
_ = model(x)
if device == "cuda":
torch.cuda.synchronize()
elapsed_eager = time.perf_counter() - start
# 컴파일된 모델
start = time.perf_counter()
for _ in range(N):
_ = compiled_model(x)
if device == "cuda":
torch.cuda.synchronize()
elapsed_compiled = time.perf_counter() - start
print(f"Eager mode: {elapsed_eager:.3f}s")
print(f"Compiled mode: {elapsed_compiled:.3f}s")
print(f"Speedup: {elapsed_eager / elapsed_compiled:.2f}x")
컴파일 모드 (Compilation Modes)
torch.compile은 세 가지 모드를 제공합니다:
# 기본 모드 - 빠른 컴파일, 좋은 성능
model_default = torch.compile(model, mode="default")
# 오버헤드 감소 모드 - 작은 모델에 유리
model_reduce = torch.compile(model, mode="reduce-overhead")
# 최대 자동 튜닝 - 긴 컴파일, 최고 성능
model_autotune = torch.compile(model, mode="max-autotune")
# 풀 그래프 모드 - 동적 그래프 없음 (strict)
model_full = torch.compile(model, fullgraph=True)
# 백엔드 선택
model_eager = torch.compile(model, backend="eager") # 컴파일 없음 (디버그용)
model_aot = torch.compile(model, backend="aot_eager") # AOT만 적용
model_inductor = torch.compile(model, backend="inductor") # 기본값
동적 형태 지원
import torch._dynamo as dynamo
# 동적 형태 활성화
model_dynamic = torch.compile(model, dynamic=True)
# 다양한 배치 크기에서도 재컴파일 없이 동작
for batch_size in [8, 16, 32, 64]:
x = torch.randn(batch_size, 128, 512, device=device)
out = model_dynamic(x)
print(f"Batch {batch_size}: output shape {out.shape}")
# 컴파일 캐시 확인
print(dynamo.explain(model)(x))
기존 코드 마이그레이션
# 훈련 루프에 torch.compile 적용
def train_epoch(model, optimizer, dataloader, criterion):
model.train()
total_loss = 0
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
# 모델만 컴파일하면 됩니다
model = MyModel().cuda()
compiled_model = torch.compile(model) # 이것만 추가!
optimizer = torch.optim.Adam(compiled_model.parameters())
criterion = nn.CrossEntropyLoss()
2. 커스텀 자동 미분 (Custom Autograd)
PyTorch의 자동 미분 엔진은 강력하지만, 때로는 수치적으로 안정적이거나 효율적인 커스텀 그래디언트 계산이 필요합니다.
torch.autograd.Function 서브클래싱
import torch
from torch.autograd import Function
class SigmoidFunction(Function):
"""수치적으로 안정적인 Sigmoid 구현"""
@staticmethod
def forward(ctx, input):
# sigmoid = 1 / (1 + exp(-x))
sigmoid = torch.sigmoid(input)
# backward에서 사용할 텐서를 저장
ctx.save_for_backward(sigmoid)
return sigmoid
@staticmethod
def backward(ctx, grad_output):
# sigmoid를 불러옴
sigmoid, = ctx.saved_tensors
# gradient = sigmoid * (1 - sigmoid) * grad_output
grad_input = sigmoid * (1 - sigmoid) * grad_output
return grad_input
# 사용 예시
def custom_sigmoid(x):
return SigmoidFunction.apply(x)
x = torch.randn(3, 4, requires_grad=True)
y = custom_sigmoid(x)
loss = y.sum()
loss.backward()
print(f"Gradient: {x.grad}")
더 복잡한 예시: Leaky ReLU with Custom Backward
class LeakyReLUFunction(Function):
@staticmethod
def forward(ctx, input, negative_slope=0.01):
ctx.save_for_backward(input)
ctx.negative_slope = negative_slope
return input.clamp(min=0) + negative_slope * input.clamp(max=0)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
negative_slope = ctx.negative_slope
grad_input = grad_output.clone()
grad_input[input < 0] *= negative_slope
# 두 번째 입력(negative_slope)에 대한 그래디언트는 None
return grad_input, None
class CustomLeakyReLU(nn.Module):
def __init__(self, negative_slope=0.01):
super().__init__()
self.negative_slope = negative_slope
def forward(self, x):
return LeakyReLUFunction.apply(x, self.negative_slope)
수치 그래디언트 체크
from torch.autograd import gradcheck
def test_custom_op():
# 수치 그래디언트 체크 (float64 권장)
input = torch.randn(3, 4, dtype=torch.float64, requires_grad=True)
# gradcheck는 Jacobian을 수치적으로 계산하고 자동 미분과 비교
result = gradcheck(SigmoidFunction.apply, (input,), eps=1e-6, atol=1e-4)
print(f"Gradient check passed: {result}")
# 더블 역전파 테스트
input = torch.randn(3, 4, dtype=torch.float64, requires_grad=True)
result = gradcheck(
SigmoidFunction.apply,
(input,),
eps=1e-6,
atol=1e-4,
check_grad_dtypes=True
)
test_custom_op()
더블 역전파 (Double Backpropagation)
class SquaredFunction(Function):
"""x^2의 커스텀 구현 - 더블 역전파 지원"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x ** 2
@staticmethod
def backward(ctx, grad_output):
x, = ctx.saved_tensors
# 더블 역전파를 지원하려면 create_graph=True 설정 필요
return 2 * x * grad_output
# 더블 역전파 예시 (MAML 같은 메타 학습에서 활용)
x = torch.randn(3, requires_grad=True)
y = SquaredFunction.apply(x)
grad_x = torch.autograd.grad(y.sum(), x, create_graph=True)[0]
# grad_x = 2x, 이를 다시 역전파
grad_grad_x = torch.autograd.grad(grad_x.sum(), x)[0]
# grad_grad_x = 2 (상수)
print(f"Second derivative: {grad_grad_x}")
3. 커스텀 CUDA 연산자
torch.utils.cpp_extension 소개
PyTorch는 C++/CUDA 확장을 작성하고 파이썬에서 사용할 수 있는 도구를 제공합니다.
# JIT 컴파일 방식 (개발/프로토타입에 적합)
from torch.utils.cpp_extension import load_inline
# C++ CPU 연산자
cpp_source = """
#include <torch/extension.h>
torch::Tensor relu_forward(torch::Tensor input) {
return input.clamp_min(0);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("relu_forward", &relu_forward, "ReLU forward");
}
"""
# 인라인 컴파일
custom_relu_cpp = load_inline(
name="custom_relu_cpp",
cpp_sources=cpp_source,
functions=["relu_forward"],
verbose=False
)
x = torch.randn(5)
result = custom_relu_cpp.relu_forward(x)
print(result)
CUDA 커널 예시: Fused Softmax
# CUDA 소스
cuda_source = """
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
__global__ void fused_softmax_kernel(
const float* input,
float* output,
int rows,
int cols
) {
int row = blockIdx.x;
if (row >= rows) return;
const float* row_input = input + row * cols;
float* row_output = output + row * cols;
// Max 찾기 (수치 안정성)
float max_val = row_input[0];
for (int i = 1; i < cols; i++) {
max_val = fmaxf(max_val, row_input[i]);
}
// exp(x - max) 계산 및 합산
float sum = 0.0f;
for (int i = 0; i < cols; i++) {
row_output[i] = expf(row_input[i] - max_val);
sum += row_output[i];
}
// 정규화
for (int i = 0; i < cols; i++) {
row_output[i] /= sum;
}
}
torch::Tensor fused_softmax_cuda(torch::Tensor input) {
auto output = torch::zeros_like(input);
int rows = input.size(0);
int cols = input.size(1);
fused_softmax_kernel<<<rows, 1>>>(
input.data_ptr<float>(),
output.data_ptr<float>(),
rows,
cols
);
return output;
}
"""
cpp_source_cuda = """
#include <torch/extension.h>
torch::Tensor fused_softmax_cuda(torch::Tensor input);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_softmax", &fused_softmax_cuda, "Fused Softmax CUDA");
}
"""
# CUDA가 있는 경우에만 컴파일
if torch.cuda.is_available():
from torch.utils.cpp_extension import load_inline
fused_softmax_ext = load_inline(
name="fused_softmax",
cpp_sources=cpp_source_cuda,
cuda_sources=cuda_source,
functions=["fused_softmax"],
verbose=True
)
# 테스트
x = torch.randn(4, 8, device="cuda")
result = fused_softmax_ext.fused_softmax(x)
expected = torch.softmax(x, dim=1)
print(f"Max difference: {(result - expected).abs().max().item():.6f}")
setup.py를 이용한 패키지 빌드
# setup.py
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name="custom_ops",
ext_modules=[
CUDAExtension(
name="custom_ops",
sources=[
"custom_ops/ops.cpp",
"custom_ops/ops_cuda.cu",
],
extra_compile_args={
"cxx": ["-O3"],
"nvcc": ["-O3", "--use_fast_math"],
}
)
],
cmdclass={
"build_ext": BuildExtension
}
)
# 빌드: python setup.py install
4. 메모리 최적화 기법
GPU 메모리 프로파일링
import torch
def print_gpu_memory_stats():
"""GPU 메모리 통계를 출력하는 유틸리티"""
if torch.cuda.is_available():
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
print(f"Max Allocated: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
print(torch.cuda.memory_summary(abbreviated=True))
# 메모리 추적 시작
torch.cuda.reset_peak_memory_stats()
print_gpu_memory_stats()
# 모델 로드
model = TransformerBlock().cuda()
print_gpu_memory_stats()
Gradient Checkpointing (Activation Recomputation)
Gradient Checkpointing은 순전파 시 중간 활성화 값을 저장하지 않고, 역전파 시 필요한 경우 다시 계산합니다. 메모리를 크게 절약할 수 있지만, 계산 시간이 약 30% 증가합니다.
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
import torch.nn as nn
class DeepTransformer(nn.Module):
def __init__(self, num_layers=12, d_model=512, nhead=8):
super().__init__()
self.layers = nn.ModuleList([
TransformerBlock(d_model, nhead) for _ in range(num_layers)
])
def forward_with_checkpointing(self, x):
"""각 레이어에 gradient checkpointing 적용"""
for layer in self.layers:
# checkpoint는 중간 활성화 값을 저장하지 않음
x = checkpoint(layer, x, use_reentrant=False)
return x
def forward_sequential_checkpointing(self, x):
"""sequential 모듈에 checkpointing 적용"""
# 4개 레이어씩 그룹화
x = checkpoint_sequential(self.layers, segments=3, input=x)
return x
# 메모리 비교
model = DeepTransformer(num_layers=24).cuda()
x = torch.randn(16, 512, 512, device="cuda")
# 일반 forward
torch.cuda.reset_peak_memory_stats()
out = model(x) if hasattr(model, 'forward') else model.forward_with_checkpointing(x)
normal_mem = torch.cuda.max_memory_allocated()
# checkpointing forward
torch.cuda.reset_peak_memory_stats()
out = model.forward_with_checkpointing(x)
checkpoint_mem = torch.cuda.max_memory_allocated()
print(f"Normal memory: {normal_mem / 1024**3:.2f} GB")
print(f"Checkpoint memory: {checkpoint_mem / 1024**3:.2f} GB")
print(f"Memory saved: {(normal_mem - checkpoint_mem) / 1024**3:.2f} GB")
Gradient Accumulation
def train_with_gradient_accumulation(
model, optimizer, dataloader, criterion,
accumulation_steps=4
):
"""
큰 effective batch size를 작은 GPU 메모리로 구현
"""
model.train()
optimizer.zero_grad()
total_loss = 0
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.cuda(), target.cuda()
# Forward pass
output = model(data)
# accumulation_steps로 나누어 스케일 조정
loss = criterion(output, target) / accumulation_steps
loss.backward()
total_loss += loss.item() * accumulation_steps
# accumulation_steps마다 업데이트
if (batch_idx + 1) % accumulation_steps == 0:
# 그래디언트 클리핑 (선택사항)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
return total_loss / len(dataloader)
Mixed Precision Training (AMP)
from torch.cuda.amp import autocast, GradScaler
def train_with_amp(model, optimizer, dataloader, criterion):
"""Automatic Mixed Precision으로 메모리 절약 + 속도 향상"""
scaler = GradScaler()
model.train()
for data, target in dataloader:
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
# autocast 컨텍스트에서 float16 연산
with autocast():
output = model(data)
loss = criterion(output, target)
# 스케일된 그래디언트로 backward
scaler.scale(loss).backward()
# 언스케일 후 그래디언트 클리핑
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# Optimizer 스텝 (NaN/Inf 체크 포함)
scaler.step(optimizer)
scaler.update()
8-bit Optimizer
# bitsandbytes 라이브러리 필요: pip install bitsandbytes
try:
import bitsandbytes as bnb
# 일반 Adam 대신 8-bit Adam 사용
optimizer_8bit = bnb.optim.Adam8bit(
model.parameters(),
lr=1e-4,
betas=(0.9, 0.999)
)
# PagedAdam (CPU로 오프로드 가능)
optimizer_paged = bnb.optim.PagedAdam(
model.parameters(),
lr=1e-4
)
print("8-bit optimizer loaded successfully")
except ImportError:
print("bitsandbytes not installed, using regular Adam")
optimizer_8bit = torch.optim.Adam(model.parameters(), lr=1e-4)
5. functorch와 vmap
vmap - 배치화된 연산
vmap은 단일 샘플에 작동하는 함수를 배치에 효율적으로 적용합니다.
import torch
from torch import vmap
# 단일 샘플에 작동하는 함수
def single_linear(weight, bias, x):
return weight @ x + bias
# vmap으로 배치 처리
batched_linear = vmap(single_linear)
# 배치 데이터
batch_size = 32
weight = torch.randn(batch_size, 10, 5)
bias = torch.randn(batch_size, 10)
x = torch.randn(batch_size, 5)
# 자동으로 배치 처리됨
result = batched_linear(weight, bias, x)
print(f"Result shape: {result.shape}") # (32, 10)
grad - 함수형 그래디언트
from torch.func import grad, vmap, functional_call
# 함수형 그래디언트
def scalar_loss(params, x, y):
pred = functional_call(model, params, (x,))
return ((pred - y) ** 2).mean()
# 파라미터에 대한 그래디언트
params = dict(model.named_parameters())
grad_fn = grad(scalar_loss)
x = torch.randn(1, 10)
y = torch.randn(1, 5)
grads = grad_fn(params, x, y)
print({k: v.shape for k, v in grads.items()})
앙상블 모델 - vmap 활용
from torch.func import stack_module_state, functional_call, vmap
def create_ensemble(model_class, num_models, *args, **kwargs):
"""vmap을 이용한 효율적인 앙상블"""
models = [model_class(*args, **kwargs) for _ in range(num_models)]
# 모든 모델의 파라미터를 스택
params, buffers = stack_module_state(models)
# 단일 모델 forward
base_model = model_class(*args, **kwargs)
def single_forward(params, buffers, x):
return functional_call(base_model, (params, buffers), (x,))
# vmap으로 모든 모델을 병렬 실행
ensemble_forward = vmap(single_forward, in_dims=(0, 0, None))
return ensemble_forward, params, buffers
# 사용 예시
ensemble_fn, params, buffers = create_ensemble(
nn.Linear, num_models=5, in_features=10, out_features=5
)
x = torch.randn(32, 10)
ensemble_out = ensemble_fn(params, buffers, x)
print(f"Ensemble output shape: {ensemble_out.shape}") # (5, 32, 5)
메타러닝 (MAML) - grad와 vmap 결합
from torch.func import grad, vmap, functional_call
def inner_loop(params, support_x, support_y, base_model, lr=0.01, steps=5):
"""MAML 내부 루프"""
adapted_params = {k: v.clone() for k, v in params.items()}
for _ in range(steps):
def loss_fn(params):
pred = functional_call(base_model, params, (support_x,))
return ((pred - support_y) ** 2).mean()
grads = grad(loss_fn)(adapted_params)
adapted_params = {
k: p - lr * grads[k]
for k, p in adapted_params.items()
}
return adapted_params
6. PyTorch Profiler
기본 프로파일링
import torch
from torch.profiler import profile, record_function, ProfilerActivity
model = TransformerBlock().cuda()
x = torch.randn(32, 128, 512, device="cuda")
# 프로파일러 실행
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True
) as prof:
with record_function("model_inference"):
for _ in range(10):
output = model(x)
# 결과 출력
print(prof.key_averages().table(
sort_by="cuda_time_total",
row_limit=20
))
# Chrome Trace 내보내기
prof.export_chrome_trace("trace.json")
# chrome://tracing에서 열기
TensorBoard 통합
from torch.profiler import profile, ProfilerActivity, tensorboard_trace_handler
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(
wait=1, # 첫 1 스텝 대기
warmup=1, # 1 스텝 워밍업
active=3, # 3 스텝 프로파일
repeat=2 # 2회 반복
),
on_trace_ready=tensorboard_trace_handler("./log/profiler"),
record_shapes=True,
profile_memory=True
) as prof:
for step, (data, target) in enumerate(dataloader):
train_step(model, optimizer, data, target)
prof.step() # 스케줄에 따라 프로파일링
# tensorboard --logdir=./log/profiler
세부 메모리 분석
# 메모리 스냅샷 (PyTorch 2.1+)
torch.cuda.memory._record_memory_history(max_entries=100000)
# 코드 실행
x = torch.randn(100, 100, device="cuda")
y = x @ x.T
z = y.sum()
# 스냅샷 저장
snapshot = torch.cuda.memory._snapshot()
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
# 분석
print(f"Active allocations: {len(snapshot['segments'])}")
7. TorchScript
torch.jit.script vs torch.jit.trace
import torch
import torch.nn as nn
class ConditionalModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 5)
def forward(self, x, flag: bool = True):
if flag: # 조건 분기가 있으면 trace 불가
return torch.relu(self.linear(x))
else:
return torch.sigmoid(self.linear(x))
model = ConditionalModel()
# script: 제어 흐름 포함 (권장)
scripted_model = torch.jit.script(model)
print(scripted_model.code)
# trace: 단일 경로만 캡처
x = torch.randn(1, 10)
traced_model = torch.jit.trace(model, x)
# 주의: flag=False 경로는 캡처되지 않음
# 저장 및 로드
scripted_model.save("model_scripted.pt")
loaded = torch.jit.load("model_scripted.pt")
TorchScript 최적화
# 최적화 패스 적용
scripted = torch.jit.script(model)
optimized = torch.jit.optimize_for_inference(scripted)
# C++ 환경을 위한 내보내기
# C++에서: torch::jit::script::Module m = torch::jit::load("model.pt");
8. Dynamic Shapes와 torch.export
torch.export 사용
import torch
from torch.export import export
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 5)
def forward(self, x):
return self.linear(x)
model = MyModel()
example_inputs = (torch.randn(2, 10),)
# 동적 형태 지정
from torch.export import Dim
batch = Dim("batch", min=1, max=100)
# 모델 내보내기
exported = export(
model,
example_inputs,
dynamic_shapes={"x": {0: batch}}
)
print(exported)
print(exported.graph_module.code)
# ExportedProgram 실행
result = exported.module()(torch.randn(5, 10))
print(f"Result shape: {result.shape}")
9. Custom Dataset과 Sampler
IterableDataset
from torch.utils.data import IterableDataset, DataLoader
import torch
class StreamingDataset(IterableDataset):
"""대용량 데이터를 스트리밍으로 처리"""
def __init__(self, data_paths, transform=None):
self.data_paths = data_paths
self.transform = transform
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
# 단일 프로세스
paths = self.data_paths
else:
# 멀티프로세싱: 데이터를 분할
per_worker = len(self.data_paths) // worker_info.num_workers
start = worker_info.id * per_worker
end = start + per_worker
paths = self.data_paths[start:end]
for path in paths:
# 파일에서 데이터 스트리밍
data = self._load_file(path)
for sample in data:
if self.transform:
sample = self.transform(sample)
yield sample
def _load_file(self, path):
# 실제 구현에서는 파일 로드
return [torch.randn(10) for _ in range(100)]
# DataLoader에 persistent_workers 사용
dataset = StreamingDataset(data_paths=["file1.pt", "file2.pt"])
dataloader = DataLoader(
dataset,
batch_size=32,
num_workers=4,
persistent_workers=True, # 워커 프로세스 재사용
pin_memory=True, # GPU 전송 최적화
prefetch_factor=2 # 미리 가져올 배치 수
)
커스텀 Sampler
from torch.utils.data import Sampler
import numpy as np
class BalancedClassSampler(Sampler):
"""클래스 불균형 문제를 위한 가중치 샘플링"""
def __init__(self, dataset, num_samples_per_class=None):
self.dataset = dataset
labels = [dataset[i][1] for i in range(len(dataset))]
self.labels = torch.tensor(labels)
# 클래스별 인덱스
self.class_indices = {}
for cls in torch.unique(self.labels):
self.class_indices[cls.item()] = (
self.labels == cls
).nonzero(as_tuple=True)[0].tolist()
self.num_classes = len(self.class_indices)
self.num_samples_per_class = (
num_samples_per_class or
max(len(v) for v in self.class_indices.values())
)
def __iter__(self):
indices = []
for cls_idx in self.class_indices.values():
# 각 클래스에서 동일한 수의 샘플 추출 (복원 추출)
sampled = np.random.choice(
cls_idx,
self.num_samples_per_class,
replace=True
).tolist()
indices.extend(sampled)
# 셔플
np.random.shuffle(indices)
return iter(indices)
def __len__(self):
return self.num_classes * self.num_samples_per_class
10. PyTorch Lightning
LightningModule 완전 예시
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
class LightningTransformer(pl.LightningModule):
def __init__(
self,
d_model=512,
nhead=8,
num_layers=6,
num_classes=10,
learning_rate=1e-4
):
super().__init__()
self.save_hyperparameters() # HParams 자동 저장
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
batch_first=True
),
num_layers=num_layers
)
self.classifier = nn.Linear(d_model, num_classes)
self.criterion = nn.CrossEntropyLoss()
def forward(self, x):
encoded = self.encoder(x)
# 시퀀스 평균 풀링
pooled = encoded.mean(dim=1)
return self.classifier(pooled)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = self.criterion(logits, y)
acc = (logits.argmax(dim=1) == y).float().mean()
# 로깅
self.log("train_loss", loss, prog_bar=True)
self.log("train_acc", acc, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = self.criterion(logits, y)
acc = (logits.argmax(dim=1) == y).float().mean()
self.log("val_loss", loss, prog_bar=True)
self.log("val_acc", acc, prog_bar=True)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.hparams.learning_rate,
weight_decay=0.01
)
# 코사인 어닐링 스케줄러
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=100
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss"
}
}
class LightningDataModule(pl.LightningDataModule):
def __init__(self, batch_size=32):
super().__init__()
self.batch_size = batch_size
def setup(self, stage=None):
# 더미 데이터
x = torch.randn(1000, 32, 512)
y = torch.randint(0, 10, (1000,))
dataset = TensorDataset(x, y)
n = len(dataset)
self.train_ds = torch.utils.data.Subset(dataset, range(int(0.8*n)))
self.val_ds = torch.utils.data.Subset(dataset, range(int(0.8*n), n))
def train_dataloader(self):
return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(self.val_ds, batch_size=self.batch_size)
# 훈련 실행
model = LightningTransformer()
dm = LightningDataModule()
# 콜백 설정
callbacks = [
ModelCheckpoint(
monitor="val_loss",
save_top_k=3,
mode="min",
filename="transformer-{epoch:02d}-{val_loss:.2f}"
),
EarlyStopping(monitor="val_loss", patience=10, mode="min")
]
# Trainer
trainer = pl.Trainer(
max_epochs=100,
accelerator="auto", # GPU 자동 감지
devices="auto",
callbacks=callbacks,
logger=TensorBoardLogger("tb_logs", name="transformer"),
gradient_clip_val=1.0, # 그래디언트 클리핑
accumulate_grad_batches=4, # Gradient accumulation
precision="16-mixed", # AMP
)
trainer.fit(model, dm)
11. 모델 양자화 (Quantization)
동적 양자화 (Dynamic Quantization)
import torch
import torch.nn as nn
from torch.ao.quantization import quantize_dynamic
# 추론 전용 양자화 - 가장 쉬운 방법
model = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 128)
)
model.eval()
# Linear와 LSTM 레이어를 int8로 양자화
quantized_model = quantize_dynamic(
model,
{nn.Linear, nn.LSTM}, # 양자화할 레이어 타입
dtype=torch.qint8
)
# 크기 비교
import os
torch.save(model.state_dict(), "model_fp32.pt")
torch.save(quantized_model.state_dict(), "model_int8.pt")
fp32_size = os.path.getsize("model_fp32.pt")
int8_size = os.path.getsize("model_int8.pt")
print(f"FP32 size: {fp32_size / 1024:.1f} KB")
print(f"INT8 size: {int8_size / 1024:.1f} KB")
print(f"Compression: {fp32_size / int8_size:.1f}x")
정적 양자화 (Static Quantization)
import torch
from torch.ao.quantization import (
get_default_qconfig,
prepare,
convert,
QConfig
)
class QuantizableModel(nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.ao.quantization.QuantStub()
self.linear1 = nn.Linear(512, 256)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(256, 10)
self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
x = self.dequant(x)
return x
model = QuantizableModel()
model.eval()
# 양자화 설정
model.qconfig = get_default_qconfig("fbgemm") # x86 CPU용
# 준비 (옵저버 삽입)
prepared_model = prepare(model)
# 캘리브레이션 데이터로 통계 수집
with torch.no_grad():
for _ in range(100):
calibration_data = torch.randn(32, 512)
prepared_model(calibration_data)
# 양자화 변환
quantized_static = convert(prepared_model)
# 추론
x = torch.randn(1, 512)
with torch.no_grad():
output = quantized_static(x)
print(f"Output shape: {output.shape}")
QAT (Quantization-Aware Training)
from torch.ao.quantization import (
get_default_qat_qconfig,
prepare_qat,
convert
)
model = QuantizableModel()
model.train()
# QAT 설정
model.qconfig = get_default_qat_qconfig("fbgemm")
# QAT 준비 (가짜 양자화 노드 삽입)
prepared_qat = prepare_qat(model)
# 훈련 (양자화 오류를 학습에 포함)
optimizer = torch.optim.SGD(prepared_qat.parameters(), lr=0.0001)
for epoch in range(10):
for x, y in dummy_dataloader():
output = prepared_qat(x)
loss = nn.functional.cross_entropy(output, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 훈련 완료 후 변환
prepared_qat.eval()
quantized_qat = convert(prepared_qat)
def dummy_dataloader():
for _ in range(10):
yield torch.randn(32, 512), torch.randint(0, 10, (32,))
12. Tensor Parallelism과 Pipeline Parallelism
DeviceMesh와 DTensor API
import torch
import torch.distributed as dist
from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
RowwiseParallel
)
from torch.distributed._tensor import DeviceMesh
# 분산 훈련 초기화
def setup_distributed():
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
class TransformerMLP(nn.Module):
def __init__(self, d_model=1024, dim_feedforward=4096):
super().__init__()
self.fc1 = nn.Linear(d_model, dim_feedforward)
self.act = nn.GELU()
self.fc2 = nn.Linear(dim_feedforward, d_model)
def forward(self, x):
return self.fc2(self.act(self.fc1(x)))
# Tensor Parallelism 적용
def apply_tensor_parallel(model, mesh):
"""
fc1은 Column-wise로 샤딩 (출력 차원 분할)
fc2는 Row-wise로 샤딩 (입력 차원 분할)
"""
parallelize_module(
model,
mesh,
{
"fc1": ColwiseParallel(),
"fc2": RowwiseParallel(),
}
)
return model
# 실행 예시 (2 GPU 환경)
# device_mesh = DeviceMesh("cuda", [0, 1])
# model = TransformerMLP()
# model = apply_tensor_parallel(model, device_mesh)
FSDP (Fully Sharded Data Parallel)
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
BackwardPrefetch,
ShardingStrategy
)
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
transformer_auto_wrap_policy
)
import functools
def setup_fsdp_model(model):
"""FSDP 설정 - 대규모 모델을 여러 GPU에 분산"""
# 혼합 정밀도 설정
mp_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.bfloat16,
)
# TransformerBlock을 래핑 단위로 설정
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={TransformerBlock}
)
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mp_policy,
sharding_strategy=ShardingStrategy.FULL_SHARD,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
device_id=torch.cuda.current_device()
)
return model
마치며
이 가이드에서는 PyTorch의 고급 기법들을 살펴보았습니다:
- torch.compile: 코드 수정 없이 2배 이상의 성능 향상
- Custom Autograd: 특수한 그래디언트 계산 구현
- CUDA Extensions: GPU 커널을 PyTorch와 통합
- 메모리 최적화: Gradient Checkpointing, AMP, 8-bit Optimizer
- functorch/vmap: 함수형 API로 배치 처리와 메타러닝
- PyTorch Profiler: 성능 병목 분석
- TorchScript/export: 배포 최적화
- PyTorch Lightning: 코드 구조화와 훈련 자동화
- 양자화: INT8로 모델 크기 축소
- 분산 학습: Tensor Parallel, FSDP
각 기법은 서로 보완적이며, 실제 프로젝트에서는 여러 기법을 조합하여 사용하는 것이 일반적입니다. 특히 대규모 모델 훈련 시에는 AMP + Gradient Checkpointing + FSDP + torch.compile 조합이 강력한 효과를 발휘합니다.
참고 자료
PyTorch Advanced Techniques Complete Guide: torch.compile, Custom Ops, Memory Optimization
PyTorch Advanced Techniques Complete Guide
PyTorch is one of the most popular deep learning frameworks for both research and production. However, most developers stop at basic tensor operations and nn.Module definitions. This guide covers advanced techniques for maximizing performance, implementing custom operations, and managing memory efficiently in real production environments.
1. torch.compile (PyTorch 2.0+)
torch.compile, introduced in PyTorch 2.0, is a feature that compiles your model to dramatically improve execution speed. Unlike TorchScript or ONNX export, torch.compile can bring over 2x speedups with minimal code changes.
Introduction and Benefits
torch.compile is composed of three core components:
- TorchDynamo: Intercepts Python bytecode to generate an FX graph
- AOTAutograd: Pre-compiles the automatic differentiation graph
- Inductor: Generates optimized kernels via TorchInductor backend (Triton GPU kernels or C++ CPU kernels)
import torch
import torch.nn as nn
import time
# Define a base model
class TransformerBlock(nn.Module):
def __init__(self, d_model=512, nhead=8, dim_feedforward=2048):
super().__init__()
self.attention = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
nn.ReLU(),
nn.Linear(dim_feedforward, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
attn_out, _ = self.attention(x, x, x)
x = self.norm1(x + attn_out)
ff_out = self.feed_forward(x)
x = self.norm2(x + ff_out)
return x
device = "cuda" if torch.cuda.is_available() else "cpu"
model = TransformerBlock().to(device)
# Apply torch.compile
compiled_model = torch.compile(model)
# Warmup
x = torch.randn(32, 128, 512, device=device)
for _ in range(3):
_ = compiled_model(x)
# Performance comparison
N = 100
# Eager mode
start = time.perf_counter()
for _ in range(N):
_ = model(x)
if device == "cuda":
torch.cuda.synchronize()
elapsed_eager = time.perf_counter() - start
# Compiled mode
start = time.perf_counter()
for _ in range(N):
_ = compiled_model(x)
if device == "cuda":
torch.cuda.synchronize()
elapsed_compiled = time.perf_counter() - start
print(f"Eager mode: {elapsed_eager:.3f}s")
print(f"Compiled mode: {elapsed_compiled:.3f}s")
print(f"Speedup: {elapsed_eager / elapsed_compiled:.2f}x")
Compilation Modes
torch.compile offers three modes:
# Default mode - fast compile, good performance
model_default = torch.compile(model, mode="default")
# Reduce-overhead mode - better for small models
model_reduce = torch.compile(model, mode="reduce-overhead")
# Max-autotune - longer compile, best performance
model_autotune = torch.compile(model, mode="max-autotune")
# Full graph mode - no dynamic graphs (strict)
model_full = torch.compile(model, fullgraph=True)
# Backend selection
model_eager = torch.compile(model, backend="eager") # No compilation (debug)
model_aot = torch.compile(model, backend="aot_eager") # AOT only
model_inductor = torch.compile(model, backend="inductor") # Default
Dynamic Shape Support
import torch._dynamo as dynamo
# Enable dynamic shapes
model_dynamic = torch.compile(model, dynamic=True)
# Works without recompilation across different batch sizes
for batch_size in [8, 16, 32, 64]:
x = torch.randn(batch_size, 128, 512, device=device)
out = model_dynamic(x)
print(f"Batch {batch_size}: output shape {out.shape}")
# Check compilation cache
print(dynamo.explain(model)(x))
Migrating Existing Code
# Applying torch.compile to a training loop
def train_epoch(model, optimizer, dataloader, criterion):
model.train()
total_loss = 0
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
# Just compile the model
model = MyModel().cuda()
compiled_model = torch.compile(model) # Only this line added!
optimizer = torch.optim.Adam(compiled_model.parameters())
criterion = nn.CrossEntropyLoss()
2. Custom Autograd
PyTorch's automatic differentiation engine is powerful, but sometimes you need custom gradient computation for numerical stability or efficiency.
Subclassing torch.autograd.Function
import torch
from torch.autograd import Function
class SigmoidFunction(Function):
"""Numerically stable sigmoid implementation"""
@staticmethod
def forward(ctx, input):
# sigmoid = 1 / (1 + exp(-x))
sigmoid = torch.sigmoid(input)
# Save tensors for use in backward
ctx.save_for_backward(sigmoid)
return sigmoid
@staticmethod
def backward(ctx, grad_output):
# Retrieve saved sigmoid
sigmoid, = ctx.saved_tensors
# gradient = sigmoid * (1 - sigmoid) * grad_output
grad_input = sigmoid * (1 - sigmoid) * grad_output
return grad_input
# Usage
def custom_sigmoid(x):
return SigmoidFunction.apply(x)
x = torch.randn(3, 4, requires_grad=True)
y = custom_sigmoid(x)
loss = y.sum()
loss.backward()
print(f"Gradient: {x.grad}")
More Complex Example: Leaky ReLU with Custom Backward
class LeakyReLUFunction(Function):
@staticmethod
def forward(ctx, input, negative_slope=0.01):
ctx.save_for_backward(input)
ctx.negative_slope = negative_slope
return input.clamp(min=0) + negative_slope * input.clamp(max=0)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
negative_slope = ctx.negative_slope
grad_input = grad_output.clone()
grad_input[input < 0] *= negative_slope
# Gradient with respect to negative_slope is None
return grad_input, None
class CustomLeakyReLU(nn.Module):
def __init__(self, negative_slope=0.01):
super().__init__()
self.negative_slope = negative_slope
def forward(self, x):
return LeakyReLUFunction.apply(x, self.negative_slope)
Numerical Gradient Check
from torch.autograd import gradcheck
def test_custom_op():
# Numerical gradient check (float64 recommended)
input = torch.randn(3, 4, dtype=torch.float64, requires_grad=True)
# gradcheck computes Jacobian numerically and compares with autograd
result = gradcheck(SigmoidFunction.apply, (input,), eps=1e-6, atol=1e-4)
print(f"Gradient check passed: {result}")
# Double backprop test
input = torch.randn(3, 4, dtype=torch.float64, requires_grad=True)
result = gradcheck(
SigmoidFunction.apply,
(input,),
eps=1e-6,
atol=1e-4,
check_grad_dtypes=True
)
test_custom_op()
Double Backpropagation
class SquaredFunction(Function):
"""Custom x^2 implementation with double backprop support"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x ** 2
@staticmethod
def backward(ctx, grad_output):
x, = ctx.saved_tensors
# create_graph=True required to support double backprop
return 2 * x * grad_output
# Double backprop example (useful in meta-learning like MAML)
x = torch.randn(3, requires_grad=True)
y = SquaredFunction.apply(x)
grad_x = torch.autograd.grad(y.sum(), x, create_graph=True)[0]
# grad_x = 2x, backprop through this again
grad_grad_x = torch.autograd.grad(grad_x.sum(), x)[0]
# grad_grad_x = 2 (constant)
print(f"Second derivative: {grad_grad_x}")
3. Custom CUDA Operators
torch.utils.cpp_extension Overview
PyTorch provides tools for writing C++/CUDA extensions and using them from Python.
# JIT compilation (suitable for development/prototyping)
from torch.utils.cpp_extension import load_inline
# C++ CPU operator
cpp_source = """
#include <torch/extension.h>
torch::Tensor relu_forward(torch::Tensor input) {
return input.clamp_min(0);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("relu_forward", &relu_forward, "ReLU forward");
}
"""
# Inline compilation
custom_relu_cpp = load_inline(
name="custom_relu_cpp",
cpp_sources=cpp_source,
functions=["relu_forward"],
verbose=False
)
x = torch.randn(5)
result = custom_relu_cpp.relu_forward(x)
print(result)
CUDA Kernel Example: Fused Softmax
# CUDA source
cuda_source = """
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
__global__ void fused_softmax_kernel(
const float* input,
float* output,
int rows,
int cols
) {
int row = blockIdx.x;
if (row >= rows) return;
const float* row_input = input + row * cols;
float* row_output = output + row * cols;
// Find max (for numerical stability)
float max_val = row_input[0];
for (int i = 1; i < cols; i++) {
max_val = fmaxf(max_val, row_input[i]);
}
// Compute exp(x - max) and sum
float sum = 0.0f;
for (int i = 0; i < cols; i++) {
row_output[i] = expf(row_input[i] - max_val);
sum += row_output[i];
}
// Normalize
for (int i = 0; i < cols; i++) {
row_output[i] /= sum;
}
}
torch::Tensor fused_softmax_cuda(torch::Tensor input) {
auto output = torch::zeros_like(input);
int rows = input.size(0);
int cols = input.size(1);
fused_softmax_kernel<<<rows, 1>>>(
input.data_ptr<float>(),
output.data_ptr<float>(),
rows,
cols
);
return output;
}
"""
cpp_source_cuda = """
#include <torch/extension.h>
torch::Tensor fused_softmax_cuda(torch::Tensor input);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_softmax", &fused_softmax_cuda, "Fused Softmax CUDA");
}
"""
# Compile only if CUDA is available
if torch.cuda.is_available():
from torch.utils.cpp_extension import load_inline
fused_softmax_ext = load_inline(
name="fused_softmax",
cpp_sources=cpp_source_cuda,
cuda_sources=cuda_source,
functions=["fused_softmax"],
verbose=True
)
# Test
x = torch.randn(4, 8, device="cuda")
result = fused_softmax_ext.fused_softmax(x)
expected = torch.softmax(x, dim=1)
print(f"Max difference: {(result - expected).abs().max().item():.6f}")
Building a Package with setup.py
# setup.py
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name="custom_ops",
ext_modules=[
CUDAExtension(
name="custom_ops",
sources=[
"custom_ops/ops.cpp",
"custom_ops/ops_cuda.cu",
],
extra_compile_args={
"cxx": ["-O3"],
"nvcc": ["-O3", "--use_fast_math"],
}
)
],
cmdclass={
"build_ext": BuildExtension
}
)
# Build: python setup.py install
4. Memory Optimization Techniques
GPU Memory Profiling
import torch
def print_gpu_memory_stats():
"""Utility to print GPU memory statistics"""
if torch.cuda.is_available():
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
print(f"Max Allocated: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
print(torch.cuda.memory_summary(abbreviated=True))
# Start memory tracking
torch.cuda.reset_peak_memory_stats()
print_gpu_memory_stats()
# Load model
model = TransformerBlock().cuda()
print_gpu_memory_stats()
Gradient Checkpointing (Activation Recomputation)
Gradient Checkpointing avoids storing intermediate activation values during the forward pass, recomputing them as needed during the backward pass. This saves significant memory at the cost of roughly 30% more computation.
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
import torch.nn as nn
class DeepTransformer(nn.Module):
def __init__(self, num_layers=12, d_model=512, nhead=8):
super().__init__()
self.layers = nn.ModuleList([
TransformerBlock(d_model, nhead) for _ in range(num_layers)
])
def forward_with_checkpointing(self, x):
"""Apply gradient checkpointing to each layer"""
for layer in self.layers:
# checkpoint does not store intermediate activations
x = checkpoint(layer, x, use_reentrant=False)
return x
def forward_sequential_checkpointing(self, x):
"""Apply checkpointing to sequential modules"""
# Group into segments of 4 layers
x = checkpoint_sequential(self.layers, segments=3, input=x)
return x
# Memory comparison
model = DeepTransformer(num_layers=24).cuda()
x = torch.randn(16, 512, 512, device="cuda")
# Normal forward
torch.cuda.reset_peak_memory_stats()
out = model.forward_with_checkpointing(x)
checkpoint_mem = torch.cuda.max_memory_allocated()
print(f"Checkpoint memory: {checkpoint_mem / 1024**3:.2f} GB")
Gradient Accumulation
def train_with_gradient_accumulation(
model, optimizer, dataloader, criterion,
accumulation_steps=4
):
"""
Implement large effective batch size with limited GPU memory
"""
model.train()
optimizer.zero_grad()
total_loss = 0
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.cuda(), target.cuda()
# Forward pass
output = model(data)
# Scale loss by accumulation_steps
loss = criterion(output, target) / accumulation_steps
loss.backward()
total_loss += loss.item() * accumulation_steps
# Update every accumulation_steps
if (batch_idx + 1) % accumulation_steps == 0:
# Gradient clipping (optional)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
return total_loss / len(dataloader)
Mixed Precision Training (AMP)
from torch.cuda.amp import autocast, GradScaler
def train_with_amp(model, optimizer, dataloader, criterion):
"""Automatic Mixed Precision for memory savings + speed boost"""
scaler = GradScaler()
model.train()
for data, target in dataloader:
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
# float16 operations inside autocast context
with autocast():
output = model(data)
loss = criterion(output, target)
# Backward with scaled gradients
scaler.scale(loss).backward()
# Unscale then clip gradients
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# Optimizer step (with NaN/Inf check)
scaler.step(optimizer)
scaler.update()
8-bit Optimizer
# Requires bitsandbytes: pip install bitsandbytes
try:
import bitsandbytes as bnb
# Use 8-bit Adam instead of regular Adam
optimizer_8bit = bnb.optim.Adam8bit(
model.parameters(),
lr=1e-4,
betas=(0.9, 0.999)
)
# PagedAdam (can offload to CPU)
optimizer_paged = bnb.optim.PagedAdam(
model.parameters(),
lr=1e-4
)
print("8-bit optimizer loaded successfully")
except ImportError:
print("bitsandbytes not installed, using regular Adam")
optimizer_8bit = torch.optim.Adam(model.parameters(), lr=1e-4)
5. functorch and vmap
vmap - Batched Operations
vmap efficiently applies a function operating on single samples to a batch.
import torch
from torch import vmap
# Function operating on a single sample
def single_linear(weight, bias, x):
return weight @ x + bias
# Batch processing with vmap
batched_linear = vmap(single_linear)
# Batch data
batch_size = 32
weight = torch.randn(batch_size, 10, 5)
bias = torch.randn(batch_size, 10)
x = torch.randn(batch_size, 5)
# Automatically handled in batch
result = batched_linear(weight, bias, x)
print(f"Result shape: {result.shape}") # (32, 10)
grad - Functional Gradient
from torch.func import grad, vmap, functional_call
# Functional gradient
def scalar_loss(params, x, y):
pred = functional_call(model, params, (x,))
return ((pred - y) ** 2).mean()
# Gradient with respect to parameters
params = dict(model.named_parameters())
grad_fn = grad(scalar_loss)
x = torch.randn(1, 10)
y = torch.randn(1, 5)
grads = grad_fn(params, x, y)
print({k: v.shape for k, v in grads.items()})
Ensemble Models with vmap
from torch.func import stack_module_state, functional_call, vmap
def create_ensemble(model_class, num_models, *args, **kwargs):
"""Efficient ensemble using vmap"""
models = [model_class(*args, **kwargs) for _ in range(num_models)]
# Stack parameters from all models
params, buffers = stack_module_state(models)
# Single model forward
base_model = model_class(*args, **kwargs)
def single_forward(params, buffers, x):
return functional_call(base_model, (params, buffers), (x,))
# Run all models in parallel with vmap
ensemble_forward = vmap(single_forward, in_dims=(0, 0, None))
return ensemble_forward, params, buffers
# Usage example
ensemble_fn, params, buffers = create_ensemble(
nn.Linear, num_models=5, in_features=10, out_features=5
)
x = torch.randn(32, 10)
ensemble_out = ensemble_fn(params, buffers, x)
print(f"Ensemble output shape: {ensemble_out.shape}") # (5, 32, 5)
Meta-Learning (MAML) with grad and vmap
from torch.func import grad, vmap, functional_call
def inner_loop(params, support_x, support_y, base_model, lr=0.01, steps=5):
"""MAML inner loop"""
adapted_params = {k: v.clone() for k, v in params.items()}
for _ in range(steps):
def loss_fn(params):
pred = functional_call(base_model, params, (support_x,))
return ((pred - support_y) ** 2).mean()
grads = grad(loss_fn)(adapted_params)
adapted_params = {
k: p - lr * grads[k]
for k, p in adapted_params.items()
}
return adapted_params
6. PyTorch Profiler
Basic Profiling
import torch
from torch.profiler import profile, record_function, ProfilerActivity
model = TransformerBlock().cuda()
x = torch.randn(32, 128, 512, device="cuda")
# Run profiler
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True
) as prof:
with record_function("model_inference"):
for _ in range(10):
output = model(x)
# Print results
print(prof.key_averages().table(
sort_by="cuda_time_total",
row_limit=20
))
# Export Chrome Trace
prof.export_chrome_trace("trace.json")
# Open in chrome://tracing
TensorBoard Integration
from torch.profiler import profile, ProfilerActivity, tensorboard_trace_handler
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(
wait=1, # Wait 1 step
warmup=1, # 1 warmup step
active=3, # Profile 3 steps
repeat=2 # Repeat 2 times
),
on_trace_ready=tensorboard_trace_handler("./log/profiler"),
record_shapes=True,
profile_memory=True
) as prof:
for step, (data, target) in enumerate(dataloader):
train_step(model, optimizer, data, target)
prof.step() # Profile according to schedule
# tensorboard --logdir=./log/profiler
Detailed Memory Analysis
# Memory snapshot (PyTorch 2.1+)
torch.cuda.memory._record_memory_history(max_entries=100000)
# Run code
x = torch.randn(100, 100, device="cuda")
y = x @ x.T
z = y.sum()
# Save snapshot
snapshot = torch.cuda.memory._snapshot()
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
print(f"Active allocations: {len(snapshot['segments'])}")
7. TorchScript
torch.jit.script vs torch.jit.trace
import torch
import torch.nn as nn
class ConditionalModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 5)
def forward(self, x, flag: bool = True):
if flag: # Control flow makes trace impossible
return torch.relu(self.linear(x))
else:
return torch.sigmoid(self.linear(x))
model = ConditionalModel()
# script: includes control flow (recommended)
scripted_model = torch.jit.script(model)
print(scripted_model.code)
# trace: captures only a single path
x = torch.randn(1, 10)
traced_model = torch.jit.trace(model, x)
# Note: flag=False path is not captured
# Save and load
scripted_model.save("model_scripted.pt")
loaded = torch.jit.load("model_scripted.pt")
TorchScript Optimization
# Apply optimization passes
scripted = torch.jit.script(model)
optimized = torch.jit.optimize_for_inference(scripted)
# Export for C++ environment
# In C++: torch::jit::script::Module m = torch::jit::load("model.pt");
8. Dynamic Shapes and torch.export
Using torch.export
import torch
from torch.export import export
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 5)
def forward(self, x):
return self.linear(x)
model = MyModel()
example_inputs = (torch.randn(2, 10),)
# Specify dynamic shapes
from torch.export import Dim
batch = Dim("batch", min=1, max=100)
# Export model
exported = export(
model,
example_inputs,
dynamic_shapes={"x": {0: batch}}
)
print(exported)
print(exported.graph_module.code)
# Run ExportedProgram
result = exported.module()(torch.randn(5, 10))
print(f"Result shape: {result.shape}")
9. Custom Dataset and Sampler
IterableDataset
from torch.utils.data import IterableDataset, DataLoader
import torch
class StreamingDataset(IterableDataset):
"""Stream large datasets efficiently"""
def __init__(self, data_paths, transform=None):
self.data_paths = data_paths
self.transform = transform
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
# Single process
paths = self.data_paths
else:
# Multiprocessing: split data across workers
per_worker = len(self.data_paths) // worker_info.num_workers
start = worker_info.id * per_worker
end = start + per_worker
paths = self.data_paths[start:end]
for path in paths:
# Stream data from file
data = self._load_file(path)
for sample in data:
if self.transform:
sample = self.transform(sample)
yield sample
def _load_file(self, path):
# In real implementation, load from file
return [torch.randn(10) for _ in range(100)]
# Use persistent_workers in DataLoader
dataset = StreamingDataset(data_paths=["file1.pt", "file2.pt"])
dataloader = DataLoader(
dataset,
batch_size=32,
num_workers=4,
persistent_workers=True, # Reuse worker processes
pin_memory=True, # Optimize GPU transfer
prefetch_factor=2 # Number of batches to prefetch
)
Custom Sampler
from torch.utils.data import Sampler
import numpy as np
class BalancedClassSampler(Sampler):
"""Weighted sampling for class imbalance problems"""
def __init__(self, dataset, num_samples_per_class=None):
self.dataset = dataset
labels = [dataset[i][1] for i in range(len(dataset))]
self.labels = torch.tensor(labels)
# Indices per class
self.class_indices = {}
for cls in torch.unique(self.labels):
self.class_indices[cls.item()] = (
self.labels == cls
).nonzero(as_tuple=True)[0].tolist()
self.num_classes = len(self.class_indices)
self.num_samples_per_class = (
num_samples_per_class or
max(len(v) for v in self.class_indices.values())
)
def __iter__(self):
indices = []
for cls_idx in self.class_indices.values():
# Sample equal numbers from each class (with replacement)
sampled = np.random.choice(
cls_idx,
self.num_samples_per_class,
replace=True
).tolist()
indices.extend(sampled)
# Shuffle
np.random.shuffle(indices)
return iter(indices)
def __len__(self):
return self.num_classes * self.num_samples_per_class
10. PyTorch Lightning
Complete LightningModule Example
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
class LightningTransformer(pl.LightningModule):
def __init__(
self,
d_model=512,
nhead=8,
num_layers=6,
num_classes=10,
learning_rate=1e-4
):
super().__init__()
self.save_hyperparameters() # Auto-save HParams
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
batch_first=True
),
num_layers=num_layers
)
self.classifier = nn.Linear(d_model, num_classes)
self.criterion = nn.CrossEntropyLoss()
def forward(self, x):
encoded = self.encoder(x)
# Sequence mean pooling
pooled = encoded.mean(dim=1)
return self.classifier(pooled)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = self.criterion(logits, y)
acc = (logits.argmax(dim=1) == y).float().mean()
# Logging
self.log("train_loss", loss, prog_bar=True)
self.log("train_acc", acc, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = self.criterion(logits, y)
acc = (logits.argmax(dim=1) == y).float().mean()
self.log("val_loss", loss, prog_bar=True)
self.log("val_acc", acc, prog_bar=True)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.hparams.learning_rate,
weight_decay=0.01
)
# Cosine annealing scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=100
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss"
}
}
class LightningDataModule(pl.LightningDataModule):
def __init__(self, batch_size=32):
super().__init__()
self.batch_size = batch_size
def setup(self, stage=None):
# Dummy data
x = torch.randn(1000, 32, 512)
y = torch.randint(0, 10, (1000,))
dataset = TensorDataset(x, y)
n = len(dataset)
self.train_ds = torch.utils.data.Subset(dataset, range(int(0.8*n)))
self.val_ds = torch.utils.data.Subset(dataset, range(int(0.8*n), n))
def train_dataloader(self):
return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(self.val_ds, batch_size=self.batch_size)
# Run training
model = LightningTransformer()
dm = LightningDataModule()
# Setup callbacks
callbacks = [
ModelCheckpoint(
monitor="val_loss",
save_top_k=3,
mode="min",
filename="transformer-{epoch:02d}-{val_loss:.2f}"
),
EarlyStopping(monitor="val_loss", patience=10, mode="min")
]
# Trainer
trainer = pl.Trainer(
max_epochs=100,
accelerator="auto", # Auto-detect GPU
devices="auto",
callbacks=callbacks,
logger=TensorBoardLogger("tb_logs", name="transformer"),
gradient_clip_val=1.0, # Gradient clipping
accumulate_grad_batches=4, # Gradient accumulation
precision="16-mixed", # AMP
)
trainer.fit(model, dm)
11. Model Quantization
Dynamic Quantization
import torch
import torch.nn as nn
from torch.ao.quantization import quantize_dynamic
# Inference-only quantization - easiest approach
model = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 128)
)
model.eval()
# Quantize Linear and LSTM layers to int8
quantized_model = quantize_dynamic(
model,
{nn.Linear, nn.LSTM}, # Layer types to quantize
dtype=torch.qint8
)
# Size comparison
import os
torch.save(model.state_dict(), "model_fp32.pt")
torch.save(quantized_model.state_dict(), "model_int8.pt")
fp32_size = os.path.getsize("model_fp32.pt")
int8_size = os.path.getsize("model_int8.pt")
print(f"FP32 size: {fp32_size / 1024:.1f} KB")
print(f"INT8 size: {int8_size / 1024:.1f} KB")
print(f"Compression: {fp32_size / int8_size:.1f}x")
Static Quantization
import torch
from torch.ao.quantization import (
get_default_qconfig,
prepare,
convert,
)
class QuantizableModel(nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.ao.quantization.QuantStub()
self.linear1 = nn.Linear(512, 256)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(256, 10)
self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
x = self.dequant(x)
return x
model = QuantizableModel()
model.eval()
# Quantization config
model.qconfig = get_default_qconfig("fbgemm") # For x86 CPU
# Prepare (insert observers)
prepared_model = prepare(model)
# Collect statistics with calibration data
with torch.no_grad():
for _ in range(100):
calibration_data = torch.randn(32, 512)
prepared_model(calibration_data)
# Convert to quantized
quantized_static = convert(prepared_model)
# Inference
x = torch.randn(1, 512)
with torch.no_grad():
output = quantized_static(x)
print(f"Output shape: {output.shape}")
QAT (Quantization-Aware Training)
from torch.ao.quantization import (
get_default_qat_qconfig,
prepare_qat,
convert
)
model = QuantizableModel()
model.train()
# QAT config
model.qconfig = get_default_qat_qconfig("fbgemm")
# Prepare QAT (insert fake quantization nodes)
prepared_qat = prepare_qat(model)
# Train (include quantization error in training)
optimizer = torch.optim.SGD(prepared_qat.parameters(), lr=0.0001)
for epoch in range(10):
for x, y in dummy_dataloader():
output = prepared_qat(x)
loss = nn.functional.cross_entropy(output, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Convert after training
prepared_qat.eval()
quantized_qat = convert(prepared_qat)
def dummy_dataloader():
for _ in range(10):
yield torch.randn(32, 512), torch.randint(0, 10, (32,))
12. Tensor Parallelism and Pipeline Parallelism
DeviceMesh and DTensor API
import torch
import torch.distributed as dist
from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
RowwiseParallel
)
from torch.distributed._tensor import DeviceMesh
# Initialize distributed training
def setup_distributed():
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
class TransformerMLP(nn.Module):
def __init__(self, d_model=1024, dim_feedforward=4096):
super().__init__()
self.fc1 = nn.Linear(d_model, dim_feedforward)
self.act = nn.GELU()
self.fc2 = nn.Linear(dim_feedforward, d_model)
def forward(self, x):
return self.fc2(self.act(self.fc1(x)))
# Apply Tensor Parallelism
def apply_tensor_parallel(model, mesh):
"""
fc1: Column-wise sharding (split output dimension)
fc2: Row-wise sharding (split input dimension)
"""
parallelize_module(
model,
mesh,
{
"fc1": ColwiseParallel(),
"fc2": RowwiseParallel(),
}
)
return model
# Example (requires 2-GPU setup)
# device_mesh = DeviceMesh("cuda", [0, 1])
# model = TransformerMLP()
# model = apply_tensor_parallel(model, device_mesh)
FSDP (Fully Sharded Data Parallel)
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
BackwardPrefetch,
ShardingStrategy
)
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
transformer_auto_wrap_policy
)
import functools
def setup_fsdp_model(model):
"""FSDP setup - distribute large models across multiple GPUs"""
# Mixed precision config
mp_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.bfloat16,
)
# Set TransformerBlock as wrapping unit
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={TransformerBlock}
)
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mp_policy,
sharding_strategy=ShardingStrategy.FULL_SHARD,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
device_id=torch.cuda.current_device()
)
return model
Summary
This guide covered PyTorch's advanced techniques:
- torch.compile: Over 2x performance improvement without code changes
- Custom Autograd: Implementing specialized gradient computation
- CUDA Extensions: Integrating GPU kernels with PyTorch
- Memory Optimization: Gradient Checkpointing, AMP, 8-bit Optimizer
- functorch/vmap: Batch processing and meta-learning via functional API
- PyTorch Profiler: Analyzing performance bottlenecks
- TorchScript/export: Deployment optimization
- PyTorch Lightning: Structuring code and automating training
- Quantization: Reducing model size with INT8
- Distributed Training: Tensor Parallel, FSDP
These techniques are complementary, and in real projects it is common to combine several of them. For large-scale model training, the combination of AMP + Gradient Checkpointing + FSDP + torch.compile is particularly powerful.