Skip to content
Published on

PyTorch 고급 기법 완전 가이드: torch.compile, Custom Ops, Memory 최적화

Authors

PyTorch 고급 기법 완전 가이드

PyTorch는 딥러닝 연구와 프로덕션 배포 모두에서 가장 인기 있는 프레임워크 중 하나입니다. 하지만 대부분의 개발자는 기본적인 텐서 연산과 nn.Module 정의에서 멈춥니다. 이 가이드에서는 실제 프로덕션 환경에서 성능을 극대화하고, 커스텀 연산을 구현하며, 메모리를 효율적으로 관리하는 고급 기법들을 다룹니다.

1. torch.compile (PyTorch 2.0+)

PyTorch 2.0에서 도입된 torch.compile은 모델을 컴파일하여 실행 속도를 획기적으로 향상시키는 기능입니다. 기존의 TorchScript나 ONNX export와 달리, torch.compile은 파이썬 코드를 거의 수정하지 않고도 최대 2배 이상의 속도 향상을 가져올 수 있습니다.

torch.compile 소개와 장점

torch.compile은 세 가지 핵심 컴포넌트로 구성됩니다:

  1. TorchDynamo: Python 바이트코드를 인터셉트하여 FX 그래프를 생성
  2. AOTAutograd: 자동 미분 그래프를 사전 컴파일
  3. Inductor: TorchInductor 백엔드로 최적화된 커널 생성 (Triton GPU 커널 또는 C++ CPU 커널)
import torch
import torch.nn as nn
import time

# 기본 모델 정의
class TransformerBlock(nn.Module):
    def __init__(self, d_model=512, nhead=8, dim_feedforward=2048):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, nhead, batch_first=True)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_out)
        ff_out = self.feed_forward(x)
        x = self.norm2(x + ff_out)
        return x

device = "cuda" if torch.cuda.is_available() else "cpu"
model = TransformerBlock().to(device)

# torch.compile 적용
compiled_model = torch.compile(model)

# 워밍업
x = torch.randn(32, 128, 512, device=device)
for _ in range(3):
    _ = compiled_model(x)

# 성능 비교
N = 100

# 일반 모델
start = time.perf_counter()
for _ in range(N):
    _ = model(x)
if device == "cuda":
    torch.cuda.synchronize()
elapsed_eager = time.perf_counter() - start

# 컴파일된 모델
start = time.perf_counter()
for _ in range(N):
    _ = compiled_model(x)
if device == "cuda":
    torch.cuda.synchronize()
elapsed_compiled = time.perf_counter() - start

print(f"Eager mode: {elapsed_eager:.3f}s")
print(f"Compiled mode: {elapsed_compiled:.3f}s")
print(f"Speedup: {elapsed_eager / elapsed_compiled:.2f}x")

컴파일 모드 (Compilation Modes)

torch.compile은 세 가지 모드를 제공합니다:

# 기본 모드 - 빠른 컴파일, 좋은 성능
model_default = torch.compile(model, mode="default")

# 오버헤드 감소 모드 - 작은 모델에 유리
model_reduce = torch.compile(model, mode="reduce-overhead")

# 최대 자동 튜닝 - 긴 컴파일, 최고 성능
model_autotune = torch.compile(model, mode="max-autotune")

# 풀 그래프 모드 - 동적 그래프 없음 (strict)
model_full = torch.compile(model, fullgraph=True)

# 백엔드 선택
model_eager = torch.compile(model, backend="eager")    # 컴파일 없음 (디버그용)
model_aot = torch.compile(model, backend="aot_eager")  # AOT만 적용
model_inductor = torch.compile(model, backend="inductor")  # 기본값

동적 형태 지원

import torch._dynamo as dynamo

# 동적 형태 활성화
model_dynamic = torch.compile(model, dynamic=True)

# 다양한 배치 크기에서도 재컴파일 없이 동작
for batch_size in [8, 16, 32, 64]:
    x = torch.randn(batch_size, 128, 512, device=device)
    out = model_dynamic(x)
    print(f"Batch {batch_size}: output shape {out.shape}")

# 컴파일 캐시 확인
print(dynamo.explain(model)(x))

기존 코드 마이그레이션

# 훈련 루프에 torch.compile 적용
def train_epoch(model, optimizer, dataloader, criterion):
    model.train()
    total_loss = 0

    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(dataloader)

# 모델만 컴파일하면 됩니다
model = MyModel().cuda()
compiled_model = torch.compile(model)  # 이것만 추가!

optimizer = torch.optim.Adam(compiled_model.parameters())
criterion = nn.CrossEntropyLoss()

2. 커스텀 자동 미분 (Custom Autograd)

PyTorch의 자동 미분 엔진은 강력하지만, 때로는 수치적으로 안정적이거나 효율적인 커스텀 그래디언트 계산이 필요합니다.

torch.autograd.Function 서브클래싱

import torch
from torch.autograd import Function

class SigmoidFunction(Function):
    """수치적으로 안정적인 Sigmoid 구현"""

    @staticmethod
    def forward(ctx, input):
        # sigmoid = 1 / (1 + exp(-x))
        sigmoid = torch.sigmoid(input)
        # backward에서 사용할 텐서를 저장
        ctx.save_for_backward(sigmoid)
        return sigmoid

    @staticmethod
    def backward(ctx, grad_output):
        # sigmoid를 불러옴
        sigmoid, = ctx.saved_tensors
        # gradient = sigmoid * (1 - sigmoid) * grad_output
        grad_input = sigmoid * (1 - sigmoid) * grad_output
        return grad_input

# 사용 예시
def custom_sigmoid(x):
    return SigmoidFunction.apply(x)

x = torch.randn(3, 4, requires_grad=True)
y = custom_sigmoid(x)
loss = y.sum()
loss.backward()
print(f"Gradient: {x.grad}")

더 복잡한 예시: Leaky ReLU with Custom Backward

class LeakyReLUFunction(Function):
    @staticmethod
    def forward(ctx, input, negative_slope=0.01):
        ctx.save_for_backward(input)
        ctx.negative_slope = negative_slope
        return input.clamp(min=0) + negative_slope * input.clamp(max=0)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        negative_slope = ctx.negative_slope
        grad_input = grad_output.clone()
        grad_input[input < 0] *= negative_slope
        # 두 번째 입력(negative_slope)에 대한 그래디언트는 None
        return grad_input, None

class CustomLeakyReLU(nn.Module):
    def __init__(self, negative_slope=0.01):
        super().__init__()
        self.negative_slope = negative_slope

    def forward(self, x):
        return LeakyReLUFunction.apply(x, self.negative_slope)

수치 그래디언트 체크

from torch.autograd import gradcheck

def test_custom_op():
    # 수치 그래디언트 체크 (float64 권장)
    input = torch.randn(3, 4, dtype=torch.float64, requires_grad=True)

    # gradcheck는 Jacobian을 수치적으로 계산하고 자동 미분과 비교
    result = gradcheck(SigmoidFunction.apply, (input,), eps=1e-6, atol=1e-4)
    print(f"Gradient check passed: {result}")

    # 더블 역전파 테스트
    input = torch.randn(3, 4, dtype=torch.float64, requires_grad=True)
    result = gradcheck(
        SigmoidFunction.apply,
        (input,),
        eps=1e-6,
        atol=1e-4,
        check_grad_dtypes=True
    )

test_custom_op()

더블 역전파 (Double Backpropagation)

class SquaredFunction(Function):
    """x^2의 커스텀 구현 - 더블 역전파 지원"""

    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x ** 2

    @staticmethod
    def backward(ctx, grad_output):
        x, = ctx.saved_tensors
        # 더블 역전파를 지원하려면 create_graph=True 설정 필요
        return 2 * x * grad_output

# 더블 역전파 예시 (MAML 같은 메타 학습에서 활용)
x = torch.randn(3, requires_grad=True)
y = SquaredFunction.apply(x)
grad_x = torch.autograd.grad(y.sum(), x, create_graph=True)[0]
# grad_x = 2x, 이를 다시 역전파
grad_grad_x = torch.autograd.grad(grad_x.sum(), x)[0]
# grad_grad_x = 2 (상수)
print(f"Second derivative: {grad_grad_x}")

3. 커스텀 CUDA 연산자

torch.utils.cpp_extension 소개

PyTorch는 C++/CUDA 확장을 작성하고 파이썬에서 사용할 수 있는 도구를 제공합니다.

# JIT 컴파일 방식 (개발/프로토타입에 적합)
from torch.utils.cpp_extension import load_inline

# C++ CPU 연산자
cpp_source = """
#include <torch/extension.h>

torch::Tensor relu_forward(torch::Tensor input) {
    return input.clamp_min(0);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("relu_forward", &relu_forward, "ReLU forward");
}
"""

# 인라인 컴파일
custom_relu_cpp = load_inline(
    name="custom_relu_cpp",
    cpp_sources=cpp_source,
    functions=["relu_forward"],
    verbose=False
)

x = torch.randn(5)
result = custom_relu_cpp.relu_forward(x)
print(result)

CUDA 커널 예시: Fused Softmax

# CUDA 소스
cuda_source = """
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

__global__ void fused_softmax_kernel(
    const float* input,
    float* output,
    int rows,
    int cols
) {
    int row = blockIdx.x;
    if (row >= rows) return;

    const float* row_input = input + row * cols;
    float* row_output = output + row * cols;

    // Max 찾기 (수치 안정성)
    float max_val = row_input[0];
    for (int i = 1; i < cols; i++) {
        max_val = fmaxf(max_val, row_input[i]);
    }

    // exp(x - max) 계산 및 합산
    float sum = 0.0f;
    for (int i = 0; i < cols; i++) {
        row_output[i] = expf(row_input[i] - max_val);
        sum += row_output[i];
    }

    // 정규화
    for (int i = 0; i < cols; i++) {
        row_output[i] /= sum;
    }
}

torch::Tensor fused_softmax_cuda(torch::Tensor input) {
    auto output = torch::zeros_like(input);
    int rows = input.size(0);
    int cols = input.size(1);

    fused_softmax_kernel<<<rows, 1>>>(
        input.data_ptr<float>(),
        output.data_ptr<float>(),
        rows,
        cols
    );

    return output;
}
"""

cpp_source_cuda = """
#include <torch/extension.h>
torch::Tensor fused_softmax_cuda(torch::Tensor input);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("fused_softmax", &fused_softmax_cuda, "Fused Softmax CUDA");
}
"""

# CUDA가 있는 경우에만 컴파일
if torch.cuda.is_available():
    from torch.utils.cpp_extension import load_inline
    fused_softmax_ext = load_inline(
        name="fused_softmax",
        cpp_sources=cpp_source_cuda,
        cuda_sources=cuda_source,
        functions=["fused_softmax"],
        verbose=True
    )

    # 테스트
    x = torch.randn(4, 8, device="cuda")
    result = fused_softmax_ext.fused_softmax(x)
    expected = torch.softmax(x, dim=1)
    print(f"Max difference: {(result - expected).abs().max().item():.6f}")

setup.py를 이용한 패키지 빌드

# setup.py
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
    name="custom_ops",
    ext_modules=[
        CUDAExtension(
            name="custom_ops",
            sources=[
                "custom_ops/ops.cpp",
                "custom_ops/ops_cuda.cu",
            ],
            extra_compile_args={
                "cxx": ["-O3"],
                "nvcc": ["-O3", "--use_fast_math"],
            }
        )
    ],
    cmdclass={
        "build_ext": BuildExtension
    }
)
# 빌드: python setup.py install

4. 메모리 최적화 기법

GPU 메모리 프로파일링

import torch

def print_gpu_memory_stats():
    """GPU 메모리 통계를 출력하는 유틸리티"""
    if torch.cuda.is_available():
        print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
        print(f"Reserved:  {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
        print(f"Max Allocated: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
        print(torch.cuda.memory_summary(abbreviated=True))

# 메모리 추적 시작
torch.cuda.reset_peak_memory_stats()
print_gpu_memory_stats()

# 모델 로드
model = TransformerBlock().cuda()
print_gpu_memory_stats()

Gradient Checkpointing (Activation Recomputation)

Gradient Checkpointing은 순전파 시 중간 활성화 값을 저장하지 않고, 역전파 시 필요한 경우 다시 계산합니다. 메모리를 크게 절약할 수 있지만, 계산 시간이 약 30% 증가합니다.

from torch.utils.checkpoint import checkpoint, checkpoint_sequential
import torch.nn as nn

class DeepTransformer(nn.Module):
    def __init__(self, num_layers=12, d_model=512, nhead=8):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, nhead) for _ in range(num_layers)
        ])

    def forward_with_checkpointing(self, x):
        """각 레이어에 gradient checkpointing 적용"""
        for layer in self.layers:
            # checkpoint는 중간 활성화 값을 저장하지 않음
            x = checkpoint(layer, x, use_reentrant=False)
        return x

    def forward_sequential_checkpointing(self, x):
        """sequential 모듈에 checkpointing 적용"""
        # 4개 레이어씩 그룹화
        x = checkpoint_sequential(self.layers, segments=3, input=x)
        return x

# 메모리 비교
model = DeepTransformer(num_layers=24).cuda()
x = torch.randn(16, 512, 512, device="cuda")

# 일반 forward
torch.cuda.reset_peak_memory_stats()
out = model(x) if hasattr(model, 'forward') else model.forward_with_checkpointing(x)
normal_mem = torch.cuda.max_memory_allocated()

# checkpointing forward
torch.cuda.reset_peak_memory_stats()
out = model.forward_with_checkpointing(x)
checkpoint_mem = torch.cuda.max_memory_allocated()

print(f"Normal memory: {normal_mem / 1024**3:.2f} GB")
print(f"Checkpoint memory: {checkpoint_mem / 1024**3:.2f} GB")
print(f"Memory saved: {(normal_mem - checkpoint_mem) / 1024**3:.2f} GB")

Gradient Accumulation

def train_with_gradient_accumulation(
    model, optimizer, dataloader, criterion,
    accumulation_steps=4
):
    """
    큰 effective batch size를 작은 GPU 메모리로 구현
    """
    model.train()
    optimizer.zero_grad()
    total_loss = 0

    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.cuda(), target.cuda()

        # Forward pass
        output = model(data)
        # accumulation_steps로 나누어 스케일 조정
        loss = criterion(output, target) / accumulation_steps
        loss.backward()

        total_loss += loss.item() * accumulation_steps

        # accumulation_steps마다 업데이트
        if (batch_idx + 1) % accumulation_steps == 0:
            # 그래디언트 클리핑 (선택사항)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            optimizer.zero_grad()

    return total_loss / len(dataloader)

Mixed Precision Training (AMP)

from torch.cuda.amp import autocast, GradScaler

def train_with_amp(model, optimizer, dataloader, criterion):
    """Automatic Mixed Precision으로 메모리 절약 + 속도 향상"""
    scaler = GradScaler()
    model.train()

    for data, target in dataloader:
        data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()

        # autocast 컨텍스트에서 float16 연산
        with autocast():
            output = model(data)
            loss = criterion(output, target)

        # 스케일된 그래디언트로 backward
        scaler.scale(loss).backward()

        # 언스케일 후 그래디언트 클리핑
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        # Optimizer 스텝 (NaN/Inf 체크 포함)
        scaler.step(optimizer)
        scaler.update()

8-bit Optimizer

# bitsandbytes 라이브러리 필요: pip install bitsandbytes
try:
    import bitsandbytes as bnb

    # 일반 Adam 대신 8-bit Adam 사용
    optimizer_8bit = bnb.optim.Adam8bit(
        model.parameters(),
        lr=1e-4,
        betas=(0.9, 0.999)
    )

    # PagedAdam (CPU로 오프로드 가능)
    optimizer_paged = bnb.optim.PagedAdam(
        model.parameters(),
        lr=1e-4
    )

    print("8-bit optimizer loaded successfully")
except ImportError:
    print("bitsandbytes not installed, using regular Adam")
    optimizer_8bit = torch.optim.Adam(model.parameters(), lr=1e-4)

5. functorch와 vmap

vmap - 배치화된 연산

vmap은 단일 샘플에 작동하는 함수를 배치에 효율적으로 적용합니다.

import torch
from torch import vmap

# 단일 샘플에 작동하는 함수
def single_linear(weight, bias, x):
    return weight @ x + bias

# vmap으로 배치 처리
batched_linear = vmap(single_linear)

# 배치 데이터
batch_size = 32
weight = torch.randn(batch_size, 10, 5)
bias = torch.randn(batch_size, 10)
x = torch.randn(batch_size, 5)

# 자동으로 배치 처리됨
result = batched_linear(weight, bias, x)
print(f"Result shape: {result.shape}")  # (32, 10)

grad - 함수형 그래디언트

from torch.func import grad, vmap, functional_call

# 함수형 그래디언트
def scalar_loss(params, x, y):
    pred = functional_call(model, params, (x,))
    return ((pred - y) ** 2).mean()

# 파라미터에 대한 그래디언트
params = dict(model.named_parameters())
grad_fn = grad(scalar_loss)

x = torch.randn(1, 10)
y = torch.randn(1, 5)
grads = grad_fn(params, x, y)
print({k: v.shape for k, v in grads.items()})

앙상블 모델 - vmap 활용

from torch.func import stack_module_state, functional_call, vmap

def create_ensemble(model_class, num_models, *args, **kwargs):
    """vmap을 이용한 효율적인 앙상블"""
    models = [model_class(*args, **kwargs) for _ in range(num_models)]

    # 모든 모델의 파라미터를 스택
    params, buffers = stack_module_state(models)

    # 단일 모델 forward
    base_model = model_class(*args, **kwargs)

    def single_forward(params, buffers, x):
        return functional_call(base_model, (params, buffers), (x,))

    # vmap으로 모든 모델을 병렬 실행
    ensemble_forward = vmap(single_forward, in_dims=(0, 0, None))

    return ensemble_forward, params, buffers

# 사용 예시
ensemble_fn, params, buffers = create_ensemble(
    nn.Linear, num_models=5, in_features=10, out_features=5
)

x = torch.randn(32, 10)
ensemble_out = ensemble_fn(params, buffers, x)
print(f"Ensemble output shape: {ensemble_out.shape}")  # (5, 32, 5)

메타러닝 (MAML) - grad와 vmap 결합

from torch.func import grad, vmap, functional_call

def inner_loop(params, support_x, support_y, base_model, lr=0.01, steps=5):
    """MAML 내부 루프"""
    adapted_params = {k: v.clone() for k, v in params.items()}

    for _ in range(steps):
        def loss_fn(params):
            pred = functional_call(base_model, params, (support_x,))
            return ((pred - support_y) ** 2).mean()

        grads = grad(loss_fn)(adapted_params)
        adapted_params = {
            k: p - lr * grads[k]
            for k, p in adapted_params.items()
        }

    return adapted_params

6. PyTorch Profiler

기본 프로파일링

import torch
from torch.profiler import profile, record_function, ProfilerActivity

model = TransformerBlock().cuda()
x = torch.randn(32, 128, 512, device="cuda")

# 프로파일러 실행
with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True,
    profile_memory=True,
    with_stack=True
) as prof:
    with record_function("model_inference"):
        for _ in range(10):
            output = model(x)

# 결과 출력
print(prof.key_averages().table(
    sort_by="cuda_time_total",
    row_limit=20
))

# Chrome Trace 내보내기
prof.export_chrome_trace("trace.json")
# chrome://tracing에서 열기

TensorBoard 통합

from torch.profiler import profile, ProfilerActivity, tensorboard_trace_handler

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    schedule=torch.profiler.schedule(
        wait=1,    # 첫 1 스텝 대기
        warmup=1,  # 1 스텝 워밍업
        active=3,  # 3 스텝 프로파일
        repeat=2   # 2회 반복
    ),
    on_trace_ready=tensorboard_trace_handler("./log/profiler"),
    record_shapes=True,
    profile_memory=True
) as prof:
    for step, (data, target) in enumerate(dataloader):
        train_step(model, optimizer, data, target)
        prof.step()  # 스케줄에 따라 프로파일링

# tensorboard --logdir=./log/profiler

세부 메모리 분석

# 메모리 스냅샷 (PyTorch 2.1+)
torch.cuda.memory._record_memory_history(max_entries=100000)

# 코드 실행
x = torch.randn(100, 100, device="cuda")
y = x @ x.T
z = y.sum()

# 스냅샷 저장
snapshot = torch.cuda.memory._snapshot()
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")

# 분석
print(f"Active allocations: {len(snapshot['segments'])}")

7. TorchScript

torch.jit.script vs torch.jit.trace

import torch
import torch.nn as nn

class ConditionalModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)

    def forward(self, x, flag: bool = True):
        if flag:  # 조건 분기가 있으면 trace 불가
            return torch.relu(self.linear(x))
        else:
            return torch.sigmoid(self.linear(x))

model = ConditionalModel()

# script: 제어 흐름 포함 (권장)
scripted_model = torch.jit.script(model)
print(scripted_model.code)

# trace: 단일 경로만 캡처
x = torch.randn(1, 10)
traced_model = torch.jit.trace(model, x)
# 주의: flag=False 경로는 캡처되지 않음

# 저장 및 로드
scripted_model.save("model_scripted.pt")
loaded = torch.jit.load("model_scripted.pt")

TorchScript 최적화

# 최적화 패스 적용
scripted = torch.jit.script(model)
optimized = torch.jit.optimize_for_inference(scripted)

# C++ 환경을 위한 내보내기
# C++에서: torch::jit::script::Module m = torch::jit::load("model.pt");

8. Dynamic Shapes와 torch.export

torch.export 사용

import torch
from torch.export import export

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)

    def forward(self, x):
        return self.linear(x)

model = MyModel()
example_inputs = (torch.randn(2, 10),)

# 동적 형태 지정
from torch.export import Dim
batch = Dim("batch", min=1, max=100)

# 모델 내보내기
exported = export(
    model,
    example_inputs,
    dynamic_shapes={"x": {0: batch}}
)

print(exported)
print(exported.graph_module.code)

# ExportedProgram 실행
result = exported.module()(torch.randn(5, 10))
print(f"Result shape: {result.shape}")

9. Custom Dataset과 Sampler

IterableDataset

from torch.utils.data import IterableDataset, DataLoader
import torch

class StreamingDataset(IterableDataset):
    """대용량 데이터를 스트리밍으로 처리"""

    def __init__(self, data_paths, transform=None):
        self.data_paths = data_paths
        self.transform = transform

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()

        if worker_info is None:
            # 단일 프로세스
            paths = self.data_paths
        else:
            # 멀티프로세싱: 데이터를 분할
            per_worker = len(self.data_paths) // worker_info.num_workers
            start = worker_info.id * per_worker
            end = start + per_worker
            paths = self.data_paths[start:end]

        for path in paths:
            # 파일에서 데이터 스트리밍
            data = self._load_file(path)
            for sample in data:
                if self.transform:
                    sample = self.transform(sample)
                yield sample

    def _load_file(self, path):
        # 실제 구현에서는 파일 로드
        return [torch.randn(10) for _ in range(100)]

# DataLoader에 persistent_workers 사용
dataset = StreamingDataset(data_paths=["file1.pt", "file2.pt"])
dataloader = DataLoader(
    dataset,
    batch_size=32,
    num_workers=4,
    persistent_workers=True,  # 워커 프로세스 재사용
    pin_memory=True,           # GPU 전송 최적화
    prefetch_factor=2          # 미리 가져올 배치 수
)

커스텀 Sampler

from torch.utils.data import Sampler
import numpy as np

class BalancedClassSampler(Sampler):
    """클래스 불균형 문제를 위한 가중치 샘플링"""

    def __init__(self, dataset, num_samples_per_class=None):
        self.dataset = dataset
        labels = [dataset[i][1] for i in range(len(dataset))]
        self.labels = torch.tensor(labels)

        # 클래스별 인덱스
        self.class_indices = {}
        for cls in torch.unique(self.labels):
            self.class_indices[cls.item()] = (
                self.labels == cls
            ).nonzero(as_tuple=True)[0].tolist()

        self.num_classes = len(self.class_indices)
        self.num_samples_per_class = (
            num_samples_per_class or
            max(len(v) for v in self.class_indices.values())
        )

    def __iter__(self):
        indices = []
        for cls_idx in self.class_indices.values():
            # 각 클래스에서 동일한 수의 샘플 추출 (복원 추출)
            sampled = np.random.choice(
                cls_idx,
                self.num_samples_per_class,
                replace=True
            ).tolist()
            indices.extend(sampled)

        # 셔플
        np.random.shuffle(indices)
        return iter(indices)

    def __len__(self):
        return self.num_classes * self.num_samples_per_class

10. PyTorch Lightning

LightningModule 완전 예시

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

class LightningTransformer(pl.LightningModule):
    def __init__(
        self,
        d_model=512,
        nhead=8,
        num_layers=6,
        num_classes=10,
        learning_rate=1e-4
    ):
        super().__init__()
        self.save_hyperparameters()  # HParams 자동 저장

        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=d_model,
                nhead=nhead,
                batch_first=True
            ),
            num_layers=num_layers
        )
        self.classifier = nn.Linear(d_model, num_classes)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        encoded = self.encoder(x)
        # 시퀀스 평균 풀링
        pooled = encoded.mean(dim=1)
        return self.classifier(pooled)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()

        # 로깅
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()

        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=0.01
        )
        # 코사인 어닐링 스케줄러
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=100
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss"
            }
        }


class LightningDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32):
        super().__init__()
        self.batch_size = batch_size

    def setup(self, stage=None):
        # 더미 데이터
        x = torch.randn(1000, 32, 512)
        y = torch.randint(0, 10, (1000,))
        dataset = TensorDataset(x, y)
        n = len(dataset)
        self.train_ds = torch.utils.data.Subset(dataset, range(int(0.8*n)))
        self.val_ds = torch.utils.data.Subset(dataset, range(int(0.8*n), n))

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=self.batch_size)


# 훈련 실행
model = LightningTransformer()
dm = LightningDataModule()

# 콜백 설정
callbacks = [
    ModelCheckpoint(
        monitor="val_loss",
        save_top_k=3,
        mode="min",
        filename="transformer-{epoch:02d}-{val_loss:.2f}"
    ),
    EarlyStopping(monitor="val_loss", patience=10, mode="min")
]

# Trainer
trainer = pl.Trainer(
    max_epochs=100,
    accelerator="auto",    # GPU 자동 감지
    devices="auto",
    callbacks=callbacks,
    logger=TensorBoardLogger("tb_logs", name="transformer"),
    gradient_clip_val=1.0,  # 그래디언트 클리핑
    accumulate_grad_batches=4,  # Gradient accumulation
    precision="16-mixed",   # AMP
)

trainer.fit(model, dm)

11. 모델 양자화 (Quantization)

동적 양자화 (Dynamic Quantization)

import torch
import torch.nn as nn
from torch.ao.quantization import quantize_dynamic

# 추론 전용 양자화 - 가장 쉬운 방법
model = nn.Sequential(
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Linear(256, 128)
)
model.eval()

# Linear와 LSTM 레이어를 int8로 양자화
quantized_model = quantize_dynamic(
    model,
    {nn.Linear, nn.LSTM},  # 양자화할 레이어 타입
    dtype=torch.qint8
)

# 크기 비교
import os
torch.save(model.state_dict(), "model_fp32.pt")
torch.save(quantized_model.state_dict(), "model_int8.pt")
fp32_size = os.path.getsize("model_fp32.pt")
int8_size = os.path.getsize("model_int8.pt")
print(f"FP32 size: {fp32_size / 1024:.1f} KB")
print(f"INT8 size: {int8_size / 1024:.1f} KB")
print(f"Compression: {fp32_size / int8_size:.1f}x")

정적 양자화 (Static Quantization)

import torch
from torch.ao.quantization import (
    get_default_qconfig,
    prepare,
    convert,
    QConfig
)

class QuantizableModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = torch.ao.quantization.QuantStub()
        self.linear1 = nn.Linear(512, 256)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(256, 10)
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.dequant(x)
        return x

model = QuantizableModel()
model.eval()

# 양자화 설정
model.qconfig = get_default_qconfig("fbgemm")  # x86 CPU용

# 준비 (옵저버 삽입)
prepared_model = prepare(model)

# 캘리브레이션 데이터로 통계 수집
with torch.no_grad():
    for _ in range(100):
        calibration_data = torch.randn(32, 512)
        prepared_model(calibration_data)

# 양자화 변환
quantized_static = convert(prepared_model)

# 추론
x = torch.randn(1, 512)
with torch.no_grad():
    output = quantized_static(x)
print(f"Output shape: {output.shape}")

QAT (Quantization-Aware Training)

from torch.ao.quantization import (
    get_default_qat_qconfig,
    prepare_qat,
    convert
)

model = QuantizableModel()
model.train()

# QAT 설정
model.qconfig = get_default_qat_qconfig("fbgemm")

# QAT 준비 (가짜 양자화 노드 삽입)
prepared_qat = prepare_qat(model)

# 훈련 (양자화 오류를 학습에 포함)
optimizer = torch.optim.SGD(prepared_qat.parameters(), lr=0.0001)
for epoch in range(10):
    for x, y in dummy_dataloader():
        output = prepared_qat(x)
        loss = nn.functional.cross_entropy(output, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# 훈련 완료 후 변환
prepared_qat.eval()
quantized_qat = convert(prepared_qat)


def dummy_dataloader():
    for _ in range(10):
        yield torch.randn(32, 512), torch.randint(0, 10, (32,))

12. Tensor Parallelism과 Pipeline Parallelism

DeviceMesh와 DTensor API

import torch
import torch.distributed as dist
from torch.distributed.tensor.parallel import (
    parallelize_module,
    ColwiseParallel,
    RowwiseParallel
)
from torch.distributed._tensor import DeviceMesh

# 분산 훈련 초기화
def setup_distributed():
    dist.init_process_group(backend="nccl")
    torch.cuda.set_device(dist.get_rank())

class TransformerMLP(nn.Module):
    def __init__(self, d_model=1024, dim_feedforward=4096):
        super().__init__()
        self.fc1 = nn.Linear(d_model, dim_feedforward)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(dim_feedforward, d_model)

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))

# Tensor Parallelism 적용
def apply_tensor_parallel(model, mesh):
    """
    fc1은 Column-wise로 샤딩 (출력 차원 분할)
    fc2는 Row-wise로 샤딩 (입력 차원 분할)
    """
    parallelize_module(
        model,
        mesh,
        {
            "fc1": ColwiseParallel(),
            "fc2": RowwiseParallel(),
        }
    )
    return model

# 실행 예시 (2 GPU 환경)
# device_mesh = DeviceMesh("cuda", [0, 1])
# model = TransformerMLP()
# model = apply_tensor_parallel(model, device_mesh)

FSDP (Fully Sharded Data Parallel)

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    BackwardPrefetch,
    ShardingStrategy
)
from torch.distributed.fsdp.wrap import (
    size_based_auto_wrap_policy,
    transformer_auto_wrap_policy
)
import functools

def setup_fsdp_model(model):
    """FSDP 설정 - 대규모 모델을 여러 GPU에 분산"""

    # 혼합 정밀도 설정
    mp_policy = MixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.float32,
        buffer_dtype=torch.bfloat16,
    )

    # TransformerBlock을 래핑 단위로 설정
    auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={TransformerBlock}
    )

    model = FSDP(
        model,
        auto_wrap_policy=auto_wrap_policy,
        mixed_precision=mp_policy,
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        device_id=torch.cuda.current_device()
    )

    return model

마치며

이 가이드에서는 PyTorch의 고급 기법들을 살펴보았습니다:

  1. torch.compile: 코드 수정 없이 2배 이상의 성능 향상
  2. Custom Autograd: 특수한 그래디언트 계산 구현
  3. CUDA Extensions: GPU 커널을 PyTorch와 통합
  4. 메모리 최적화: Gradient Checkpointing, AMP, 8-bit Optimizer
  5. functorch/vmap: 함수형 API로 배치 처리와 메타러닝
  6. PyTorch Profiler: 성능 병목 분석
  7. TorchScript/export: 배포 최적화
  8. PyTorch Lightning: 코드 구조화와 훈련 자동화
  9. 양자화: INT8로 모델 크기 축소
  10. 분산 학습: Tensor Parallel, FSDP

각 기법은 서로 보완적이며, 실제 프로젝트에서는 여러 기법을 조합하여 사용하는 것이 일반적입니다. 특히 대규모 모델 훈련 시에는 AMP + Gradient Checkpointing + FSDP + torch.compile 조합이 강력한 효과를 발휘합니다.

참고 자료