Skip to content

Split View: PyTorch 고급 기법 완전 가이드: torch.compile, Custom Ops, Memory 최적화

|

PyTorch 고급 기법 완전 가이드: torch.compile, Custom Ops, Memory 최적화

PyTorch 고급 기법 완전 가이드

PyTorch는 딥러닝 연구와 프로덕션 배포 모두에서 가장 인기 있는 프레임워크 중 하나입니다. 하지만 대부분의 개발자는 기본적인 텐서 연산과 nn.Module 정의에서 멈춥니다. 이 가이드에서는 실제 프로덕션 환경에서 성능을 극대화하고, 커스텀 연산을 구현하며, 메모리를 효율적으로 관리하는 고급 기법들을 다룹니다.

1. torch.compile (PyTorch 2.0+)

PyTorch 2.0에서 도입된 torch.compile은 모델을 컴파일하여 실행 속도를 획기적으로 향상시키는 기능입니다. 기존의 TorchScript나 ONNX export와 달리, torch.compile은 파이썬 코드를 거의 수정하지 않고도 최대 2배 이상의 속도 향상을 가져올 수 있습니다.

torch.compile 소개와 장점

torch.compile은 세 가지 핵심 컴포넌트로 구성됩니다:

  1. TorchDynamo: Python 바이트코드를 인터셉트하여 FX 그래프를 생성
  2. AOTAutograd: 자동 미분 그래프를 사전 컴파일
  3. Inductor: TorchInductor 백엔드로 최적화된 커널 생성 (Triton GPU 커널 또는 C++ CPU 커널)
import torch
import torch.nn as nn
import time

# 기본 모델 정의
class TransformerBlock(nn.Module):
    def __init__(self, d_model=512, nhead=8, dim_feedforward=2048):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, nhead, batch_first=True)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_out)
        ff_out = self.feed_forward(x)
        x = self.norm2(x + ff_out)
        return x

device = "cuda" if torch.cuda.is_available() else "cpu"
model = TransformerBlock().to(device)

# torch.compile 적용
compiled_model = torch.compile(model)

# 워밍업
x = torch.randn(32, 128, 512, device=device)
for _ in range(3):
    _ = compiled_model(x)

# 성능 비교
N = 100

# 일반 모델
start = time.perf_counter()
for _ in range(N):
    _ = model(x)
if device == "cuda":
    torch.cuda.synchronize()
elapsed_eager = time.perf_counter() - start

# 컴파일된 모델
start = time.perf_counter()
for _ in range(N):
    _ = compiled_model(x)
if device == "cuda":
    torch.cuda.synchronize()
elapsed_compiled = time.perf_counter() - start

print(f"Eager mode: {elapsed_eager:.3f}s")
print(f"Compiled mode: {elapsed_compiled:.3f}s")
print(f"Speedup: {elapsed_eager / elapsed_compiled:.2f}x")

컴파일 모드 (Compilation Modes)

torch.compile은 세 가지 모드를 제공합니다:

# 기본 모드 - 빠른 컴파일, 좋은 성능
model_default = torch.compile(model, mode="default")

# 오버헤드 감소 모드 - 작은 모델에 유리
model_reduce = torch.compile(model, mode="reduce-overhead")

# 최대 자동 튜닝 - 긴 컴파일, 최고 성능
model_autotune = torch.compile(model, mode="max-autotune")

# 풀 그래프 모드 - 동적 그래프 없음 (strict)
model_full = torch.compile(model, fullgraph=True)

# 백엔드 선택
model_eager = torch.compile(model, backend="eager")    # 컴파일 없음 (디버그용)
model_aot = torch.compile(model, backend="aot_eager")  # AOT만 적용
model_inductor = torch.compile(model, backend="inductor")  # 기본값

동적 형태 지원

import torch._dynamo as dynamo

# 동적 형태 활성화
model_dynamic = torch.compile(model, dynamic=True)

# 다양한 배치 크기에서도 재컴파일 없이 동작
for batch_size in [8, 16, 32, 64]:
    x = torch.randn(batch_size, 128, 512, device=device)
    out = model_dynamic(x)
    print(f"Batch {batch_size}: output shape {out.shape}")

# 컴파일 캐시 확인
print(dynamo.explain(model)(x))

기존 코드 마이그레이션

# 훈련 루프에 torch.compile 적용
def train_epoch(model, optimizer, dataloader, criterion):
    model.train()
    total_loss = 0

    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(dataloader)

# 모델만 컴파일하면 됩니다
model = MyModel().cuda()
compiled_model = torch.compile(model)  # 이것만 추가!

optimizer = torch.optim.Adam(compiled_model.parameters())
criterion = nn.CrossEntropyLoss()

2. 커스텀 자동 미분 (Custom Autograd)

PyTorch의 자동 미분 엔진은 강력하지만, 때로는 수치적으로 안정적이거나 효율적인 커스텀 그래디언트 계산이 필요합니다.

torch.autograd.Function 서브클래싱

import torch
from torch.autograd import Function

class SigmoidFunction(Function):
    """수치적으로 안정적인 Sigmoid 구현"""

    @staticmethod
    def forward(ctx, input):
        # sigmoid = 1 / (1 + exp(-x))
        sigmoid = torch.sigmoid(input)
        # backward에서 사용할 텐서를 저장
        ctx.save_for_backward(sigmoid)
        return sigmoid

    @staticmethod
    def backward(ctx, grad_output):
        # sigmoid를 불러옴
        sigmoid, = ctx.saved_tensors
        # gradient = sigmoid * (1 - sigmoid) * grad_output
        grad_input = sigmoid * (1 - sigmoid) * grad_output
        return grad_input

# 사용 예시
def custom_sigmoid(x):
    return SigmoidFunction.apply(x)

x = torch.randn(3, 4, requires_grad=True)
y = custom_sigmoid(x)
loss = y.sum()
loss.backward()
print(f"Gradient: {x.grad}")

더 복잡한 예시: Leaky ReLU with Custom Backward

class LeakyReLUFunction(Function):
    @staticmethod
    def forward(ctx, input, negative_slope=0.01):
        ctx.save_for_backward(input)
        ctx.negative_slope = negative_slope
        return input.clamp(min=0) + negative_slope * input.clamp(max=0)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        negative_slope = ctx.negative_slope
        grad_input = grad_output.clone()
        grad_input[input < 0] *= negative_slope
        # 두 번째 입력(negative_slope)에 대한 그래디언트는 None
        return grad_input, None

class CustomLeakyReLU(nn.Module):
    def __init__(self, negative_slope=0.01):
        super().__init__()
        self.negative_slope = negative_slope

    def forward(self, x):
        return LeakyReLUFunction.apply(x, self.negative_slope)

수치 그래디언트 체크

from torch.autograd import gradcheck

def test_custom_op():
    # 수치 그래디언트 체크 (float64 권장)
    input = torch.randn(3, 4, dtype=torch.float64, requires_grad=True)

    # gradcheck는 Jacobian을 수치적으로 계산하고 자동 미분과 비교
    result = gradcheck(SigmoidFunction.apply, (input,), eps=1e-6, atol=1e-4)
    print(f"Gradient check passed: {result}")

    # 더블 역전파 테스트
    input = torch.randn(3, 4, dtype=torch.float64, requires_grad=True)
    result = gradcheck(
        SigmoidFunction.apply,
        (input,),
        eps=1e-6,
        atol=1e-4,
        check_grad_dtypes=True
    )

test_custom_op()

더블 역전파 (Double Backpropagation)

class SquaredFunction(Function):
    """x^2의 커스텀 구현 - 더블 역전파 지원"""

    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x ** 2

    @staticmethod
    def backward(ctx, grad_output):
        x, = ctx.saved_tensors
        # 더블 역전파를 지원하려면 create_graph=True 설정 필요
        return 2 * x * grad_output

# 더블 역전파 예시 (MAML 같은 메타 학습에서 활용)
x = torch.randn(3, requires_grad=True)
y = SquaredFunction.apply(x)
grad_x = torch.autograd.grad(y.sum(), x, create_graph=True)[0]
# grad_x = 2x, 이를 다시 역전파
grad_grad_x = torch.autograd.grad(grad_x.sum(), x)[0]
# grad_grad_x = 2 (상수)
print(f"Second derivative: {grad_grad_x}")

3. 커스텀 CUDA 연산자

torch.utils.cpp_extension 소개

PyTorch는 C++/CUDA 확장을 작성하고 파이썬에서 사용할 수 있는 도구를 제공합니다.

# JIT 컴파일 방식 (개발/프로토타입에 적합)
from torch.utils.cpp_extension import load_inline

# C++ CPU 연산자
cpp_source = """
#include <torch/extension.h>

torch::Tensor relu_forward(torch::Tensor input) {
    return input.clamp_min(0);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("relu_forward", &relu_forward, "ReLU forward");
}
"""

# 인라인 컴파일
custom_relu_cpp = load_inline(
    name="custom_relu_cpp",
    cpp_sources=cpp_source,
    functions=["relu_forward"],
    verbose=False
)

x = torch.randn(5)
result = custom_relu_cpp.relu_forward(x)
print(result)

CUDA 커널 예시: Fused Softmax

# CUDA 소스
cuda_source = """
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

__global__ void fused_softmax_kernel(
    const float* input,
    float* output,
    int rows,
    int cols
) {
    int row = blockIdx.x;
    if (row >= rows) return;

    const float* row_input = input + row * cols;
    float* row_output = output + row * cols;

    // Max 찾기 (수치 안정성)
    float max_val = row_input[0];
    for (int i = 1; i < cols; i++) {
        max_val = fmaxf(max_val, row_input[i]);
    }

    // exp(x - max) 계산 및 합산
    float sum = 0.0f;
    for (int i = 0; i < cols; i++) {
        row_output[i] = expf(row_input[i] - max_val);
        sum += row_output[i];
    }

    // 정규화
    for (int i = 0; i < cols; i++) {
        row_output[i] /= sum;
    }
}

torch::Tensor fused_softmax_cuda(torch::Tensor input) {
    auto output = torch::zeros_like(input);
    int rows = input.size(0);
    int cols = input.size(1);

    fused_softmax_kernel<<<rows, 1>>>(
        input.data_ptr<float>(),
        output.data_ptr<float>(),
        rows,
        cols
    );

    return output;
}
"""

cpp_source_cuda = """
#include <torch/extension.h>
torch::Tensor fused_softmax_cuda(torch::Tensor input);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("fused_softmax", &fused_softmax_cuda, "Fused Softmax CUDA");
}
"""

# CUDA가 있는 경우에만 컴파일
if torch.cuda.is_available():
    from torch.utils.cpp_extension import load_inline
    fused_softmax_ext = load_inline(
        name="fused_softmax",
        cpp_sources=cpp_source_cuda,
        cuda_sources=cuda_source,
        functions=["fused_softmax"],
        verbose=True
    )

    # 테스트
    x = torch.randn(4, 8, device="cuda")
    result = fused_softmax_ext.fused_softmax(x)
    expected = torch.softmax(x, dim=1)
    print(f"Max difference: {(result - expected).abs().max().item():.6f}")

setup.py를 이용한 패키지 빌드

# setup.py
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
    name="custom_ops",
    ext_modules=[
        CUDAExtension(
            name="custom_ops",
            sources=[
                "custom_ops/ops.cpp",
                "custom_ops/ops_cuda.cu",
            ],
            extra_compile_args={
                "cxx": ["-O3"],
                "nvcc": ["-O3", "--use_fast_math"],
            }
        )
    ],
    cmdclass={
        "build_ext": BuildExtension
    }
)
# 빌드: python setup.py install

4. 메모리 최적화 기법

GPU 메모리 프로파일링

import torch

def print_gpu_memory_stats():
    """GPU 메모리 통계를 출력하는 유틸리티"""
    if torch.cuda.is_available():
        print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
        print(f"Reserved:  {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
        print(f"Max Allocated: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
        print(torch.cuda.memory_summary(abbreviated=True))

# 메모리 추적 시작
torch.cuda.reset_peak_memory_stats()
print_gpu_memory_stats()

# 모델 로드
model = TransformerBlock().cuda()
print_gpu_memory_stats()

Gradient Checkpointing (Activation Recomputation)

Gradient Checkpointing은 순전파 시 중간 활성화 값을 저장하지 않고, 역전파 시 필요한 경우 다시 계산합니다. 메모리를 크게 절약할 수 있지만, 계산 시간이 약 30% 증가합니다.

from torch.utils.checkpoint import checkpoint, checkpoint_sequential
import torch.nn as nn

class DeepTransformer(nn.Module):
    def __init__(self, num_layers=12, d_model=512, nhead=8):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, nhead) for _ in range(num_layers)
        ])

    def forward_with_checkpointing(self, x):
        """각 레이어에 gradient checkpointing 적용"""
        for layer in self.layers:
            # checkpoint는 중간 활성화 값을 저장하지 않음
            x = checkpoint(layer, x, use_reentrant=False)
        return x

    def forward_sequential_checkpointing(self, x):
        """sequential 모듈에 checkpointing 적용"""
        # 4개 레이어씩 그룹화
        x = checkpoint_sequential(self.layers, segments=3, input=x)
        return x

# 메모리 비교
model = DeepTransformer(num_layers=24).cuda()
x = torch.randn(16, 512, 512, device="cuda")

# 일반 forward
torch.cuda.reset_peak_memory_stats()
out = model(x) if hasattr(model, 'forward') else model.forward_with_checkpointing(x)
normal_mem = torch.cuda.max_memory_allocated()

# checkpointing forward
torch.cuda.reset_peak_memory_stats()
out = model.forward_with_checkpointing(x)
checkpoint_mem = torch.cuda.max_memory_allocated()

print(f"Normal memory: {normal_mem / 1024**3:.2f} GB")
print(f"Checkpoint memory: {checkpoint_mem / 1024**3:.2f} GB")
print(f"Memory saved: {(normal_mem - checkpoint_mem) / 1024**3:.2f} GB")

Gradient Accumulation

def train_with_gradient_accumulation(
    model, optimizer, dataloader, criterion,
    accumulation_steps=4
):
    """
    큰 effective batch size를 작은 GPU 메모리로 구현
    """
    model.train()
    optimizer.zero_grad()
    total_loss = 0

    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.cuda(), target.cuda()

        # Forward pass
        output = model(data)
        # accumulation_steps로 나누어 스케일 조정
        loss = criterion(output, target) / accumulation_steps
        loss.backward()

        total_loss += loss.item() * accumulation_steps

        # accumulation_steps마다 업데이트
        if (batch_idx + 1) % accumulation_steps == 0:
            # 그래디언트 클리핑 (선택사항)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            optimizer.zero_grad()

    return total_loss / len(dataloader)

Mixed Precision Training (AMP)

from torch.cuda.amp import autocast, GradScaler

def train_with_amp(model, optimizer, dataloader, criterion):
    """Automatic Mixed Precision으로 메모리 절약 + 속도 향상"""
    scaler = GradScaler()
    model.train()

    for data, target in dataloader:
        data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()

        # autocast 컨텍스트에서 float16 연산
        with autocast():
            output = model(data)
            loss = criterion(output, target)

        # 스케일된 그래디언트로 backward
        scaler.scale(loss).backward()

        # 언스케일 후 그래디언트 클리핑
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        # Optimizer 스텝 (NaN/Inf 체크 포함)
        scaler.step(optimizer)
        scaler.update()

8-bit Optimizer

# bitsandbytes 라이브러리 필요: pip install bitsandbytes
try:
    import bitsandbytes as bnb

    # 일반 Adam 대신 8-bit Adam 사용
    optimizer_8bit = bnb.optim.Adam8bit(
        model.parameters(),
        lr=1e-4,
        betas=(0.9, 0.999)
    )

    # PagedAdam (CPU로 오프로드 가능)
    optimizer_paged = bnb.optim.PagedAdam(
        model.parameters(),
        lr=1e-4
    )

    print("8-bit optimizer loaded successfully")
except ImportError:
    print("bitsandbytes not installed, using regular Adam")
    optimizer_8bit = torch.optim.Adam(model.parameters(), lr=1e-4)

5. functorch와 vmap

vmap - 배치화된 연산

vmap은 단일 샘플에 작동하는 함수를 배치에 효율적으로 적용합니다.

import torch
from torch import vmap

# 단일 샘플에 작동하는 함수
def single_linear(weight, bias, x):
    return weight @ x + bias

# vmap으로 배치 처리
batched_linear = vmap(single_linear)

# 배치 데이터
batch_size = 32
weight = torch.randn(batch_size, 10, 5)
bias = torch.randn(batch_size, 10)
x = torch.randn(batch_size, 5)

# 자동으로 배치 처리됨
result = batched_linear(weight, bias, x)
print(f"Result shape: {result.shape}")  # (32, 10)

grad - 함수형 그래디언트

from torch.func import grad, vmap, functional_call

# 함수형 그래디언트
def scalar_loss(params, x, y):
    pred = functional_call(model, params, (x,))
    return ((pred - y) ** 2).mean()

# 파라미터에 대한 그래디언트
params = dict(model.named_parameters())
grad_fn = grad(scalar_loss)

x = torch.randn(1, 10)
y = torch.randn(1, 5)
grads = grad_fn(params, x, y)
print({k: v.shape for k, v in grads.items()})

앙상블 모델 - vmap 활용

from torch.func import stack_module_state, functional_call, vmap

def create_ensemble(model_class, num_models, *args, **kwargs):
    """vmap을 이용한 효율적인 앙상블"""
    models = [model_class(*args, **kwargs) for _ in range(num_models)]

    # 모든 모델의 파라미터를 스택
    params, buffers = stack_module_state(models)

    # 단일 모델 forward
    base_model = model_class(*args, **kwargs)

    def single_forward(params, buffers, x):
        return functional_call(base_model, (params, buffers), (x,))

    # vmap으로 모든 모델을 병렬 실행
    ensemble_forward = vmap(single_forward, in_dims=(0, 0, None))

    return ensemble_forward, params, buffers

# 사용 예시
ensemble_fn, params, buffers = create_ensemble(
    nn.Linear, num_models=5, in_features=10, out_features=5
)

x = torch.randn(32, 10)
ensemble_out = ensemble_fn(params, buffers, x)
print(f"Ensemble output shape: {ensemble_out.shape}")  # (5, 32, 5)

메타러닝 (MAML) - grad와 vmap 결합

from torch.func import grad, vmap, functional_call

def inner_loop(params, support_x, support_y, base_model, lr=0.01, steps=5):
    """MAML 내부 루프"""
    adapted_params = {k: v.clone() for k, v in params.items()}

    for _ in range(steps):
        def loss_fn(params):
            pred = functional_call(base_model, params, (support_x,))
            return ((pred - support_y) ** 2).mean()

        grads = grad(loss_fn)(adapted_params)
        adapted_params = {
            k: p - lr * grads[k]
            for k, p in adapted_params.items()
        }

    return adapted_params

6. PyTorch Profiler

기본 프로파일링

import torch
from torch.profiler import profile, record_function, ProfilerActivity

model = TransformerBlock().cuda()
x = torch.randn(32, 128, 512, device="cuda")

# 프로파일러 실행
with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True,
    profile_memory=True,
    with_stack=True
) as prof:
    with record_function("model_inference"):
        for _ in range(10):
            output = model(x)

# 결과 출력
print(prof.key_averages().table(
    sort_by="cuda_time_total",
    row_limit=20
))

# Chrome Trace 내보내기
prof.export_chrome_trace("trace.json")
# chrome://tracing에서 열기

TensorBoard 통합

from torch.profiler import profile, ProfilerActivity, tensorboard_trace_handler

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    schedule=torch.profiler.schedule(
        wait=1,    # 첫 1 스텝 대기
        warmup=1,  # 1 스텝 워밍업
        active=3,  # 3 스텝 프로파일
        repeat=2   # 2회 반복
    ),
    on_trace_ready=tensorboard_trace_handler("./log/profiler"),
    record_shapes=True,
    profile_memory=True
) as prof:
    for step, (data, target) in enumerate(dataloader):
        train_step(model, optimizer, data, target)
        prof.step()  # 스케줄에 따라 프로파일링

# tensorboard --logdir=./log/profiler

세부 메모리 분석

# 메모리 스냅샷 (PyTorch 2.1+)
torch.cuda.memory._record_memory_history(max_entries=100000)

# 코드 실행
x = torch.randn(100, 100, device="cuda")
y = x @ x.T
z = y.sum()

# 스냅샷 저장
snapshot = torch.cuda.memory._snapshot()
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")

# 분석
print(f"Active allocations: {len(snapshot['segments'])}")

7. TorchScript

torch.jit.script vs torch.jit.trace

import torch
import torch.nn as nn

class ConditionalModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)

    def forward(self, x, flag: bool = True):
        if flag:  # 조건 분기가 있으면 trace 불가
            return torch.relu(self.linear(x))
        else:
            return torch.sigmoid(self.linear(x))

model = ConditionalModel()

# script: 제어 흐름 포함 (권장)
scripted_model = torch.jit.script(model)
print(scripted_model.code)

# trace: 단일 경로만 캡처
x = torch.randn(1, 10)
traced_model = torch.jit.trace(model, x)
# 주의: flag=False 경로는 캡처되지 않음

# 저장 및 로드
scripted_model.save("model_scripted.pt")
loaded = torch.jit.load("model_scripted.pt")

TorchScript 최적화

# 최적화 패스 적용
scripted = torch.jit.script(model)
optimized = torch.jit.optimize_for_inference(scripted)

# C++ 환경을 위한 내보내기
# C++에서: torch::jit::script::Module m = torch::jit::load("model.pt");

8. Dynamic Shapes와 torch.export

torch.export 사용

import torch
from torch.export import export

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)

    def forward(self, x):
        return self.linear(x)

model = MyModel()
example_inputs = (torch.randn(2, 10),)

# 동적 형태 지정
from torch.export import Dim
batch = Dim("batch", min=1, max=100)

# 모델 내보내기
exported = export(
    model,
    example_inputs,
    dynamic_shapes={"x": {0: batch}}
)

print(exported)
print(exported.graph_module.code)

# ExportedProgram 실행
result = exported.module()(torch.randn(5, 10))
print(f"Result shape: {result.shape}")

9. Custom Dataset과 Sampler

IterableDataset

from torch.utils.data import IterableDataset, DataLoader
import torch

class StreamingDataset(IterableDataset):
    """대용량 데이터를 스트리밍으로 처리"""

    def __init__(self, data_paths, transform=None):
        self.data_paths = data_paths
        self.transform = transform

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()

        if worker_info is None:
            # 단일 프로세스
            paths = self.data_paths
        else:
            # 멀티프로세싱: 데이터를 분할
            per_worker = len(self.data_paths) // worker_info.num_workers
            start = worker_info.id * per_worker
            end = start + per_worker
            paths = self.data_paths[start:end]

        for path in paths:
            # 파일에서 데이터 스트리밍
            data = self._load_file(path)
            for sample in data:
                if self.transform:
                    sample = self.transform(sample)
                yield sample

    def _load_file(self, path):
        # 실제 구현에서는 파일 로드
        return [torch.randn(10) for _ in range(100)]

# DataLoader에 persistent_workers 사용
dataset = StreamingDataset(data_paths=["file1.pt", "file2.pt"])
dataloader = DataLoader(
    dataset,
    batch_size=32,
    num_workers=4,
    persistent_workers=True,  # 워커 프로세스 재사용
    pin_memory=True,           # GPU 전송 최적화
    prefetch_factor=2          # 미리 가져올 배치 수
)

커스텀 Sampler

from torch.utils.data import Sampler
import numpy as np

class BalancedClassSampler(Sampler):
    """클래스 불균형 문제를 위한 가중치 샘플링"""

    def __init__(self, dataset, num_samples_per_class=None):
        self.dataset = dataset
        labels = [dataset[i][1] for i in range(len(dataset))]
        self.labels = torch.tensor(labels)

        # 클래스별 인덱스
        self.class_indices = {}
        for cls in torch.unique(self.labels):
            self.class_indices[cls.item()] = (
                self.labels == cls
            ).nonzero(as_tuple=True)[0].tolist()

        self.num_classes = len(self.class_indices)
        self.num_samples_per_class = (
            num_samples_per_class or
            max(len(v) for v in self.class_indices.values())
        )

    def __iter__(self):
        indices = []
        for cls_idx in self.class_indices.values():
            # 각 클래스에서 동일한 수의 샘플 추출 (복원 추출)
            sampled = np.random.choice(
                cls_idx,
                self.num_samples_per_class,
                replace=True
            ).tolist()
            indices.extend(sampled)

        # 셔플
        np.random.shuffle(indices)
        return iter(indices)

    def __len__(self):
        return self.num_classes * self.num_samples_per_class

10. PyTorch Lightning

LightningModule 완전 예시

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

class LightningTransformer(pl.LightningModule):
    def __init__(
        self,
        d_model=512,
        nhead=8,
        num_layers=6,
        num_classes=10,
        learning_rate=1e-4
    ):
        super().__init__()
        self.save_hyperparameters()  # HParams 자동 저장

        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=d_model,
                nhead=nhead,
                batch_first=True
            ),
            num_layers=num_layers
        )
        self.classifier = nn.Linear(d_model, num_classes)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        encoded = self.encoder(x)
        # 시퀀스 평균 풀링
        pooled = encoded.mean(dim=1)
        return self.classifier(pooled)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()

        # 로깅
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()

        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=0.01
        )
        # 코사인 어닐링 스케줄러
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=100
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss"
            }
        }


class LightningDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32):
        super().__init__()
        self.batch_size = batch_size

    def setup(self, stage=None):
        # 더미 데이터
        x = torch.randn(1000, 32, 512)
        y = torch.randint(0, 10, (1000,))
        dataset = TensorDataset(x, y)
        n = len(dataset)
        self.train_ds = torch.utils.data.Subset(dataset, range(int(0.8*n)))
        self.val_ds = torch.utils.data.Subset(dataset, range(int(0.8*n), n))

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=self.batch_size)


# 훈련 실행
model = LightningTransformer()
dm = LightningDataModule()

# 콜백 설정
callbacks = [
    ModelCheckpoint(
        monitor="val_loss",
        save_top_k=3,
        mode="min",
        filename="transformer-{epoch:02d}-{val_loss:.2f}"
    ),
    EarlyStopping(monitor="val_loss", patience=10, mode="min")
]

# Trainer
trainer = pl.Trainer(
    max_epochs=100,
    accelerator="auto",    # GPU 자동 감지
    devices="auto",
    callbacks=callbacks,
    logger=TensorBoardLogger("tb_logs", name="transformer"),
    gradient_clip_val=1.0,  # 그래디언트 클리핑
    accumulate_grad_batches=4,  # Gradient accumulation
    precision="16-mixed",   # AMP
)

trainer.fit(model, dm)

11. 모델 양자화 (Quantization)

동적 양자화 (Dynamic Quantization)

import torch
import torch.nn as nn
from torch.ao.quantization import quantize_dynamic

# 추론 전용 양자화 - 가장 쉬운 방법
model = nn.Sequential(
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Linear(256, 128)
)
model.eval()

# Linear와 LSTM 레이어를 int8로 양자화
quantized_model = quantize_dynamic(
    model,
    {nn.Linear, nn.LSTM},  # 양자화할 레이어 타입
    dtype=torch.qint8
)

# 크기 비교
import os
torch.save(model.state_dict(), "model_fp32.pt")
torch.save(quantized_model.state_dict(), "model_int8.pt")
fp32_size = os.path.getsize("model_fp32.pt")
int8_size = os.path.getsize("model_int8.pt")
print(f"FP32 size: {fp32_size / 1024:.1f} KB")
print(f"INT8 size: {int8_size / 1024:.1f} KB")
print(f"Compression: {fp32_size / int8_size:.1f}x")

정적 양자화 (Static Quantization)

import torch
from torch.ao.quantization import (
    get_default_qconfig,
    prepare,
    convert,
    QConfig
)

class QuantizableModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = torch.ao.quantization.QuantStub()
        self.linear1 = nn.Linear(512, 256)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(256, 10)
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.dequant(x)
        return x

model = QuantizableModel()
model.eval()

# 양자화 설정
model.qconfig = get_default_qconfig("fbgemm")  # x86 CPU용

# 준비 (옵저버 삽입)
prepared_model = prepare(model)

# 캘리브레이션 데이터로 통계 수집
with torch.no_grad():
    for _ in range(100):
        calibration_data = torch.randn(32, 512)
        prepared_model(calibration_data)

# 양자화 변환
quantized_static = convert(prepared_model)

# 추론
x = torch.randn(1, 512)
with torch.no_grad():
    output = quantized_static(x)
print(f"Output shape: {output.shape}")

QAT (Quantization-Aware Training)

from torch.ao.quantization import (
    get_default_qat_qconfig,
    prepare_qat,
    convert
)

model = QuantizableModel()
model.train()

# QAT 설정
model.qconfig = get_default_qat_qconfig("fbgemm")

# QAT 준비 (가짜 양자화 노드 삽입)
prepared_qat = prepare_qat(model)

# 훈련 (양자화 오류를 학습에 포함)
optimizer = torch.optim.SGD(prepared_qat.parameters(), lr=0.0001)
for epoch in range(10):
    for x, y in dummy_dataloader():
        output = prepared_qat(x)
        loss = nn.functional.cross_entropy(output, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# 훈련 완료 후 변환
prepared_qat.eval()
quantized_qat = convert(prepared_qat)


def dummy_dataloader():
    for _ in range(10):
        yield torch.randn(32, 512), torch.randint(0, 10, (32,))

12. Tensor Parallelism과 Pipeline Parallelism

DeviceMesh와 DTensor API

import torch
import torch.distributed as dist
from torch.distributed.tensor.parallel import (
    parallelize_module,
    ColwiseParallel,
    RowwiseParallel
)
from torch.distributed._tensor import DeviceMesh

# 분산 훈련 초기화
def setup_distributed():
    dist.init_process_group(backend="nccl")
    torch.cuda.set_device(dist.get_rank())

class TransformerMLP(nn.Module):
    def __init__(self, d_model=1024, dim_feedforward=4096):
        super().__init__()
        self.fc1 = nn.Linear(d_model, dim_feedforward)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(dim_feedforward, d_model)

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))

# Tensor Parallelism 적용
def apply_tensor_parallel(model, mesh):
    """
    fc1은 Column-wise로 샤딩 (출력 차원 분할)
    fc2는 Row-wise로 샤딩 (입력 차원 분할)
    """
    parallelize_module(
        model,
        mesh,
        {
            "fc1": ColwiseParallel(),
            "fc2": RowwiseParallel(),
        }
    )
    return model

# 실행 예시 (2 GPU 환경)
# device_mesh = DeviceMesh("cuda", [0, 1])
# model = TransformerMLP()
# model = apply_tensor_parallel(model, device_mesh)

FSDP (Fully Sharded Data Parallel)

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    BackwardPrefetch,
    ShardingStrategy
)
from torch.distributed.fsdp.wrap import (
    size_based_auto_wrap_policy,
    transformer_auto_wrap_policy
)
import functools

def setup_fsdp_model(model):
    """FSDP 설정 - 대규모 모델을 여러 GPU에 분산"""

    # 혼합 정밀도 설정
    mp_policy = MixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.float32,
        buffer_dtype=torch.bfloat16,
    )

    # TransformerBlock을 래핑 단위로 설정
    auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={TransformerBlock}
    )

    model = FSDP(
        model,
        auto_wrap_policy=auto_wrap_policy,
        mixed_precision=mp_policy,
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        device_id=torch.cuda.current_device()
    )

    return model

마치며

이 가이드에서는 PyTorch의 고급 기법들을 살펴보았습니다:

  1. torch.compile: 코드 수정 없이 2배 이상의 성능 향상
  2. Custom Autograd: 특수한 그래디언트 계산 구현
  3. CUDA Extensions: GPU 커널을 PyTorch와 통합
  4. 메모리 최적화: Gradient Checkpointing, AMP, 8-bit Optimizer
  5. functorch/vmap: 함수형 API로 배치 처리와 메타러닝
  6. PyTorch Profiler: 성능 병목 분석
  7. TorchScript/export: 배포 최적화
  8. PyTorch Lightning: 코드 구조화와 훈련 자동화
  9. 양자화: INT8로 모델 크기 축소
  10. 분산 학습: Tensor Parallel, FSDP

각 기법은 서로 보완적이며, 실제 프로젝트에서는 여러 기법을 조합하여 사용하는 것이 일반적입니다. 특히 대규모 모델 훈련 시에는 AMP + Gradient Checkpointing + FSDP + torch.compile 조합이 강력한 효과를 발휘합니다.

참고 자료

PyTorch Advanced Techniques Complete Guide: torch.compile, Custom Ops, Memory Optimization

PyTorch Advanced Techniques Complete Guide

PyTorch is one of the most popular deep learning frameworks for both research and production. However, most developers stop at basic tensor operations and nn.Module definitions. This guide covers advanced techniques for maximizing performance, implementing custom operations, and managing memory efficiently in real production environments.

1. torch.compile (PyTorch 2.0+)

torch.compile, introduced in PyTorch 2.0, is a feature that compiles your model to dramatically improve execution speed. Unlike TorchScript or ONNX export, torch.compile can bring over 2x speedups with minimal code changes.

Introduction and Benefits

torch.compile is composed of three core components:

  1. TorchDynamo: Intercepts Python bytecode to generate an FX graph
  2. AOTAutograd: Pre-compiles the automatic differentiation graph
  3. Inductor: Generates optimized kernels via TorchInductor backend (Triton GPU kernels or C++ CPU kernels)
import torch
import torch.nn as nn
import time

# Define a base model
class TransformerBlock(nn.Module):
    def __init__(self, d_model=512, nhead=8, dim_feedforward=2048):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, nhead, batch_first=True)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_out)
        ff_out = self.feed_forward(x)
        x = self.norm2(x + ff_out)
        return x

device = "cuda" if torch.cuda.is_available() else "cpu"
model = TransformerBlock().to(device)

# Apply torch.compile
compiled_model = torch.compile(model)

# Warmup
x = torch.randn(32, 128, 512, device=device)
for _ in range(3):
    _ = compiled_model(x)

# Performance comparison
N = 100

# Eager mode
start = time.perf_counter()
for _ in range(N):
    _ = model(x)
if device == "cuda":
    torch.cuda.synchronize()
elapsed_eager = time.perf_counter() - start

# Compiled mode
start = time.perf_counter()
for _ in range(N):
    _ = compiled_model(x)
if device == "cuda":
    torch.cuda.synchronize()
elapsed_compiled = time.perf_counter() - start

print(f"Eager mode: {elapsed_eager:.3f}s")
print(f"Compiled mode: {elapsed_compiled:.3f}s")
print(f"Speedup: {elapsed_eager / elapsed_compiled:.2f}x")

Compilation Modes

torch.compile offers three modes:

# Default mode - fast compile, good performance
model_default = torch.compile(model, mode="default")

# Reduce-overhead mode - better for small models
model_reduce = torch.compile(model, mode="reduce-overhead")

# Max-autotune - longer compile, best performance
model_autotune = torch.compile(model, mode="max-autotune")

# Full graph mode - no dynamic graphs (strict)
model_full = torch.compile(model, fullgraph=True)

# Backend selection
model_eager = torch.compile(model, backend="eager")    # No compilation (debug)
model_aot = torch.compile(model, backend="aot_eager")  # AOT only
model_inductor = torch.compile(model, backend="inductor")  # Default

Dynamic Shape Support

import torch._dynamo as dynamo

# Enable dynamic shapes
model_dynamic = torch.compile(model, dynamic=True)

# Works without recompilation across different batch sizes
for batch_size in [8, 16, 32, 64]:
    x = torch.randn(batch_size, 128, 512, device=device)
    out = model_dynamic(x)
    print(f"Batch {batch_size}: output shape {out.shape}")

# Check compilation cache
print(dynamo.explain(model)(x))

Migrating Existing Code

# Applying torch.compile to a training loop
def train_epoch(model, optimizer, dataloader, criterion):
    model.train()
    total_loss = 0

    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(dataloader)

# Just compile the model
model = MyModel().cuda()
compiled_model = torch.compile(model)  # Only this line added!

optimizer = torch.optim.Adam(compiled_model.parameters())
criterion = nn.CrossEntropyLoss()

2. Custom Autograd

PyTorch's automatic differentiation engine is powerful, but sometimes you need custom gradient computation for numerical stability or efficiency.

Subclassing torch.autograd.Function

import torch
from torch.autograd import Function

class SigmoidFunction(Function):
    """Numerically stable sigmoid implementation"""

    @staticmethod
    def forward(ctx, input):
        # sigmoid = 1 / (1 + exp(-x))
        sigmoid = torch.sigmoid(input)
        # Save tensors for use in backward
        ctx.save_for_backward(sigmoid)
        return sigmoid

    @staticmethod
    def backward(ctx, grad_output):
        # Retrieve saved sigmoid
        sigmoid, = ctx.saved_tensors
        # gradient = sigmoid * (1 - sigmoid) * grad_output
        grad_input = sigmoid * (1 - sigmoid) * grad_output
        return grad_input

# Usage
def custom_sigmoid(x):
    return SigmoidFunction.apply(x)

x = torch.randn(3, 4, requires_grad=True)
y = custom_sigmoid(x)
loss = y.sum()
loss.backward()
print(f"Gradient: {x.grad}")

More Complex Example: Leaky ReLU with Custom Backward

class LeakyReLUFunction(Function):
    @staticmethod
    def forward(ctx, input, negative_slope=0.01):
        ctx.save_for_backward(input)
        ctx.negative_slope = negative_slope
        return input.clamp(min=0) + negative_slope * input.clamp(max=0)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        negative_slope = ctx.negative_slope
        grad_input = grad_output.clone()
        grad_input[input < 0] *= negative_slope
        # Gradient with respect to negative_slope is None
        return grad_input, None

class CustomLeakyReLU(nn.Module):
    def __init__(self, negative_slope=0.01):
        super().__init__()
        self.negative_slope = negative_slope

    def forward(self, x):
        return LeakyReLUFunction.apply(x, self.negative_slope)

Numerical Gradient Check

from torch.autograd import gradcheck

def test_custom_op():
    # Numerical gradient check (float64 recommended)
    input = torch.randn(3, 4, dtype=torch.float64, requires_grad=True)

    # gradcheck computes Jacobian numerically and compares with autograd
    result = gradcheck(SigmoidFunction.apply, (input,), eps=1e-6, atol=1e-4)
    print(f"Gradient check passed: {result}")

    # Double backprop test
    input = torch.randn(3, 4, dtype=torch.float64, requires_grad=True)
    result = gradcheck(
        SigmoidFunction.apply,
        (input,),
        eps=1e-6,
        atol=1e-4,
        check_grad_dtypes=True
    )

test_custom_op()

Double Backpropagation

class SquaredFunction(Function):
    """Custom x^2 implementation with double backprop support"""

    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x ** 2

    @staticmethod
    def backward(ctx, grad_output):
        x, = ctx.saved_tensors
        # create_graph=True required to support double backprop
        return 2 * x * grad_output

# Double backprop example (useful in meta-learning like MAML)
x = torch.randn(3, requires_grad=True)
y = SquaredFunction.apply(x)
grad_x = torch.autograd.grad(y.sum(), x, create_graph=True)[0]
# grad_x = 2x, backprop through this again
grad_grad_x = torch.autograd.grad(grad_x.sum(), x)[0]
# grad_grad_x = 2 (constant)
print(f"Second derivative: {grad_grad_x}")

3. Custom CUDA Operators

torch.utils.cpp_extension Overview

PyTorch provides tools for writing C++/CUDA extensions and using them from Python.

# JIT compilation (suitable for development/prototyping)
from torch.utils.cpp_extension import load_inline

# C++ CPU operator
cpp_source = """
#include <torch/extension.h>

torch::Tensor relu_forward(torch::Tensor input) {
    return input.clamp_min(0);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("relu_forward", &relu_forward, "ReLU forward");
}
"""

# Inline compilation
custom_relu_cpp = load_inline(
    name="custom_relu_cpp",
    cpp_sources=cpp_source,
    functions=["relu_forward"],
    verbose=False
)

x = torch.randn(5)
result = custom_relu_cpp.relu_forward(x)
print(result)

CUDA Kernel Example: Fused Softmax

# CUDA source
cuda_source = """
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

__global__ void fused_softmax_kernel(
    const float* input,
    float* output,
    int rows,
    int cols
) {
    int row = blockIdx.x;
    if (row >= rows) return;

    const float* row_input = input + row * cols;
    float* row_output = output + row * cols;

    // Find max (for numerical stability)
    float max_val = row_input[0];
    for (int i = 1; i < cols; i++) {
        max_val = fmaxf(max_val, row_input[i]);
    }

    // Compute exp(x - max) and sum
    float sum = 0.0f;
    for (int i = 0; i < cols; i++) {
        row_output[i] = expf(row_input[i] - max_val);
        sum += row_output[i];
    }

    // Normalize
    for (int i = 0; i < cols; i++) {
        row_output[i] /= sum;
    }
}

torch::Tensor fused_softmax_cuda(torch::Tensor input) {
    auto output = torch::zeros_like(input);
    int rows = input.size(0);
    int cols = input.size(1);

    fused_softmax_kernel<<<rows, 1>>>(
        input.data_ptr<float>(),
        output.data_ptr<float>(),
        rows,
        cols
    );

    return output;
}
"""

cpp_source_cuda = """
#include <torch/extension.h>
torch::Tensor fused_softmax_cuda(torch::Tensor input);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("fused_softmax", &fused_softmax_cuda, "Fused Softmax CUDA");
}
"""

# Compile only if CUDA is available
if torch.cuda.is_available():
    from torch.utils.cpp_extension import load_inline
    fused_softmax_ext = load_inline(
        name="fused_softmax",
        cpp_sources=cpp_source_cuda,
        cuda_sources=cuda_source,
        functions=["fused_softmax"],
        verbose=True
    )

    # Test
    x = torch.randn(4, 8, device="cuda")
    result = fused_softmax_ext.fused_softmax(x)
    expected = torch.softmax(x, dim=1)
    print(f"Max difference: {(result - expected).abs().max().item():.6f}")

Building a Package with setup.py

# setup.py
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
    name="custom_ops",
    ext_modules=[
        CUDAExtension(
            name="custom_ops",
            sources=[
                "custom_ops/ops.cpp",
                "custom_ops/ops_cuda.cu",
            ],
            extra_compile_args={
                "cxx": ["-O3"],
                "nvcc": ["-O3", "--use_fast_math"],
            }
        )
    ],
    cmdclass={
        "build_ext": BuildExtension
    }
)
# Build: python setup.py install

4. Memory Optimization Techniques

GPU Memory Profiling

import torch

def print_gpu_memory_stats():
    """Utility to print GPU memory statistics"""
    if torch.cuda.is_available():
        print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
        print(f"Reserved:  {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
        print(f"Max Allocated: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
        print(torch.cuda.memory_summary(abbreviated=True))

# Start memory tracking
torch.cuda.reset_peak_memory_stats()
print_gpu_memory_stats()

# Load model
model = TransformerBlock().cuda()
print_gpu_memory_stats()

Gradient Checkpointing (Activation Recomputation)

Gradient Checkpointing avoids storing intermediate activation values during the forward pass, recomputing them as needed during the backward pass. This saves significant memory at the cost of roughly 30% more computation.

from torch.utils.checkpoint import checkpoint, checkpoint_sequential
import torch.nn as nn

class DeepTransformer(nn.Module):
    def __init__(self, num_layers=12, d_model=512, nhead=8):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, nhead) for _ in range(num_layers)
        ])

    def forward_with_checkpointing(self, x):
        """Apply gradient checkpointing to each layer"""
        for layer in self.layers:
            # checkpoint does not store intermediate activations
            x = checkpoint(layer, x, use_reentrant=False)
        return x

    def forward_sequential_checkpointing(self, x):
        """Apply checkpointing to sequential modules"""
        # Group into segments of 4 layers
        x = checkpoint_sequential(self.layers, segments=3, input=x)
        return x

# Memory comparison
model = DeepTransformer(num_layers=24).cuda()
x = torch.randn(16, 512, 512, device="cuda")

# Normal forward
torch.cuda.reset_peak_memory_stats()
out = model.forward_with_checkpointing(x)
checkpoint_mem = torch.cuda.max_memory_allocated()

print(f"Checkpoint memory: {checkpoint_mem / 1024**3:.2f} GB")

Gradient Accumulation

def train_with_gradient_accumulation(
    model, optimizer, dataloader, criterion,
    accumulation_steps=4
):
    """
    Implement large effective batch size with limited GPU memory
    """
    model.train()
    optimizer.zero_grad()
    total_loss = 0

    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.cuda(), target.cuda()

        # Forward pass
        output = model(data)
        # Scale loss by accumulation_steps
        loss = criterion(output, target) / accumulation_steps
        loss.backward()

        total_loss += loss.item() * accumulation_steps

        # Update every accumulation_steps
        if (batch_idx + 1) % accumulation_steps == 0:
            # Gradient clipping (optional)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            optimizer.zero_grad()

    return total_loss / len(dataloader)

Mixed Precision Training (AMP)

from torch.cuda.amp import autocast, GradScaler

def train_with_amp(model, optimizer, dataloader, criterion):
    """Automatic Mixed Precision for memory savings + speed boost"""
    scaler = GradScaler()
    model.train()

    for data, target in dataloader:
        data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()

        # float16 operations inside autocast context
        with autocast():
            output = model(data)
            loss = criterion(output, target)

        # Backward with scaled gradients
        scaler.scale(loss).backward()

        # Unscale then clip gradients
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        # Optimizer step (with NaN/Inf check)
        scaler.step(optimizer)
        scaler.update()

8-bit Optimizer

# Requires bitsandbytes: pip install bitsandbytes
try:
    import bitsandbytes as bnb

    # Use 8-bit Adam instead of regular Adam
    optimizer_8bit = bnb.optim.Adam8bit(
        model.parameters(),
        lr=1e-4,
        betas=(0.9, 0.999)
    )

    # PagedAdam (can offload to CPU)
    optimizer_paged = bnb.optim.PagedAdam(
        model.parameters(),
        lr=1e-4
    )

    print("8-bit optimizer loaded successfully")
except ImportError:
    print("bitsandbytes not installed, using regular Adam")
    optimizer_8bit = torch.optim.Adam(model.parameters(), lr=1e-4)

5. functorch and vmap

vmap - Batched Operations

vmap efficiently applies a function operating on single samples to a batch.

import torch
from torch import vmap

# Function operating on a single sample
def single_linear(weight, bias, x):
    return weight @ x + bias

# Batch processing with vmap
batched_linear = vmap(single_linear)

# Batch data
batch_size = 32
weight = torch.randn(batch_size, 10, 5)
bias = torch.randn(batch_size, 10)
x = torch.randn(batch_size, 5)

# Automatically handled in batch
result = batched_linear(weight, bias, x)
print(f"Result shape: {result.shape}")  # (32, 10)

grad - Functional Gradient

from torch.func import grad, vmap, functional_call

# Functional gradient
def scalar_loss(params, x, y):
    pred = functional_call(model, params, (x,))
    return ((pred - y) ** 2).mean()

# Gradient with respect to parameters
params = dict(model.named_parameters())
grad_fn = grad(scalar_loss)

x = torch.randn(1, 10)
y = torch.randn(1, 5)
grads = grad_fn(params, x, y)
print({k: v.shape for k, v in grads.items()})

Ensemble Models with vmap

from torch.func import stack_module_state, functional_call, vmap

def create_ensemble(model_class, num_models, *args, **kwargs):
    """Efficient ensemble using vmap"""
    models = [model_class(*args, **kwargs) for _ in range(num_models)]

    # Stack parameters from all models
    params, buffers = stack_module_state(models)

    # Single model forward
    base_model = model_class(*args, **kwargs)

    def single_forward(params, buffers, x):
        return functional_call(base_model, (params, buffers), (x,))

    # Run all models in parallel with vmap
    ensemble_forward = vmap(single_forward, in_dims=(0, 0, None))

    return ensemble_forward, params, buffers

# Usage example
ensemble_fn, params, buffers = create_ensemble(
    nn.Linear, num_models=5, in_features=10, out_features=5
)

x = torch.randn(32, 10)
ensemble_out = ensemble_fn(params, buffers, x)
print(f"Ensemble output shape: {ensemble_out.shape}")  # (5, 32, 5)

Meta-Learning (MAML) with grad and vmap

from torch.func import grad, vmap, functional_call

def inner_loop(params, support_x, support_y, base_model, lr=0.01, steps=5):
    """MAML inner loop"""
    adapted_params = {k: v.clone() for k, v in params.items()}

    for _ in range(steps):
        def loss_fn(params):
            pred = functional_call(base_model, params, (support_x,))
            return ((pred - support_y) ** 2).mean()

        grads = grad(loss_fn)(adapted_params)
        adapted_params = {
            k: p - lr * grads[k]
            for k, p in adapted_params.items()
        }

    return adapted_params

6. PyTorch Profiler

Basic Profiling

import torch
from torch.profiler import profile, record_function, ProfilerActivity

model = TransformerBlock().cuda()
x = torch.randn(32, 128, 512, device="cuda")

# Run profiler
with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True,
    profile_memory=True,
    with_stack=True
) as prof:
    with record_function("model_inference"):
        for _ in range(10):
            output = model(x)

# Print results
print(prof.key_averages().table(
    sort_by="cuda_time_total",
    row_limit=20
))

# Export Chrome Trace
prof.export_chrome_trace("trace.json")
# Open in chrome://tracing

TensorBoard Integration

from torch.profiler import profile, ProfilerActivity, tensorboard_trace_handler

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    schedule=torch.profiler.schedule(
        wait=1,    # Wait 1 step
        warmup=1,  # 1 warmup step
        active=3,  # Profile 3 steps
        repeat=2   # Repeat 2 times
    ),
    on_trace_ready=tensorboard_trace_handler("./log/profiler"),
    record_shapes=True,
    profile_memory=True
) as prof:
    for step, (data, target) in enumerate(dataloader):
        train_step(model, optimizer, data, target)
        prof.step()  # Profile according to schedule

# tensorboard --logdir=./log/profiler

Detailed Memory Analysis

# Memory snapshot (PyTorch 2.1+)
torch.cuda.memory._record_memory_history(max_entries=100000)

# Run code
x = torch.randn(100, 100, device="cuda")
y = x @ x.T
z = y.sum()

# Save snapshot
snapshot = torch.cuda.memory._snapshot()
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")

print(f"Active allocations: {len(snapshot['segments'])}")

7. TorchScript

torch.jit.script vs torch.jit.trace

import torch
import torch.nn as nn

class ConditionalModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)

    def forward(self, x, flag: bool = True):
        if flag:  # Control flow makes trace impossible
            return torch.relu(self.linear(x))
        else:
            return torch.sigmoid(self.linear(x))

model = ConditionalModel()

# script: includes control flow (recommended)
scripted_model = torch.jit.script(model)
print(scripted_model.code)

# trace: captures only a single path
x = torch.randn(1, 10)
traced_model = torch.jit.trace(model, x)
# Note: flag=False path is not captured

# Save and load
scripted_model.save("model_scripted.pt")
loaded = torch.jit.load("model_scripted.pt")

TorchScript Optimization

# Apply optimization passes
scripted = torch.jit.script(model)
optimized = torch.jit.optimize_for_inference(scripted)

# Export for C++ environment
# In C++: torch::jit::script::Module m = torch::jit::load("model.pt");

8. Dynamic Shapes and torch.export

Using torch.export

import torch
from torch.export import export

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)

    def forward(self, x):
        return self.linear(x)

model = MyModel()
example_inputs = (torch.randn(2, 10),)

# Specify dynamic shapes
from torch.export import Dim
batch = Dim("batch", min=1, max=100)

# Export model
exported = export(
    model,
    example_inputs,
    dynamic_shapes={"x": {0: batch}}
)

print(exported)
print(exported.graph_module.code)

# Run ExportedProgram
result = exported.module()(torch.randn(5, 10))
print(f"Result shape: {result.shape}")

9. Custom Dataset and Sampler

IterableDataset

from torch.utils.data import IterableDataset, DataLoader
import torch

class StreamingDataset(IterableDataset):
    """Stream large datasets efficiently"""

    def __init__(self, data_paths, transform=None):
        self.data_paths = data_paths
        self.transform = transform

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()

        if worker_info is None:
            # Single process
            paths = self.data_paths
        else:
            # Multiprocessing: split data across workers
            per_worker = len(self.data_paths) // worker_info.num_workers
            start = worker_info.id * per_worker
            end = start + per_worker
            paths = self.data_paths[start:end]

        for path in paths:
            # Stream data from file
            data = self._load_file(path)
            for sample in data:
                if self.transform:
                    sample = self.transform(sample)
                yield sample

    def _load_file(self, path):
        # In real implementation, load from file
        return [torch.randn(10) for _ in range(100)]

# Use persistent_workers in DataLoader
dataset = StreamingDataset(data_paths=["file1.pt", "file2.pt"])
dataloader = DataLoader(
    dataset,
    batch_size=32,
    num_workers=4,
    persistent_workers=True,  # Reuse worker processes
    pin_memory=True,           # Optimize GPU transfer
    prefetch_factor=2          # Number of batches to prefetch
)

Custom Sampler

from torch.utils.data import Sampler
import numpy as np

class BalancedClassSampler(Sampler):
    """Weighted sampling for class imbalance problems"""

    def __init__(self, dataset, num_samples_per_class=None):
        self.dataset = dataset
        labels = [dataset[i][1] for i in range(len(dataset))]
        self.labels = torch.tensor(labels)

        # Indices per class
        self.class_indices = {}
        for cls in torch.unique(self.labels):
            self.class_indices[cls.item()] = (
                self.labels == cls
            ).nonzero(as_tuple=True)[0].tolist()

        self.num_classes = len(self.class_indices)
        self.num_samples_per_class = (
            num_samples_per_class or
            max(len(v) for v in self.class_indices.values())
        )

    def __iter__(self):
        indices = []
        for cls_idx in self.class_indices.values():
            # Sample equal numbers from each class (with replacement)
            sampled = np.random.choice(
                cls_idx,
                self.num_samples_per_class,
                replace=True
            ).tolist()
            indices.extend(sampled)

        # Shuffle
        np.random.shuffle(indices)
        return iter(indices)

    def __len__(self):
        return self.num_classes * self.num_samples_per_class

10. PyTorch Lightning

Complete LightningModule Example

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

class LightningTransformer(pl.LightningModule):
    def __init__(
        self,
        d_model=512,
        nhead=8,
        num_layers=6,
        num_classes=10,
        learning_rate=1e-4
    ):
        super().__init__()
        self.save_hyperparameters()  # Auto-save HParams

        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=d_model,
                nhead=nhead,
                batch_first=True
            ),
            num_layers=num_layers
        )
        self.classifier = nn.Linear(d_model, num_classes)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        encoded = self.encoder(x)
        # Sequence mean pooling
        pooled = encoded.mean(dim=1)
        return self.classifier(pooled)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()

        # Logging
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()

        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=0.01
        )
        # Cosine annealing scheduler
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=100
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss"
            }
        }


class LightningDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32):
        super().__init__()
        self.batch_size = batch_size

    def setup(self, stage=None):
        # Dummy data
        x = torch.randn(1000, 32, 512)
        y = torch.randint(0, 10, (1000,))
        dataset = TensorDataset(x, y)
        n = len(dataset)
        self.train_ds = torch.utils.data.Subset(dataset, range(int(0.8*n)))
        self.val_ds = torch.utils.data.Subset(dataset, range(int(0.8*n), n))

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=self.batch_size)


# Run training
model = LightningTransformer()
dm = LightningDataModule()

# Setup callbacks
callbacks = [
    ModelCheckpoint(
        monitor="val_loss",
        save_top_k=3,
        mode="min",
        filename="transformer-{epoch:02d}-{val_loss:.2f}"
    ),
    EarlyStopping(monitor="val_loss", patience=10, mode="min")
]

# Trainer
trainer = pl.Trainer(
    max_epochs=100,
    accelerator="auto",    # Auto-detect GPU
    devices="auto",
    callbacks=callbacks,
    logger=TensorBoardLogger("tb_logs", name="transformer"),
    gradient_clip_val=1.0,  # Gradient clipping
    accumulate_grad_batches=4,  # Gradient accumulation
    precision="16-mixed",   # AMP
)

trainer.fit(model, dm)

11. Model Quantization

Dynamic Quantization

import torch
import torch.nn as nn
from torch.ao.quantization import quantize_dynamic

# Inference-only quantization - easiest approach
model = nn.Sequential(
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Linear(256, 128)
)
model.eval()

# Quantize Linear and LSTM layers to int8
quantized_model = quantize_dynamic(
    model,
    {nn.Linear, nn.LSTM},  # Layer types to quantize
    dtype=torch.qint8
)

# Size comparison
import os
torch.save(model.state_dict(), "model_fp32.pt")
torch.save(quantized_model.state_dict(), "model_int8.pt")
fp32_size = os.path.getsize("model_fp32.pt")
int8_size = os.path.getsize("model_int8.pt")
print(f"FP32 size: {fp32_size / 1024:.1f} KB")
print(f"INT8 size: {int8_size / 1024:.1f} KB")
print(f"Compression: {fp32_size / int8_size:.1f}x")

Static Quantization

import torch
from torch.ao.quantization import (
    get_default_qconfig,
    prepare,
    convert,
)

class QuantizableModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = torch.ao.quantization.QuantStub()
        self.linear1 = nn.Linear(512, 256)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(256, 10)
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.dequant(x)
        return x

model = QuantizableModel()
model.eval()

# Quantization config
model.qconfig = get_default_qconfig("fbgemm")  # For x86 CPU

# Prepare (insert observers)
prepared_model = prepare(model)

# Collect statistics with calibration data
with torch.no_grad():
    for _ in range(100):
        calibration_data = torch.randn(32, 512)
        prepared_model(calibration_data)

# Convert to quantized
quantized_static = convert(prepared_model)

# Inference
x = torch.randn(1, 512)
with torch.no_grad():
    output = quantized_static(x)
print(f"Output shape: {output.shape}")

QAT (Quantization-Aware Training)

from torch.ao.quantization import (
    get_default_qat_qconfig,
    prepare_qat,
    convert
)

model = QuantizableModel()
model.train()

# QAT config
model.qconfig = get_default_qat_qconfig("fbgemm")

# Prepare QAT (insert fake quantization nodes)
prepared_qat = prepare_qat(model)

# Train (include quantization error in training)
optimizer = torch.optim.SGD(prepared_qat.parameters(), lr=0.0001)
for epoch in range(10):
    for x, y in dummy_dataloader():
        output = prepared_qat(x)
        loss = nn.functional.cross_entropy(output, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# Convert after training
prepared_qat.eval()
quantized_qat = convert(prepared_qat)


def dummy_dataloader():
    for _ in range(10):
        yield torch.randn(32, 512), torch.randint(0, 10, (32,))

12. Tensor Parallelism and Pipeline Parallelism

DeviceMesh and DTensor API

import torch
import torch.distributed as dist
from torch.distributed.tensor.parallel import (
    parallelize_module,
    ColwiseParallel,
    RowwiseParallel
)
from torch.distributed._tensor import DeviceMesh

# Initialize distributed training
def setup_distributed():
    dist.init_process_group(backend="nccl")
    torch.cuda.set_device(dist.get_rank())

class TransformerMLP(nn.Module):
    def __init__(self, d_model=1024, dim_feedforward=4096):
        super().__init__()
        self.fc1 = nn.Linear(d_model, dim_feedforward)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(dim_feedforward, d_model)

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))

# Apply Tensor Parallelism
def apply_tensor_parallel(model, mesh):
    """
    fc1: Column-wise sharding (split output dimension)
    fc2: Row-wise sharding (split input dimension)
    """
    parallelize_module(
        model,
        mesh,
        {
            "fc1": ColwiseParallel(),
            "fc2": RowwiseParallel(),
        }
    )
    return model

# Example (requires 2-GPU setup)
# device_mesh = DeviceMesh("cuda", [0, 1])
# model = TransformerMLP()
# model = apply_tensor_parallel(model, device_mesh)

FSDP (Fully Sharded Data Parallel)

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    BackwardPrefetch,
    ShardingStrategy
)
from torch.distributed.fsdp.wrap import (
    size_based_auto_wrap_policy,
    transformer_auto_wrap_policy
)
import functools

def setup_fsdp_model(model):
    """FSDP setup - distribute large models across multiple GPUs"""

    # Mixed precision config
    mp_policy = MixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.float32,
        buffer_dtype=torch.bfloat16,
    )

    # Set TransformerBlock as wrapping unit
    auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={TransformerBlock}
    )

    model = FSDP(
        model,
        auto_wrap_policy=auto_wrap_policy,
        mixed_precision=mp_policy,
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        device_id=torch.cuda.current_device()
    )

    return model

Summary

This guide covered PyTorch's advanced techniques:

  1. torch.compile: Over 2x performance improvement without code changes
  2. Custom Autograd: Implementing specialized gradient computation
  3. CUDA Extensions: Integrating GPU kernels with PyTorch
  4. Memory Optimization: Gradient Checkpointing, AMP, 8-bit Optimizer
  5. functorch/vmap: Batch processing and meta-learning via functional API
  6. PyTorch Profiler: Analyzing performance bottlenecks
  7. TorchScript/export: Deployment optimization
  8. PyTorch Lightning: Structuring code and automating training
  9. Quantization: Reducing model size with INT8
  10. Distributed Training: Tensor Parallel, FSDP

These techniques are complementary, and in real projects it is common to combine several of them. For large-scale model training, the combination of AMP + Gradient Checkpointing + FSDP + torch.compile is particularly powerful.

References