Skip to content

Split View: PyTorch 내부 구조 & 고급 최적화: autograd, torch.compile, FSDP, Triton까지

|

PyTorch 내부 구조 & 고급 최적화: autograd, torch.compile, FSDP, Triton까지

목차

  1. PyTorch 내부 구조: ATen과 Tensor 계층
  2. Autograd 엔진과 계산 그래프
  3. 커스텀 연산 구현
  4. torch.compile()과 TorchInductor
  5. 메모리 최적화 기법
  6. 분산 학습: DDP와 FSDP
  7. 추론 최적화
  8. 디버깅 도구
  9. 퀴즈

PyTorch 내부 구조

ATen 라이브러리

PyTorch의 핵심은 **ATen(A Tensor library)**입니다. C++ 기반의 텐서 연산 라이브러리로, 모든 PyTorch 연산의 하부 구현입니다.

Python API (torch.*)
TorchDispatch / Dispatcher
ATen (C++ tensor ops)
CUDA / CPU / MPS backends

ATen의 주요 구성 요소:

  • Tensor: 다차원 배열로, storage, dtype, device, stride 정보를 보관
  • Storage: 실제 메모리 블록 (공유 가능)
  • Dispatcher: 연산자를 적절한 백엔드로 라우팅
import torch

x = torch.randn(3, 4)
print(x.storage())          # 실제 메모리 블록
print(x.stride())           # (4, 1) - row-major
print(x.storage_offset())   # 0

# View는 storage를 공유
y = x.view(2, 6)
print(x.storage().data_ptr() == y.storage().data_ptr())  # True

TorchDispatch

TorchDispatch는 Python에서 PyTorch 연산을 가로채는 메커니즘입니다. 커스텀 텐서 타입 구현에 활용됩니다.

import torch
from torch.utils._pytree import tree_map

class LoggingTensor(torch.Tensor):
    @staticmethod
    def __new__(cls, elem):
        return torch.Tensor._make_subclass(cls, elem)

    @classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        print(f"Calling: {func.__name__}")
        kwargs = kwargs or {}
        return func(*args, **kwargs)

x = LoggingTensor(torch.randn(3, 3))
y = x + x  # 출력: Calling: add.Tensor

Autograd 엔진

계산 그래프 (DAG)

PyTorch autograd는 동적 계산 그래프(Dynamic Computational Graph)를 사용합니다. 연산이 실행될 때마다 DAG(Directed Acyclic Graph)가 구성됩니다.

import torch

x = torch.tensor(2.0, requires_grad=True)  # leaf tensor
y = x ** 2          # non-leaf, grad_fn=PowBackward0
z = y * 3           # non-leaf, grad_fn=MulBackward0

print(x.is_leaf)    # True
print(y.is_leaf)    # False
print(z.grad_fn)    # MulBackward0

z.backward()
print(x.grad)       # dz/dx = 3 * 2x = 12.0

Leaf tensor vs Non-leaf tensor:

  • Leaf tensor: requires_grad=True이고 사용자가 직접 생성한 텐서. .grad에 gradient가 누적됨
  • Non-leaf tensor: 연산의 결과로 생성된 텐서. 기본적으로 .grad가 None (.retain_grad() 호출 필요)

Gradient 누적 메커니즘

import torch

model = torch.nn.Linear(10, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Gradient 누적 (accumulation)
ACCUMULATION_STEPS = 4
for i, (x, y) in enumerate(dataloader):
    output = model(x)
    loss = criterion(output, y) / ACCUMULATION_STEPS
    loss.backward()  # gradient 누적

    if (i + 1) % ACCUMULATION_STEPS == 0:
        optimizer.step()
        optimizer.zero_grad()  # gradient 초기화

retain_graph와 create_graph

x = torch.tensor(3.0, requires_grad=True)
y = x ** 3

# 고차 미분: create_graph=True
grad_1 = torch.autograd.grad(y, x, create_graph=True)[0]
grad_2 = torch.autograd.grad(grad_1, x)[0]

print(grad_1)  # 3x^2 = 27.0
print(grad_2)  # 6x = 18.0

커스텀 연산 구현

torch.autograd.Function

커스텀 forward/backward를 정의할 때 사용합니다.

import torch

class SigmoidFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        # ctx에 backward에 필요한 값 저장
        output = 1 / (1 + torch.exp(-x))
        ctx.save_for_backward(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        (output,) = ctx.saved_tensors
        # sigmoid 미분: sigma(x) * (1 - sigma(x))
        grad_input = grad_output * output * (1 - output)
        return grad_input

# 사용
x = torch.randn(4, requires_grad=True)
y = SigmoidFunction.apply(x)
y.sum().backward()
print(x.grad)

torch.library API (커스텀 연산 등록)

import torch
from torch.library import Library, impl

my_lib = Library("my_ops", "DEF")
my_lib.define("relu_squared(Tensor x) -> Tensor")

@impl(my_lib, "relu_squared", "CPU")
def relu_squared_cpu(x):
    return torch.relu(x) ** 2

@impl(my_lib, "relu_squared", "CUDA")
def relu_squared_cuda(x):
    return torch.relu(x) ** 2

# 커스텀 연산 사용
x = torch.randn(5)
result = torch.ops.my_ops.relu_squared(x)

커스텀 CUDA 커널 (Triton)

import triton
import triton.language as tl
import torch

@triton.jit
def relu_squared_kernel(
    x_ptr, out_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    x = tl.load(x_ptr + offsets, mask=mask)
    relu_x = tl.where(x > 0, x, 0.0)
    out = relu_x * relu_x

    tl.store(out_ptr + offsets, out, mask=mask)

def relu_squared_triton(x: torch.Tensor):
    out = torch.empty_like(x)
    n_elements = x.numel()
    BLOCK_SIZE = 1024
    grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
    relu_squared_kernel[grid](x, out, n_elements, BLOCK_SIZE)
    return out

torch.compile()

Dynamo와 graph capture

torch.compile()은 Python 바이트코드를 분석해 계산 그래프를 추출합니다.

import torch

def model_forward(x, weight):
    x = torch.nn.functional.relu(x @ weight)
    return x.sum()

# 컴파일: fullgraph=True면 graph break 불허
compiled_fn = torch.compile(model_forward, fullgraph=True, backend="inductor")

x = torch.randn(128, 256, device="cuda")
w = torch.randn(256, 512, device="cuda")
out = compiled_fn(x, w)

Graph break 발생 조건:

  • Python 제어 흐름 (if/for에서 텐서 값 사용)
  • 외부 라이브러리 호출 (numpy 등)
  • 지원되지 않는 Python 패턴
import torch._dynamo
torch._dynamo.config.verbose = True  # graph break 디버깅

# graph break를 허용하려면 fullgraph=False (기본값)
compiled = torch.compile(model, backend="inductor")

AOTAutograd와 TorchInductor

torch.compile() 파이프라인:
Python 코드 → Dynamo (그래프 추출)
AOTAutograd (forward + backward 결합)
TorchInductor (커널 생성)
Triton / C++ 코드
import torch
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(512, 1024)
        self.fc2 = nn.Linear(1024, 10)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

model = SimpleModel().cuda()

# mode 옵션: "default", "reduce-overhead", "max-autotune"
compiled_model = torch.compile(model, mode="max-autotune")

x = torch.randn(32, 512, device="cuda")
out = compiled_model(x)

메모리 최적화

Gradient Checkpointing

순전파 중 중간 활성화를 저장하지 않고, 역전파 시 재계산하여 메모리를 절약합니다.

import torch
import torch.utils.checkpoint as checkpoint
import torch.nn as nn

class CheckpointedBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim),
        )

    def forward(self, x):
        # checkpoint: forward 재실행으로 메모리 절약
        return checkpoint.checkpoint(self.layers, x, use_reentrant=False)

model = nn.Sequential(*[CheckpointedBlock(512) for _ in range(24)]).cuda()
x = torch.randn(32, 512, device="cuda", requires_grad=True)
out = model(x)
out.sum().backward()

AMP (Automatic Mixed Precision)

import torch
from torch.cuda.amp import autocast, GradScaler

model = SimpleModel().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scaler = GradScaler()  # FP16 언더플로우 방지

for x, y in dataloader:
    x, y = x.cuda(), y.cuda()
    optimizer.zero_grad()

    # autocast: 연산에 따라 FP16/BF16 자동 적용
    with autocast(dtype=torch.float16):
        output = model(x)
        loss = criterion(output, y)

    # scaler: gradient를 스케일링하여 언더플로우 방지
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    scaler.step(optimizer)
    scaler.update()

메모리 프로파일러

import torch
from torch.profiler import profile, ProfilerActivity, record_function

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    profile_memory=True,
    record_shapes=True,
) as prof:
    with record_function("model_inference"):
        output = model(x)

print(prof.key_averages().table(
    sort_by="cuda_memory_usage", row_limit=10
))
prof.export_chrome_trace("trace.json")

Activation Offloading

# CPU 오프로딩으로 GPU 메모리 절약
def offload_checkpoint(module, x):
    """활성화를 CPU로 오프로드 후 역전파 시 GPU로 복귀"""
    def forward_and_save(*inputs):
        output = module(*inputs)
        return output

    return checkpoint.checkpoint(forward_and_save, x, use_reentrant=False)

분산 학습

DDP (DistributedDataParallel)

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import os

def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size):
    setup(rank, world_size)
    torch.cuda.set_device(rank)

    model = SimpleModel().cuda(rank)
    ddp_model = DDP(model, device_ids=[rank])

    optimizer = torch.optim.Adam(ddp_model.parameters())

    for x, y in dataloader:
        x, y = x.cuda(rank), y.cuda(rank)
        output = ddp_model(x)
        loss = criterion(output, y)
        loss.backward()     # gradient all-reduce 자동 수행
        optimizer.step()
        optimizer.zero_grad()

    cleanup()

FSDP (Fully Sharded Data Parallel)

FSDP는 파라미터, gradient, optimizer state를 모든 GPU에 샤딩합니다.

import torch
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    ShardingStrategy,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import functools

# Transformer 블록 단위로 FSDP 적용
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={nn.TransformerEncoderLayer},
)

# 혼합 정밀도 설정
mp_policy = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.float32,
    buffer_dtype=torch.bfloat16,
)

model = LargeTransformer().cuda()
fsdp_model = FSDP(
    model,
    auto_wrap_policy=auto_wrap_policy,
    mixed_precision=mp_policy,
    sharding_strategy=ShardingStrategy.FULL_SHARD,  # ZeRO-3 동등
)

optimizer = torch.optim.AdamW(fsdp_model.parameters(), lr=1e-4)

FSDP vs DDP 메모리 비교:

DDP는 각 GPU가 전체 모델 복사본을 보관합니다. FSDP는 파라미터를 world_size로 나누어 각 GPU의 메모리 사용량을 1/N으로 줄입니다.

DeepSpeed ZeRO 통합

# deepspeed_config.json
# {
#   "zero_optimization": {"stage": 3},
#   "fp16": {"enabled": true},
#   "gradient_accumulation_steps": 4
# }

import deepspeed

model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    model_parameters=model.parameters(),
    config="deepspeed_config.json"
)

for x, y in dataloader:
    output = model_engine(x)
    loss = criterion(output, y)
    model_engine.backward(loss)
    model_engine.step()

추론 최적화

torch.export()와 ONNX

import torch
from torch.export import export

model = SimpleModel().eval()
x = torch.randn(1, 512)

# torch.export: 정적 계산 그래프 추출
exported = export(model, (x,))
print(exported.graph)

# ONNX 내보내기
torch.onnx.export(
    model, x,
    "model.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}},
    opset_version=17,
)

Quantization-Aware Training (QAT)

import torch
from torch.quantization import get_default_qat_qconfig, prepare_qat, convert

model = SimpleModel()
model.qconfig = get_default_qat_qconfig("fbgemm")

# QAT 준비: fake quantization 삽입
model_prepared = prepare_qat(model.train())

# 일반 학습과 동일하게 학습
for x, y in dataloader:
    output = model_prepared(x)
    loss = criterion(output, y)
    loss.backward()
    optimizer.step()

# INT8 모델로 변환
model_int8 = convert(model_prepared.eval())

디버깅 도구

torch.profiler

import torch
from torch.profiler import profile, ProfilerActivity, schedule

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    schedule=schedule(wait=1, warmup=1, active=3, repeat=2),
    on_trace_ready=torch.profiler.tensorboard_trace_handler("./log"),
    record_shapes=True,
    profile_memory=True,
    with_stack=True,
) as prof:
    for step, (x, y) in enumerate(dataloader):
        output = model(x.cuda())
        loss = criterion(output, y.cuda())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        prof.step()

Anomaly Detection

# autograd 이상 감지 (NaN/Inf gradient 추적)
with torch.autograd.detect_anomaly():
    output = model(x)
    loss = output.sum()
    loss.backward()  # NaN 발생 시 스택 트레이스 출력

grad_fn 추적

def trace_grad_fn(tensor, depth=0):
    if tensor.grad_fn is None:
        print("  " * depth + f"Leaf: {tensor.shape}")
        return
    print("  " * depth + f"{tensor.grad_fn.__class__.__name__}: {tensor.shape}")
    for inp, _ in tensor.grad_fn.next_functions:
        if inp is not None:
            trace_grad_fn(inp.variable if hasattr(inp, 'variable') else inp, depth + 1)

x = torch.randn(3, requires_grad=True)
y = torch.randn(3, requires_grad=True)
z = (x * y).sum()
trace_grad_fn(z)

퀴즈

Q1. PyTorch autograd에서 leaf tensor와 non-leaf tensor의 차이 및 gradient 누적 방식은?

정답: leaf tensor는 사용자가 직접 생성하고 requires_grad=True인 텐서입니다. .grad 속성에 gradient가 누적됩니다.

설명: non-leaf tensor는 연산의 결과로 생성된 텐서로, 기본적으로 gradient가 저장되지 않습니다 (메모리 절약). retain_grad()를 호출하면 non-leaf tensor의 gradient도 보관할 수 있습니다. Gradient 누적은 optimizer.zero_grad() 없이 여러 번 backward()를 호출하면 .grad에 더해지는 방식으로 작동합니다. 이를 활용해 gradient accumulation으로 배치 크기를 가상으로 늘릴 수 있습니다.

Q2. torch.compile()의 Dynamo가 Python 바이트코드를 추적하는 방식과 graph break 발생 조건은?

정답: Dynamo는 Python 프레임 평가를 가로채 바이트코드를 분석하며, 지원되지 않는 패턴에서 graph break가 발생합니다.

설명: Dynamo는 CPython의 PEP 523 프레임 평가 API를 사용해 Python 바이트코드를 symbolic하게 추적합니다. 텐서 값에 의존하는 제어 흐름 (예: if x.sum() > 0:), 외부 라이브러리 호출, C 확장 등에서 graph break가 발생합니다. Graph break 시 Dynamo는 해당 지점까지의 그래프를 컴파일하고 나머지는 일반 Python으로 실행합니다. fullgraph=True 설정 시 graph break를 에러로 처리합니다.

Q3. FSDP가 DDP보다 메모리 효율적인 이유 (파라미터 샤딩 관점)

정답: FSDP는 파라미터, gradient, optimizer state 모두를 world_size개의 GPU에 분산하여 각 GPU의 메모리를 약 1/N으로 줄입니다.

설명: DDP는 각 GPU가 전체 모델 파라미터를 복사하여 보관합니다. 10B 파라미터 모델은 GPU당 약 40GB (FP32 기준)가 필요합니다. FSDP (ZeRO-3 전략)는 forward/backward 시에만 필요한 파라미터를 all-gather로 수집하고, 사용 후 즉시 해제합니다. 8개 GPU 환경에서 메모리 사용량은 약 1/8로 줄어들어 단일 GPU에 올릴 수 없는 거대 모델을 학습할 수 있습니다.

Q4. Gradient checkpointing에서 forward pass를 재실행하는 트레이드오프

정답: 메모리 사용량을 O(sqrt(N))으로 줄이는 대신, 역전파 시간이 약 33% 증가하는 연산-메모리 트레이드오프입니다.

설명: 일반적인 역전파는 모든 forward 활성화를 저장하므로 O(N) 메모리가 필요합니다. Gradient checkpointing은 체크포인트 경계에서의 활성화만 저장하고, 역전파 시 해당 구간의 forward를 재실행합니다. Transformer에서 레이어별로 체크포인트를 설정하면 레이어 수가 아닌 sqrt(레이어 수)에 비례한 메모리만 필요합니다. 재연산 비용은 전체 학습 시간을 약 30-40% 증가시키지만, 배치 크기를 늘릴 수 있어 실제 처리량은 개선될 수 있습니다.

Q5. AMP에서 GradScaler가 언더플로우를 방지하는 방법

정답: GradScaler는 loss에 큰 스케일 값을 곱해 gradient를 FP16 표현 범위 내로 유지하고, optimizer 업데이트 전에 스케일을 역적용합니다.

설명: FP16의 최솟값은 약 6e-5로, 작은 gradient는 0으로 언더플로우됩니다. GradScaler는 loss에 scale factor (초기값 65536 등)를 곱하면 gradient가 그 배수로 커져 FP16에서도 표현 가능해집니다. scaler.unscale_(optimizer) 단계에서 gradient를 원래 크기로 복원합니다. Inf/NaN 발생 시 scale을 자동으로 줄이고 해당 스텝을 건너뜁니다. BF16은 FP32와 같은 지수 범위를 가지므로 GradScaler가 필요하지 않습니다.

PyTorch Internals & Advanced Optimization: autograd, torch.compile, FSDP, and Triton

Table of Contents

  1. PyTorch Internals: ATen and the Tensor Layer
  2. Autograd Engine and Computational Graph
  3. Custom Operation Implementation
  4. torch.compile() and TorchInductor
  5. Memory Optimization Techniques
  6. Distributed Training: DDP and FSDP
  7. Inference Optimization
  8. Debugging Tools
  9. Quiz

PyTorch Internals

The ATen Library

At the heart of PyTorch is ATen (A Tensor Library) — a C++ tensor operations library that underlies every PyTorch operation.

Python API (torch.*)
TorchDispatch / Dispatcher
ATen (C++ tensor ops)
CUDA / CPU / MPS backends

Key ATen components:

  • Tensor: Multi-dimensional array holding storage, dtype, device, and stride
  • Storage: The raw memory block (shareable between tensors)
  • Dispatcher: Routes operations to the appropriate backend
import torch

x = torch.randn(3, 4)
print(x.storage())          # raw memory block
print(x.stride())           # (4, 1) - row-major layout
print(x.storage_offset())   # 0

# Views share the same storage
y = x.view(2, 6)
print(x.storage().data_ptr() == y.storage().data_ptr())  # True

TorchDispatch

TorchDispatch is a Python-level mechanism to intercept PyTorch operations. It is used to implement custom tensor types.

import torch
from torch.utils._pytree import tree_map

class LoggingTensor(torch.Tensor):
    @staticmethod
    def __new__(cls, elem):
        return torch.Tensor._make_subclass(cls, elem)

    @classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        print(f"Calling: {func.__name__}")
        kwargs = kwargs or {}
        return func(*args, **kwargs)

x = LoggingTensor(torch.randn(3, 3))
y = x + x  # prints: Calling: add.Tensor

Autograd Engine

The Computational Graph (DAG)

PyTorch autograd uses a Dynamic Computational Graph. A DAG (Directed Acyclic Graph) is constructed on the fly as operations execute.

import torch

x = torch.tensor(2.0, requires_grad=True)  # leaf tensor
y = x ** 2          # non-leaf, grad_fn=PowBackward0
z = y * 3           # non-leaf, grad_fn=MulBackward0

print(x.is_leaf)    # True
print(y.is_leaf)    # False
print(z.grad_fn)    # MulBackward0

z.backward()
print(x.grad)       # dz/dx = 3 * 2x = 12.0

Leaf tensor vs Non-leaf tensor:

  • Leaf tensor: Created directly by the user with requires_grad=True. Gradients accumulate in .grad
  • Non-leaf tensor: Result of an operation. .grad is None by default (call .retain_grad() to keep it)

Gradient Accumulation

import torch

model = torch.nn.Linear(10, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Gradient accumulation to simulate larger batch sizes
ACCUMULATION_STEPS = 4
for i, (x, y) in enumerate(dataloader):
    output = model(x)
    loss = criterion(output, y) / ACCUMULATION_STEPS
    loss.backward()  # gradients accumulate in .grad

    if (i + 1) % ACCUMULATION_STEPS == 0:
        optimizer.step()
        optimizer.zero_grad()  # reset accumulated gradients

retain_graph and create_graph

x = torch.tensor(3.0, requires_grad=True)
y = x ** 3

# Higher-order derivatives: create_graph=True keeps grad graph
grad_1 = torch.autograd.grad(y, x, create_graph=True)[0]
grad_2 = torch.autograd.grad(grad_1, x)[0]

print(grad_1)  # 3x^2 = 27.0
print(grad_2)  # 6x  = 18.0

Custom Operations

torch.autograd.Function

Use this to define custom forward and backward passes.

import torch

class SigmoidFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        # save values needed for backward
        output = 1 / (1 + torch.exp(-x))
        ctx.save_for_backward(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        (output,) = ctx.saved_tensors
        # derivative of sigmoid: sigma(x) * (1 - sigma(x))
        grad_input = grad_output * output * (1 - output)
        return grad_input

# Usage
x = torch.randn(4, requires_grad=True)
y = SigmoidFunction.apply(x)
y.sum().backward()
print(x.grad)

torch.library API (Registering Custom Ops)

import torch
from torch.library import Library, impl

my_lib = Library("my_ops", "DEF")
my_lib.define("relu_squared(Tensor x) -> Tensor")

@impl(my_lib, "relu_squared", "CPU")
def relu_squared_cpu(x):
    return torch.relu(x) ** 2

@impl(my_lib, "relu_squared", "CUDA")
def relu_squared_cuda(x):
    return torch.relu(x) ** 2

# Use the custom op
x = torch.randn(5)
result = torch.ops.my_ops.relu_squared(x)

Custom CUDA Kernel with Triton

import triton
import triton.language as tl
import torch

@triton.jit
def relu_squared_kernel(
    x_ptr, out_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    x = tl.load(x_ptr + offsets, mask=mask)
    relu_x = tl.where(x > 0, x, 0.0)
    out = relu_x * relu_x

    tl.store(out_ptr + offsets, out, mask=mask)

def relu_squared_triton(x: torch.Tensor):
    out = torch.empty_like(x)
    n_elements = x.numel()
    BLOCK_SIZE = 1024
    grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
    relu_squared_kernel[grid](x, out, n_elements, BLOCK_SIZE)
    return out

torch.compile()

Dynamo and Graph Capture

torch.compile() analyzes Python bytecode to extract a computation graph for optimization.

import torch

def model_forward(x, weight):
    x = torch.nn.functional.relu(x @ weight)
    return x.sum()

# fullgraph=True disallows graph breaks
compiled_fn = torch.compile(model_forward, fullgraph=True, backend="inductor")

x = torch.randn(128, 256, device="cuda")
w = torch.randn(256, 512, device="cuda")
out = compiled_fn(x, w)

Graph break conditions:

  • Control flow depending on tensor values (e.g., if tensor.sum() > 0)
  • Calls to external libraries (numpy, scipy, etc.)
  • Unsupported Python patterns
import torch._dynamo
torch._dynamo.config.verbose = True  # debug graph breaks

# fullgraph=False (default) allows graph breaks
compiled = torch.compile(model, backend="inductor")

AOTAutograd and TorchInductor

torch.compile() pipeline:
Python code → Dynamo (graph extraction)
AOTAutograd (fused forward + backward)
TorchInductor (kernel generation)
Triton / C++ code
import torch
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(512, 1024)
        self.fc2 = nn.Linear(1024, 10)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

model = SimpleModel().cuda()

# mode options: "default", "reduce-overhead", "max-autotune"
compiled_model = torch.compile(model, mode="max-autotune")

x = torch.randn(32, 512, device="cuda")
out = compiled_model(x)

Memory Optimization

Gradient Checkpointing

Avoids storing intermediate activations during the forward pass; recomputes them during backward to save memory.

import torch
import torch.utils.checkpoint as checkpoint
import torch.nn as nn

class CheckpointedBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim),
        )

    def forward(self, x):
        # recomputes layers during backward to save memory
        return checkpoint.checkpoint(self.layers, x, use_reentrant=False)

model = nn.Sequential(*[CheckpointedBlock(512) for _ in range(24)]).cuda()
x = torch.randn(32, 512, device="cuda", requires_grad=True)
out = model(x)
out.sum().backward()

AMP (Automatic Mixed Precision)

import torch
from torch.cuda.amp import autocast, GradScaler

model = SimpleModel().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scaler = GradScaler()  # prevents FP16 underflow

for x, y in dataloader:
    x, y = x.cuda(), y.cuda()
    optimizer.zero_grad()

    # autocast applies FP16/BF16 where appropriate
    with autocast(dtype=torch.float16):
        output = model(x)
        loss = criterion(output, y)

    # scaler multiplies loss to keep gradients in FP16 range
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    scaler.step(optimizer)
    scaler.update()

Memory Profiler

import torch
from torch.profiler import profile, ProfilerActivity, record_function

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    profile_memory=True,
    record_shapes=True,
) as prof:
    with record_function("model_inference"):
        output = model(x)

print(prof.key_averages().table(
    sort_by="cuda_memory_usage", row_limit=10
))
prof.export_chrome_trace("trace.json")

Activation Offloading

# Offload activations to CPU to free GPU memory
def offload_checkpoint(module, x):
    """Offloads activations to CPU; restores to GPU during backward."""
    def forward_and_save(*inputs):
        output = module(*inputs)
        return output

    return checkpoint.checkpoint(forward_and_save, x, use_reentrant=False)

Distributed Training

DDP (DistributedDataParallel)

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import os

def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size):
    setup(rank, world_size)
    torch.cuda.set_device(rank)

    model = SimpleModel().cuda(rank)
    ddp_model = DDP(model, device_ids=[rank])

    optimizer = torch.optim.Adam(ddp_model.parameters())

    for x, y in dataloader:
        x, y = x.cuda(rank), y.cuda(rank)
        output = ddp_model(x)
        loss = criterion(output, y)
        loss.backward()     # gradient all-reduce happens automatically
        optimizer.step()
        optimizer.zero_grad()

    cleanup()

FSDP (Fully Sharded Data Parallel)

FSDP shards parameters, gradients, and optimizer state across all GPUs.

import torch
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    ShardingStrategy,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import functools

# Wrap at the Transformer block level
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={nn.TransformerEncoderLayer},
)

# Mixed precision policy
mp_policy = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.float32,
    buffer_dtype=torch.bfloat16,
)

model = LargeTransformer().cuda()
fsdp_model = FSDP(
    model,
    auto_wrap_policy=auto_wrap_policy,
    mixed_precision=mp_policy,
    sharding_strategy=ShardingStrategy.FULL_SHARD,  # equivalent to ZeRO-3
)

optimizer = torch.optim.AdamW(fsdp_model.parameters(), lr=1e-4)

FSDP vs DDP memory comparison:

DDP keeps a full model copy on every GPU. FSDP divides parameters by world_size, reducing per-GPU memory to roughly 1/N. This enables training models too large to fit on a single GPU.

DeepSpeed ZeRO Integration

# deepspeed_config.json:
# {
#   "zero_optimization": {"stage": 3},
#   "fp16": {"enabled": true},
#   "gradient_accumulation_steps": 4
# }

import deepspeed

model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    model_parameters=model.parameters(),
    config="deepspeed_config.json"
)

for x, y in dataloader:
    output = model_engine(x)
    loss = criterion(output, y)
    model_engine.backward(loss)
    model_engine.step()

Inference Optimization

torch.export() and ONNX

import torch
from torch.export import export

model = SimpleModel().eval()
x = torch.randn(1, 512)

# torch.export: extracts a static computation graph
exported = export(model, (x,))
print(exported.graph)

# ONNX export
torch.onnx.export(
    model, x,
    "model.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}},
    opset_version=17,
)

Quantization-Aware Training (QAT)

import torch
from torch.quantization import get_default_qat_qconfig, prepare_qat, convert

model = SimpleModel()
model.qconfig = get_default_qat_qconfig("fbgemm")

# Insert fake quantization ops
model_prepared = prepare_qat(model.train())

# Train as usual
for x, y in dataloader:
    output = model_prepared(x)
    loss = criterion(output, y)
    loss.backward()
    optimizer.step()

# Convert to INT8
model_int8 = convert(model_prepared.eval())

Debugging Tools

torch.profiler

import torch
from torch.profiler import profile, ProfilerActivity, schedule

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    schedule=schedule(wait=1, warmup=1, active=3, repeat=2),
    on_trace_ready=torch.profiler.tensorboard_trace_handler("./log"),
    record_shapes=True,
    profile_memory=True,
    with_stack=True,
) as prof:
    for step, (x, y) in enumerate(dataloader):
        output = model(x.cuda())
        loss = criterion(output, y.cuda())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        prof.step()

Anomaly Detection

# Detect NaN/Inf gradients with full stack traces
with torch.autograd.detect_anomaly():
    output = model(x)
    loss = output.sum()
    loss.backward()  # prints stack trace if NaN occurs

Tracing grad_fn

def trace_grad_fn(tensor, depth=0):
    if tensor.grad_fn is None:
        print("  " * depth + f"Leaf: {tensor.shape}")
        return
    print("  " * depth + f"{tensor.grad_fn.__class__.__name__}: {tensor.shape}")
    for inp, _ in tensor.grad_fn.next_functions:
        if inp is not None:
            trace_grad_fn(inp.variable if hasattr(inp, 'variable') else inp, depth + 1)

x = torch.randn(3, requires_grad=True)
y = torch.randn(3, requires_grad=True)
z = (x * y).sum()
trace_grad_fn(z)

Quiz

Q1. What is the difference between leaf and non-leaf tensors in PyTorch autograd, and how does gradient accumulation work?

Answer: A leaf tensor is created directly by the user with requires_grad=True. Gradients accumulate in its .grad attribute.

Explanation: Non-leaf tensors are results of operations; their .grad is None by default (memory optimization). Call .retain_grad() to preserve them. Gradient accumulation works because calling backward() multiple times without zero_grad() adds to existing .grad values. This property lets you simulate larger effective batch sizes without increasing GPU memory usage.

Q2. How does Dynamo in torch.compile() trace Python bytecode, and what triggers a graph break?

Answer: Dynamo intercepts Python frame evaluation using PEP 523 APIs to symbolically trace bytecode. Graph breaks occur on unsupported patterns.

Explanation: Dynamo uses CPython's frame evaluation hook to trace operations symbolically. Graph breaks are triggered by: control flow dependent on tensor values (e.g., if tensor.sum() > 0), calls to unsupported external libraries, and C extensions. At each break, Dynamo compiles the graph up to that point and falls back to regular Python execution for the remainder. Setting fullgraph=True turns graph breaks into errors.

Q3. Why is FSDP more memory efficient than DDP from a parameter sharding perspective?

Answer: FSDP shards parameters, gradients, and optimizer state across all GPUs, reducing per-GPU memory to roughly 1/N.

Explanation: DDP replicates the full model on every GPU. A 10B parameter model requires roughly 40 GB per GPU in FP32. FSDP (with FULL_SHARD strategy, equivalent to ZeRO-3) uses all-gather to collect only the parameters needed for the current forward/backward pass, then immediately discards them. With 8 GPUs, memory per GPU drops to about 1/8 of the total, enabling training of models that exceed single-GPU memory.

Q4. What is the compute-memory tradeoff when using gradient checkpointing?

Answer: Memory usage drops to O(sqrt(N)) at the cost of roughly a 33% increase in backward pass time due to recomputation.

Explanation: Standard backpropagation stores all forward activations — O(N) memory. Gradient checkpointing saves only activations at checkpoint boundaries and recomputes the intervening forward pass during backpropagation. For Transformers with per-layer checkpointing, memory scales with the square root of the number of layers rather than linearly. The recomputation overhead adds roughly 30-40% to total training time, but the ability to use larger batch sizes can compensate, often improving overall throughput.

Q5. How does GradScaler in AMP prevent gradient underflow in FP16 training?

Answer: GradScaler multiplies the loss by a large scale factor before backward, keeping gradients in the representable FP16 range, then inverts the scaling before the optimizer update.

Explanation: FP16's smallest positive normal value is approximately 6e-5. Without scaling, small gradients underflow to zero and parameter updates stall. GradScaler multiplies the loss by a scale factor (default 65536), which amplifies all gradients proportionally. Before optimizer.step(), scaler.unscale_() divides gradients back by the scale factor. If any Inf or NaN is detected, the step is skipped and the scale is halved. BF16 shares FP32's exponent range and does not require GradScaler.