Skip to content
Published on

PyTorch 내부 구조 & 고급 최적화: autograd, torch.compile, FSDP, Triton까지

Authors

목차

  1. PyTorch 내부 구조: ATen과 Tensor 계층
  2. Autograd 엔진과 계산 그래프
  3. 커스텀 연산 구현
  4. torch.compile()과 TorchInductor
  5. 메모리 최적화 기법
  6. 분산 학습: DDP와 FSDP
  7. 추론 최적화
  8. 디버깅 도구
  9. 퀴즈

PyTorch 내부 구조

ATen 라이브러리

PyTorch의 핵심은 **ATen(A Tensor library)**입니다. C++ 기반의 텐서 연산 라이브러리로, 모든 PyTorch 연산의 하부 구현입니다.

Python API (torch.*)
TorchDispatch / Dispatcher
ATen (C++ tensor ops)
CUDA / CPU / MPS backends

ATen의 주요 구성 요소:

  • Tensor: 다차원 배열로, storage, dtype, device, stride 정보를 보관
  • Storage: 실제 메모리 블록 (공유 가능)
  • Dispatcher: 연산자를 적절한 백엔드로 라우팅
import torch

x = torch.randn(3, 4)
print(x.storage())          # 실제 메모리 블록
print(x.stride())           # (4, 1) - row-major
print(x.storage_offset())   # 0

# View는 storage를 공유
y = x.view(2, 6)
print(x.storage().data_ptr() == y.storage().data_ptr())  # True

TorchDispatch

TorchDispatch는 Python에서 PyTorch 연산을 가로채는 메커니즘입니다. 커스텀 텐서 타입 구현에 활용됩니다.

import torch
from torch.utils._pytree import tree_map

class LoggingTensor(torch.Tensor):
    @staticmethod
    def __new__(cls, elem):
        return torch.Tensor._make_subclass(cls, elem)

    @classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        print(f"Calling: {func.__name__}")
        kwargs = kwargs or {}
        return func(*args, **kwargs)

x = LoggingTensor(torch.randn(3, 3))
y = x + x  # 출력: Calling: add.Tensor

Autograd 엔진

계산 그래프 (DAG)

PyTorch autograd는 동적 계산 그래프(Dynamic Computational Graph)를 사용합니다. 연산이 실행될 때마다 DAG(Directed Acyclic Graph)가 구성됩니다.

import torch

x = torch.tensor(2.0, requires_grad=True)  # leaf tensor
y = x ** 2          # non-leaf, grad_fn=PowBackward0
z = y * 3           # non-leaf, grad_fn=MulBackward0

print(x.is_leaf)    # True
print(y.is_leaf)    # False
print(z.grad_fn)    # MulBackward0

z.backward()
print(x.grad)       # dz/dx = 3 * 2x = 12.0

Leaf tensor vs Non-leaf tensor:

  • Leaf tensor: requires_grad=True이고 사용자가 직접 생성한 텐서. .grad에 gradient가 누적됨
  • Non-leaf tensor: 연산의 결과로 생성된 텐서. 기본적으로 .grad가 None (.retain_grad() 호출 필요)

Gradient 누적 메커니즘

import torch

model = torch.nn.Linear(10, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Gradient 누적 (accumulation)
ACCUMULATION_STEPS = 4
for i, (x, y) in enumerate(dataloader):
    output = model(x)
    loss = criterion(output, y) / ACCUMULATION_STEPS
    loss.backward()  # gradient 누적

    if (i + 1) % ACCUMULATION_STEPS == 0:
        optimizer.step()
        optimizer.zero_grad()  # gradient 초기화

retain_graph와 create_graph

x = torch.tensor(3.0, requires_grad=True)
y = x ** 3

# 고차 미분: create_graph=True
grad_1 = torch.autograd.grad(y, x, create_graph=True)[0]
grad_2 = torch.autograd.grad(grad_1, x)[0]

print(grad_1)  # 3x^2 = 27.0
print(grad_2)  # 6x = 18.0

커스텀 연산 구현

torch.autograd.Function

커스텀 forward/backward를 정의할 때 사용합니다.

import torch

class SigmoidFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        # ctx에 backward에 필요한 값 저장
        output = 1 / (1 + torch.exp(-x))
        ctx.save_for_backward(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        (output,) = ctx.saved_tensors
        # sigmoid 미분: sigma(x) * (1 - sigma(x))
        grad_input = grad_output * output * (1 - output)
        return grad_input

# 사용
x = torch.randn(4, requires_grad=True)
y = SigmoidFunction.apply(x)
y.sum().backward()
print(x.grad)

torch.library API (커스텀 연산 등록)

import torch
from torch.library import Library, impl

my_lib = Library("my_ops", "DEF")
my_lib.define("relu_squared(Tensor x) -> Tensor")

@impl(my_lib, "relu_squared", "CPU")
def relu_squared_cpu(x):
    return torch.relu(x) ** 2

@impl(my_lib, "relu_squared", "CUDA")
def relu_squared_cuda(x):
    return torch.relu(x) ** 2

# 커스텀 연산 사용
x = torch.randn(5)
result = torch.ops.my_ops.relu_squared(x)

커스텀 CUDA 커널 (Triton)

import triton
import triton.language as tl
import torch

@triton.jit
def relu_squared_kernel(
    x_ptr, out_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    x = tl.load(x_ptr + offsets, mask=mask)
    relu_x = tl.where(x > 0, x, 0.0)
    out = relu_x * relu_x

    tl.store(out_ptr + offsets, out, mask=mask)

def relu_squared_triton(x: torch.Tensor):
    out = torch.empty_like(x)
    n_elements = x.numel()
    BLOCK_SIZE = 1024
    grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
    relu_squared_kernel[grid](x, out, n_elements, BLOCK_SIZE)
    return out

torch.compile()

Dynamo와 graph capture

torch.compile()은 Python 바이트코드를 분석해 계산 그래프를 추출합니다.

import torch

def model_forward(x, weight):
    x = torch.nn.functional.relu(x @ weight)
    return x.sum()

# 컴파일: fullgraph=True면 graph break 불허
compiled_fn = torch.compile(model_forward, fullgraph=True, backend="inductor")

x = torch.randn(128, 256, device="cuda")
w = torch.randn(256, 512, device="cuda")
out = compiled_fn(x, w)

Graph break 발생 조건:

  • Python 제어 흐름 (if/for에서 텐서 값 사용)
  • 외부 라이브러리 호출 (numpy 등)
  • 지원되지 않는 Python 패턴
import torch._dynamo
torch._dynamo.config.verbose = True  # graph break 디버깅

# graph break를 허용하려면 fullgraph=False (기본값)
compiled = torch.compile(model, backend="inductor")

AOTAutograd와 TorchInductor

torch.compile() 파이프라인:
Python 코드 → Dynamo (그래프 추출)
AOTAutograd (forward + backward 결합)
TorchInductor (커널 생성)
Triton / C++ 코드
import torch
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(512, 1024)
        self.fc2 = nn.Linear(1024, 10)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

model = SimpleModel().cuda()

# mode 옵션: "default", "reduce-overhead", "max-autotune"
compiled_model = torch.compile(model, mode="max-autotune")

x = torch.randn(32, 512, device="cuda")
out = compiled_model(x)

메모리 최적화

Gradient Checkpointing

순전파 중 중간 활성화를 저장하지 않고, 역전파 시 재계산하여 메모리를 절약합니다.

import torch
import torch.utils.checkpoint as checkpoint
import torch.nn as nn

class CheckpointedBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim),
        )

    def forward(self, x):
        # checkpoint: forward 재실행으로 메모리 절약
        return checkpoint.checkpoint(self.layers, x, use_reentrant=False)

model = nn.Sequential(*[CheckpointedBlock(512) for _ in range(24)]).cuda()
x = torch.randn(32, 512, device="cuda", requires_grad=True)
out = model(x)
out.sum().backward()

AMP (Automatic Mixed Precision)

import torch
from torch.cuda.amp import autocast, GradScaler

model = SimpleModel().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scaler = GradScaler()  # FP16 언더플로우 방지

for x, y in dataloader:
    x, y = x.cuda(), y.cuda()
    optimizer.zero_grad()

    # autocast: 연산에 따라 FP16/BF16 자동 적용
    with autocast(dtype=torch.float16):
        output = model(x)
        loss = criterion(output, y)

    # scaler: gradient를 스케일링하여 언더플로우 방지
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    scaler.step(optimizer)
    scaler.update()

메모리 프로파일러

import torch
from torch.profiler import profile, ProfilerActivity, record_function

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    profile_memory=True,
    record_shapes=True,
) as prof:
    with record_function("model_inference"):
        output = model(x)

print(prof.key_averages().table(
    sort_by="cuda_memory_usage", row_limit=10
))
prof.export_chrome_trace("trace.json")

Activation Offloading

# CPU 오프로딩으로 GPU 메모리 절약
def offload_checkpoint(module, x):
    """활성화를 CPU로 오프로드 후 역전파 시 GPU로 복귀"""
    def forward_and_save(*inputs):
        output = module(*inputs)
        return output

    return checkpoint.checkpoint(forward_and_save, x, use_reentrant=False)

분산 학습

DDP (DistributedDataParallel)

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import os

def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size):
    setup(rank, world_size)
    torch.cuda.set_device(rank)

    model = SimpleModel().cuda(rank)
    ddp_model = DDP(model, device_ids=[rank])

    optimizer = torch.optim.Adam(ddp_model.parameters())

    for x, y in dataloader:
        x, y = x.cuda(rank), y.cuda(rank)
        output = ddp_model(x)
        loss = criterion(output, y)
        loss.backward()     # gradient all-reduce 자동 수행
        optimizer.step()
        optimizer.zero_grad()

    cleanup()

FSDP (Fully Sharded Data Parallel)

FSDP는 파라미터, gradient, optimizer state를 모든 GPU에 샤딩합니다.

import torch
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    ShardingStrategy,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import functools

# Transformer 블록 단위로 FSDP 적용
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={nn.TransformerEncoderLayer},
)

# 혼합 정밀도 설정
mp_policy = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.float32,
    buffer_dtype=torch.bfloat16,
)

model = LargeTransformer().cuda()
fsdp_model = FSDP(
    model,
    auto_wrap_policy=auto_wrap_policy,
    mixed_precision=mp_policy,
    sharding_strategy=ShardingStrategy.FULL_SHARD,  # ZeRO-3 동등
)

optimizer = torch.optim.AdamW(fsdp_model.parameters(), lr=1e-4)

FSDP vs DDP 메모리 비교:

DDP는 각 GPU가 전체 모델 복사본을 보관합니다. FSDP는 파라미터를 world_size로 나누어 각 GPU의 메모리 사용량을 1/N으로 줄입니다.

DeepSpeed ZeRO 통합

# deepspeed_config.json
# {
#   "zero_optimization": {"stage": 3},
#   "fp16": {"enabled": true},
#   "gradient_accumulation_steps": 4
# }

import deepspeed

model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    model_parameters=model.parameters(),
    config="deepspeed_config.json"
)

for x, y in dataloader:
    output = model_engine(x)
    loss = criterion(output, y)
    model_engine.backward(loss)
    model_engine.step()

추론 최적화

torch.export()와 ONNX

import torch
from torch.export import export

model = SimpleModel().eval()
x = torch.randn(1, 512)

# torch.export: 정적 계산 그래프 추출
exported = export(model, (x,))
print(exported.graph)

# ONNX 내보내기
torch.onnx.export(
    model, x,
    "model.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}},
    opset_version=17,
)

Quantization-Aware Training (QAT)

import torch
from torch.quantization import get_default_qat_qconfig, prepare_qat, convert

model = SimpleModel()
model.qconfig = get_default_qat_qconfig("fbgemm")

# QAT 준비: fake quantization 삽입
model_prepared = prepare_qat(model.train())

# 일반 학습과 동일하게 학습
for x, y in dataloader:
    output = model_prepared(x)
    loss = criterion(output, y)
    loss.backward()
    optimizer.step()

# INT8 모델로 변환
model_int8 = convert(model_prepared.eval())

디버깅 도구

torch.profiler

import torch
from torch.profiler import profile, ProfilerActivity, schedule

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    schedule=schedule(wait=1, warmup=1, active=3, repeat=2),
    on_trace_ready=torch.profiler.tensorboard_trace_handler("./log"),
    record_shapes=True,
    profile_memory=True,
    with_stack=True,
) as prof:
    for step, (x, y) in enumerate(dataloader):
        output = model(x.cuda())
        loss = criterion(output, y.cuda())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        prof.step()

Anomaly Detection

# autograd 이상 감지 (NaN/Inf gradient 추적)
with torch.autograd.detect_anomaly():
    output = model(x)
    loss = output.sum()
    loss.backward()  # NaN 발생 시 스택 트레이스 출력

grad_fn 추적

def trace_grad_fn(tensor, depth=0):
    if tensor.grad_fn is None:
        print("  " * depth + f"Leaf: {tensor.shape}")
        return
    print("  " * depth + f"{tensor.grad_fn.__class__.__name__}: {tensor.shape}")
    for inp, _ in tensor.grad_fn.next_functions:
        if inp is not None:
            trace_grad_fn(inp.variable if hasattr(inp, 'variable') else inp, depth + 1)

x = torch.randn(3, requires_grad=True)
y = torch.randn(3, requires_grad=True)
z = (x * y).sum()
trace_grad_fn(z)

퀴즈

Q1. PyTorch autograd에서 leaf tensor와 non-leaf tensor의 차이 및 gradient 누적 방식은?

정답: leaf tensor는 사용자가 직접 생성하고 requires_grad=True인 텐서입니다. .grad 속성에 gradient가 누적됩니다.

설명: non-leaf tensor는 연산의 결과로 생성된 텐서로, 기본적으로 gradient가 저장되지 않습니다 (메모리 절약). retain_grad()를 호출하면 non-leaf tensor의 gradient도 보관할 수 있습니다. Gradient 누적은 optimizer.zero_grad() 없이 여러 번 backward()를 호출하면 .grad에 더해지는 방식으로 작동합니다. 이를 활용해 gradient accumulation으로 배치 크기를 가상으로 늘릴 수 있습니다.

Q2. torch.compile()의 Dynamo가 Python 바이트코드를 추적하는 방식과 graph break 발생 조건은?

정답: Dynamo는 Python 프레임 평가를 가로채 바이트코드를 분석하며, 지원되지 않는 패턴에서 graph break가 발생합니다.

설명: Dynamo는 CPython의 PEP 523 프레임 평가 API를 사용해 Python 바이트코드를 symbolic하게 추적합니다. 텐서 값에 의존하는 제어 흐름 (예: if x.sum() > 0:), 외부 라이브러리 호출, C 확장 등에서 graph break가 발생합니다. Graph break 시 Dynamo는 해당 지점까지의 그래프를 컴파일하고 나머지는 일반 Python으로 실행합니다. fullgraph=True 설정 시 graph break를 에러로 처리합니다.

Q3. FSDP가 DDP보다 메모리 효율적인 이유 (파라미터 샤딩 관점)

정답: FSDP는 파라미터, gradient, optimizer state 모두를 world_size개의 GPU에 분산하여 각 GPU의 메모리를 약 1/N으로 줄입니다.

설명: DDP는 각 GPU가 전체 모델 파라미터를 복사하여 보관합니다. 10B 파라미터 모델은 GPU당 약 40GB (FP32 기준)가 필요합니다. FSDP (ZeRO-3 전략)는 forward/backward 시에만 필요한 파라미터를 all-gather로 수집하고, 사용 후 즉시 해제합니다. 8개 GPU 환경에서 메모리 사용량은 약 1/8로 줄어들어 단일 GPU에 올릴 수 없는 거대 모델을 학습할 수 있습니다.

Q4. Gradient checkpointing에서 forward pass를 재실행하는 트레이드오프

정답: 메모리 사용량을 O(sqrt(N))으로 줄이는 대신, 역전파 시간이 약 33% 증가하는 연산-메모리 트레이드오프입니다.

설명: 일반적인 역전파는 모든 forward 활성화를 저장하므로 O(N) 메모리가 필요합니다. Gradient checkpointing은 체크포인트 경계에서의 활성화만 저장하고, 역전파 시 해당 구간의 forward를 재실행합니다. Transformer에서 레이어별로 체크포인트를 설정하면 레이어 수가 아닌 sqrt(레이어 수)에 비례한 메모리만 필요합니다. 재연산 비용은 전체 학습 시간을 약 30-40% 증가시키지만, 배치 크기를 늘릴 수 있어 실제 처리량은 개선될 수 있습니다.

Q5. AMP에서 GradScaler가 언더플로우를 방지하는 방법

정답: GradScaler는 loss에 큰 스케일 값을 곱해 gradient를 FP16 표현 범위 내로 유지하고, optimizer 업데이트 전에 스케일을 역적용합니다.

설명: FP16의 최솟값은 약 6e-5로, 작은 gradient는 0으로 언더플로우됩니다. GradScaler는 loss에 scale factor (초기값 65536 등)를 곱하면 gradient가 그 배수로 커져 FP16에서도 표현 가능해집니다. scaler.unscale_(optimizer) 단계에서 gradient를 원래 크기로 복원합니다. Inf/NaN 발생 시 scale을 자동으로 줄이고 해당 스텝을 건너뜁니다. BF16은 FP32와 같은 지수 범위를 가지므로 GradScaler가 필요하지 않습니다.