Skip to content
Published on

PyTorch Internals & Advanced Optimization: autograd, torch.compile, FSDP, and Triton

Authors

Table of Contents

  1. PyTorch Internals: ATen and the Tensor Layer
  2. Autograd Engine and Computational Graph
  3. Custom Operation Implementation
  4. torch.compile() and TorchInductor
  5. Memory Optimization Techniques
  6. Distributed Training: DDP and FSDP
  7. Inference Optimization
  8. Debugging Tools
  9. Quiz

PyTorch Internals

The ATen Library

At the heart of PyTorch is ATen (A Tensor Library) — a C++ tensor operations library that underlies every PyTorch operation.

Python API (torch.*)
TorchDispatch / Dispatcher
ATen (C++ tensor ops)
CUDA / CPU / MPS backends

Key ATen components:

  • Tensor: Multi-dimensional array holding storage, dtype, device, and stride
  • Storage: The raw memory block (shareable between tensors)
  • Dispatcher: Routes operations to the appropriate backend
import torch

x = torch.randn(3, 4)
print(x.storage())          # raw memory block
print(x.stride())           # (4, 1) - row-major layout
print(x.storage_offset())   # 0

# Views share the same storage
y = x.view(2, 6)
print(x.storage().data_ptr() == y.storage().data_ptr())  # True

TorchDispatch

TorchDispatch is a Python-level mechanism to intercept PyTorch operations. It is used to implement custom tensor types.

import torch
from torch.utils._pytree import tree_map

class LoggingTensor(torch.Tensor):
    @staticmethod
    def __new__(cls, elem):
        return torch.Tensor._make_subclass(cls, elem)

    @classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        print(f"Calling: {func.__name__}")
        kwargs = kwargs or {}
        return func(*args, **kwargs)

x = LoggingTensor(torch.randn(3, 3))
y = x + x  # prints: Calling: add.Tensor

Autograd Engine

The Computational Graph (DAG)

PyTorch autograd uses a Dynamic Computational Graph. A DAG (Directed Acyclic Graph) is constructed on the fly as operations execute.

import torch

x = torch.tensor(2.0, requires_grad=True)  # leaf tensor
y = x ** 2          # non-leaf, grad_fn=PowBackward0
z = y * 3           # non-leaf, grad_fn=MulBackward0

print(x.is_leaf)    # True
print(y.is_leaf)    # False
print(z.grad_fn)    # MulBackward0

z.backward()
print(x.grad)       # dz/dx = 3 * 2x = 12.0

Leaf tensor vs Non-leaf tensor:

  • Leaf tensor: Created directly by the user with requires_grad=True. Gradients accumulate in .grad
  • Non-leaf tensor: Result of an operation. .grad is None by default (call .retain_grad() to keep it)

Gradient Accumulation

import torch

model = torch.nn.Linear(10, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Gradient accumulation to simulate larger batch sizes
ACCUMULATION_STEPS = 4
for i, (x, y) in enumerate(dataloader):
    output = model(x)
    loss = criterion(output, y) / ACCUMULATION_STEPS
    loss.backward()  # gradients accumulate in .grad

    if (i + 1) % ACCUMULATION_STEPS == 0:
        optimizer.step()
        optimizer.zero_grad()  # reset accumulated gradients

retain_graph and create_graph

x = torch.tensor(3.0, requires_grad=True)
y = x ** 3

# Higher-order derivatives: create_graph=True keeps grad graph
grad_1 = torch.autograd.grad(y, x, create_graph=True)[0]
grad_2 = torch.autograd.grad(grad_1, x)[0]

print(grad_1)  # 3x^2 = 27.0
print(grad_2)  # 6x  = 18.0

Custom Operations

torch.autograd.Function

Use this to define custom forward and backward passes.

import torch

class SigmoidFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        # save values needed for backward
        output = 1 / (1 + torch.exp(-x))
        ctx.save_for_backward(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        (output,) = ctx.saved_tensors
        # derivative of sigmoid: sigma(x) * (1 - sigma(x))
        grad_input = grad_output * output * (1 - output)
        return grad_input

# Usage
x = torch.randn(4, requires_grad=True)
y = SigmoidFunction.apply(x)
y.sum().backward()
print(x.grad)

torch.library API (Registering Custom Ops)

import torch
from torch.library import Library, impl

my_lib = Library("my_ops", "DEF")
my_lib.define("relu_squared(Tensor x) -> Tensor")

@impl(my_lib, "relu_squared", "CPU")
def relu_squared_cpu(x):
    return torch.relu(x) ** 2

@impl(my_lib, "relu_squared", "CUDA")
def relu_squared_cuda(x):
    return torch.relu(x) ** 2

# Use the custom op
x = torch.randn(5)
result = torch.ops.my_ops.relu_squared(x)

Custom CUDA Kernel with Triton

import triton
import triton.language as tl
import torch

@triton.jit
def relu_squared_kernel(
    x_ptr, out_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    x = tl.load(x_ptr + offsets, mask=mask)
    relu_x = tl.where(x > 0, x, 0.0)
    out = relu_x * relu_x

    tl.store(out_ptr + offsets, out, mask=mask)

def relu_squared_triton(x: torch.Tensor):
    out = torch.empty_like(x)
    n_elements = x.numel()
    BLOCK_SIZE = 1024
    grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
    relu_squared_kernel[grid](x, out, n_elements, BLOCK_SIZE)
    return out

torch.compile()

Dynamo and Graph Capture

torch.compile() analyzes Python bytecode to extract a computation graph for optimization.

import torch

def model_forward(x, weight):
    x = torch.nn.functional.relu(x @ weight)
    return x.sum()

# fullgraph=True disallows graph breaks
compiled_fn = torch.compile(model_forward, fullgraph=True, backend="inductor")

x = torch.randn(128, 256, device="cuda")
w = torch.randn(256, 512, device="cuda")
out = compiled_fn(x, w)

Graph break conditions:

  • Control flow depending on tensor values (e.g., if tensor.sum() > 0)
  • Calls to external libraries (numpy, scipy, etc.)
  • Unsupported Python patterns
import torch._dynamo
torch._dynamo.config.verbose = True  # debug graph breaks

# fullgraph=False (default) allows graph breaks
compiled = torch.compile(model, backend="inductor")

AOTAutograd and TorchInductor

torch.compile() pipeline:
Python code → Dynamo (graph extraction)
AOTAutograd (fused forward + backward)
TorchInductor (kernel generation)
Triton / C++ code
import torch
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(512, 1024)
        self.fc2 = nn.Linear(1024, 10)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

model = SimpleModel().cuda()

# mode options: "default", "reduce-overhead", "max-autotune"
compiled_model = torch.compile(model, mode="max-autotune")

x = torch.randn(32, 512, device="cuda")
out = compiled_model(x)

Memory Optimization

Gradient Checkpointing

Avoids storing intermediate activations during the forward pass; recomputes them during backward to save memory.

import torch
import torch.utils.checkpoint as checkpoint
import torch.nn as nn

class CheckpointedBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim),
        )

    def forward(self, x):
        # recomputes layers during backward to save memory
        return checkpoint.checkpoint(self.layers, x, use_reentrant=False)

model = nn.Sequential(*[CheckpointedBlock(512) for _ in range(24)]).cuda()
x = torch.randn(32, 512, device="cuda", requires_grad=True)
out = model(x)
out.sum().backward()

AMP (Automatic Mixed Precision)

import torch
from torch.cuda.amp import autocast, GradScaler

model = SimpleModel().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scaler = GradScaler()  # prevents FP16 underflow

for x, y in dataloader:
    x, y = x.cuda(), y.cuda()
    optimizer.zero_grad()

    # autocast applies FP16/BF16 where appropriate
    with autocast(dtype=torch.float16):
        output = model(x)
        loss = criterion(output, y)

    # scaler multiplies loss to keep gradients in FP16 range
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    scaler.step(optimizer)
    scaler.update()

Memory Profiler

import torch
from torch.profiler import profile, ProfilerActivity, record_function

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    profile_memory=True,
    record_shapes=True,
) as prof:
    with record_function("model_inference"):
        output = model(x)

print(prof.key_averages().table(
    sort_by="cuda_memory_usage", row_limit=10
))
prof.export_chrome_trace("trace.json")

Activation Offloading

# Offload activations to CPU to free GPU memory
def offload_checkpoint(module, x):
    """Offloads activations to CPU; restores to GPU during backward."""
    def forward_and_save(*inputs):
        output = module(*inputs)
        return output

    return checkpoint.checkpoint(forward_and_save, x, use_reentrant=False)

Distributed Training

DDP (DistributedDataParallel)

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import os

def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size):
    setup(rank, world_size)
    torch.cuda.set_device(rank)

    model = SimpleModel().cuda(rank)
    ddp_model = DDP(model, device_ids=[rank])

    optimizer = torch.optim.Adam(ddp_model.parameters())

    for x, y in dataloader:
        x, y = x.cuda(rank), y.cuda(rank)
        output = ddp_model(x)
        loss = criterion(output, y)
        loss.backward()     # gradient all-reduce happens automatically
        optimizer.step()
        optimizer.zero_grad()

    cleanup()

FSDP (Fully Sharded Data Parallel)

FSDP shards parameters, gradients, and optimizer state across all GPUs.

import torch
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    ShardingStrategy,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import functools

# Wrap at the Transformer block level
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={nn.TransformerEncoderLayer},
)

# Mixed precision policy
mp_policy = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.float32,
    buffer_dtype=torch.bfloat16,
)

model = LargeTransformer().cuda()
fsdp_model = FSDP(
    model,
    auto_wrap_policy=auto_wrap_policy,
    mixed_precision=mp_policy,
    sharding_strategy=ShardingStrategy.FULL_SHARD,  # equivalent to ZeRO-3
)

optimizer = torch.optim.AdamW(fsdp_model.parameters(), lr=1e-4)

FSDP vs DDP memory comparison:

DDP keeps a full model copy on every GPU. FSDP divides parameters by world_size, reducing per-GPU memory to roughly 1/N. This enables training models too large to fit on a single GPU.

DeepSpeed ZeRO Integration

# deepspeed_config.json:
# {
#   "zero_optimization": {"stage": 3},
#   "fp16": {"enabled": true},
#   "gradient_accumulation_steps": 4
# }

import deepspeed

model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    model_parameters=model.parameters(),
    config="deepspeed_config.json"
)

for x, y in dataloader:
    output = model_engine(x)
    loss = criterion(output, y)
    model_engine.backward(loss)
    model_engine.step()

Inference Optimization

torch.export() and ONNX

import torch
from torch.export import export

model = SimpleModel().eval()
x = torch.randn(1, 512)

# torch.export: extracts a static computation graph
exported = export(model, (x,))
print(exported.graph)

# ONNX export
torch.onnx.export(
    model, x,
    "model.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}},
    opset_version=17,
)

Quantization-Aware Training (QAT)

import torch
from torch.quantization import get_default_qat_qconfig, prepare_qat, convert

model = SimpleModel()
model.qconfig = get_default_qat_qconfig("fbgemm")

# Insert fake quantization ops
model_prepared = prepare_qat(model.train())

# Train as usual
for x, y in dataloader:
    output = model_prepared(x)
    loss = criterion(output, y)
    loss.backward()
    optimizer.step()

# Convert to INT8
model_int8 = convert(model_prepared.eval())

Debugging Tools

torch.profiler

import torch
from torch.profiler import profile, ProfilerActivity, schedule

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    schedule=schedule(wait=1, warmup=1, active=3, repeat=2),
    on_trace_ready=torch.profiler.tensorboard_trace_handler("./log"),
    record_shapes=True,
    profile_memory=True,
    with_stack=True,
) as prof:
    for step, (x, y) in enumerate(dataloader):
        output = model(x.cuda())
        loss = criterion(output, y.cuda())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        prof.step()

Anomaly Detection

# Detect NaN/Inf gradients with full stack traces
with torch.autograd.detect_anomaly():
    output = model(x)
    loss = output.sum()
    loss.backward()  # prints stack trace if NaN occurs

Tracing grad_fn

def trace_grad_fn(tensor, depth=0):
    if tensor.grad_fn is None:
        print("  " * depth + f"Leaf: {tensor.shape}")
        return
    print("  " * depth + f"{tensor.grad_fn.__class__.__name__}: {tensor.shape}")
    for inp, _ in tensor.grad_fn.next_functions:
        if inp is not None:
            trace_grad_fn(inp.variable if hasattr(inp, 'variable') else inp, depth + 1)

x = torch.randn(3, requires_grad=True)
y = torch.randn(3, requires_grad=True)
z = (x * y).sum()
trace_grad_fn(z)

Quiz

Q1. What is the difference between leaf and non-leaf tensors in PyTorch autograd, and how does gradient accumulation work?

Answer: A leaf tensor is created directly by the user with requires_grad=True. Gradients accumulate in its .grad attribute.

Explanation: Non-leaf tensors are results of operations; their .grad is None by default (memory optimization). Call .retain_grad() to preserve them. Gradient accumulation works because calling backward() multiple times without zero_grad() adds to existing .grad values. This property lets you simulate larger effective batch sizes without increasing GPU memory usage.

Q2. How does Dynamo in torch.compile() trace Python bytecode, and what triggers a graph break?

Answer: Dynamo intercepts Python frame evaluation using PEP 523 APIs to symbolically trace bytecode. Graph breaks occur on unsupported patterns.

Explanation: Dynamo uses CPython's frame evaluation hook to trace operations symbolically. Graph breaks are triggered by: control flow dependent on tensor values (e.g., if tensor.sum() > 0), calls to unsupported external libraries, and C extensions. At each break, Dynamo compiles the graph up to that point and falls back to regular Python execution for the remainder. Setting fullgraph=True turns graph breaks into errors.

Q3. Why is FSDP more memory efficient than DDP from a parameter sharding perspective?

Answer: FSDP shards parameters, gradients, and optimizer state across all GPUs, reducing per-GPU memory to roughly 1/N.

Explanation: DDP replicates the full model on every GPU. A 10B parameter model requires roughly 40 GB per GPU in FP32. FSDP (with FULL_SHARD strategy, equivalent to ZeRO-3) uses all-gather to collect only the parameters needed for the current forward/backward pass, then immediately discards them. With 8 GPUs, memory per GPU drops to about 1/8 of the total, enabling training of models that exceed single-GPU memory.

Q4. What is the compute-memory tradeoff when using gradient checkpointing?

Answer: Memory usage drops to O(sqrt(N)) at the cost of roughly a 33% increase in backward pass time due to recomputation.

Explanation: Standard backpropagation stores all forward activations — O(N) memory. Gradient checkpointing saves only activations at checkpoint boundaries and recomputes the intervening forward pass during backpropagation. For Transformers with per-layer checkpointing, memory scales with the square root of the number of layers rather than linearly. The recomputation overhead adds roughly 30-40% to total training time, but the ability to use larger batch sizes can compensate, often improving overall throughput.

Q5. How does GradScaler in AMP prevent gradient underflow in FP16 training?

Answer: GradScaler multiplies the loss by a large scale factor before backward, keeping gradients in the representable FP16 range, then inverts the scaling before the optimizer update.

Explanation: FP16's smallest positive normal value is approximately 6e-5. Without scaling, small gradients underflow to zero and parameter updates stall. GradScaler multiplies the loss by a scale factor (default 65536), which amplifies all gradients proportionally. Before optimizer.step(), scaler.unscale_() divides gradients back by the scale factor. If any Inf or NaN is detected, the step is skipped and the scale is halved. BF16 shares FP32's exponent range and does not require GradScaler.