- Published on
PyTorch Advanced Techniques Complete Guide: torch.compile, Custom Ops, Memory Optimization
- Authors

- Name
- Youngju Kim
- @fjvbn20031
PyTorch Advanced Techniques Complete Guide
PyTorch is one of the most popular deep learning frameworks for both research and production. However, most developers stop at basic tensor operations and nn.Module definitions. This guide covers advanced techniques for maximizing performance, implementing custom operations, and managing memory efficiently in real production environments.
1. torch.compile (PyTorch 2.0+)
torch.compile, introduced in PyTorch 2.0, is a feature that compiles your model to dramatically improve execution speed. Unlike TorchScript or ONNX export, torch.compile can bring over 2x speedups with minimal code changes.
Introduction and Benefits
torch.compile is composed of three core components:
- TorchDynamo: Intercepts Python bytecode to generate an FX graph
- AOTAutograd: Pre-compiles the automatic differentiation graph
- Inductor: Generates optimized kernels via TorchInductor backend (Triton GPU kernels or C++ CPU kernels)
import torch
import torch.nn as nn
import time
# Define a base model
class TransformerBlock(nn.Module):
def __init__(self, d_model=512, nhead=8, dim_feedforward=2048):
super().__init__()
self.attention = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
nn.ReLU(),
nn.Linear(dim_feedforward, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
attn_out, _ = self.attention(x, x, x)
x = self.norm1(x + attn_out)
ff_out = self.feed_forward(x)
x = self.norm2(x + ff_out)
return x
device = "cuda" if torch.cuda.is_available() else "cpu"
model = TransformerBlock().to(device)
# Apply torch.compile
compiled_model = torch.compile(model)
# Warmup
x = torch.randn(32, 128, 512, device=device)
for _ in range(3):
_ = compiled_model(x)
# Performance comparison
N = 100
# Eager mode
start = time.perf_counter()
for _ in range(N):
_ = model(x)
if device == "cuda":
torch.cuda.synchronize()
elapsed_eager = time.perf_counter() - start
# Compiled mode
start = time.perf_counter()
for _ in range(N):
_ = compiled_model(x)
if device == "cuda":
torch.cuda.synchronize()
elapsed_compiled = time.perf_counter() - start
print(f"Eager mode: {elapsed_eager:.3f}s")
print(f"Compiled mode: {elapsed_compiled:.3f}s")
print(f"Speedup: {elapsed_eager / elapsed_compiled:.2f}x")
Compilation Modes
torch.compile offers three modes:
# Default mode - fast compile, good performance
model_default = torch.compile(model, mode="default")
# Reduce-overhead mode - better for small models
model_reduce = torch.compile(model, mode="reduce-overhead")
# Max-autotune - longer compile, best performance
model_autotune = torch.compile(model, mode="max-autotune")
# Full graph mode - no dynamic graphs (strict)
model_full = torch.compile(model, fullgraph=True)
# Backend selection
model_eager = torch.compile(model, backend="eager") # No compilation (debug)
model_aot = torch.compile(model, backend="aot_eager") # AOT only
model_inductor = torch.compile(model, backend="inductor") # Default
Dynamic Shape Support
import torch._dynamo as dynamo
# Enable dynamic shapes
model_dynamic = torch.compile(model, dynamic=True)
# Works without recompilation across different batch sizes
for batch_size in [8, 16, 32, 64]:
x = torch.randn(batch_size, 128, 512, device=device)
out = model_dynamic(x)
print(f"Batch {batch_size}: output shape {out.shape}")
# Check compilation cache
print(dynamo.explain(model)(x))
Migrating Existing Code
# Applying torch.compile to a training loop
def train_epoch(model, optimizer, dataloader, criterion):
model.train()
total_loss = 0
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
# Just compile the model
model = MyModel().cuda()
compiled_model = torch.compile(model) # Only this line added!
optimizer = torch.optim.Adam(compiled_model.parameters())
criterion = nn.CrossEntropyLoss()
2. Custom Autograd
PyTorch's automatic differentiation engine is powerful, but sometimes you need custom gradient computation for numerical stability or efficiency.
Subclassing torch.autograd.Function
import torch
from torch.autograd import Function
class SigmoidFunction(Function):
"""Numerically stable sigmoid implementation"""
@staticmethod
def forward(ctx, input):
# sigmoid = 1 / (1 + exp(-x))
sigmoid = torch.sigmoid(input)
# Save tensors for use in backward
ctx.save_for_backward(sigmoid)
return sigmoid
@staticmethod
def backward(ctx, grad_output):
# Retrieve saved sigmoid
sigmoid, = ctx.saved_tensors
# gradient = sigmoid * (1 - sigmoid) * grad_output
grad_input = sigmoid * (1 - sigmoid) * grad_output
return grad_input
# Usage
def custom_sigmoid(x):
return SigmoidFunction.apply(x)
x = torch.randn(3, 4, requires_grad=True)
y = custom_sigmoid(x)
loss = y.sum()
loss.backward()
print(f"Gradient: {x.grad}")
More Complex Example: Leaky ReLU with Custom Backward
class LeakyReLUFunction(Function):
@staticmethod
def forward(ctx, input, negative_slope=0.01):
ctx.save_for_backward(input)
ctx.negative_slope = negative_slope
return input.clamp(min=0) + negative_slope * input.clamp(max=0)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
negative_slope = ctx.negative_slope
grad_input = grad_output.clone()
grad_input[input < 0] *= negative_slope
# Gradient with respect to negative_slope is None
return grad_input, None
class CustomLeakyReLU(nn.Module):
def __init__(self, negative_slope=0.01):
super().__init__()
self.negative_slope = negative_slope
def forward(self, x):
return LeakyReLUFunction.apply(x, self.negative_slope)
Numerical Gradient Check
from torch.autograd import gradcheck
def test_custom_op():
# Numerical gradient check (float64 recommended)
input = torch.randn(3, 4, dtype=torch.float64, requires_grad=True)
# gradcheck computes Jacobian numerically and compares with autograd
result = gradcheck(SigmoidFunction.apply, (input,), eps=1e-6, atol=1e-4)
print(f"Gradient check passed: {result}")
# Double backprop test
input = torch.randn(3, 4, dtype=torch.float64, requires_grad=True)
result = gradcheck(
SigmoidFunction.apply,
(input,),
eps=1e-6,
atol=1e-4,
check_grad_dtypes=True
)
test_custom_op()
Double Backpropagation
class SquaredFunction(Function):
"""Custom x^2 implementation with double backprop support"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x ** 2
@staticmethod
def backward(ctx, grad_output):
x, = ctx.saved_tensors
# create_graph=True required to support double backprop
return 2 * x * grad_output
# Double backprop example (useful in meta-learning like MAML)
x = torch.randn(3, requires_grad=True)
y = SquaredFunction.apply(x)
grad_x = torch.autograd.grad(y.sum(), x, create_graph=True)[0]
# grad_x = 2x, backprop through this again
grad_grad_x = torch.autograd.grad(grad_x.sum(), x)[0]
# grad_grad_x = 2 (constant)
print(f"Second derivative: {grad_grad_x}")
3. Custom CUDA Operators
torch.utils.cpp_extension Overview
PyTorch provides tools for writing C++/CUDA extensions and using them from Python.
# JIT compilation (suitable for development/prototyping)
from torch.utils.cpp_extension import load_inline
# C++ CPU operator
cpp_source = """
#include <torch/extension.h>
torch::Tensor relu_forward(torch::Tensor input) {
return input.clamp_min(0);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("relu_forward", &relu_forward, "ReLU forward");
}
"""
# Inline compilation
custom_relu_cpp = load_inline(
name="custom_relu_cpp",
cpp_sources=cpp_source,
functions=["relu_forward"],
verbose=False
)
x = torch.randn(5)
result = custom_relu_cpp.relu_forward(x)
print(result)
CUDA Kernel Example: Fused Softmax
# CUDA source
cuda_source = """
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
__global__ void fused_softmax_kernel(
const float* input,
float* output,
int rows,
int cols
) {
int row = blockIdx.x;
if (row >= rows) return;
const float* row_input = input + row * cols;
float* row_output = output + row * cols;
// Find max (for numerical stability)
float max_val = row_input[0];
for (int i = 1; i < cols; i++) {
max_val = fmaxf(max_val, row_input[i]);
}
// Compute exp(x - max) and sum
float sum = 0.0f;
for (int i = 0; i < cols; i++) {
row_output[i] = expf(row_input[i] - max_val);
sum += row_output[i];
}
// Normalize
for (int i = 0; i < cols; i++) {
row_output[i] /= sum;
}
}
torch::Tensor fused_softmax_cuda(torch::Tensor input) {
auto output = torch::zeros_like(input);
int rows = input.size(0);
int cols = input.size(1);
fused_softmax_kernel<<<rows, 1>>>(
input.data_ptr<float>(),
output.data_ptr<float>(),
rows,
cols
);
return output;
}
"""
cpp_source_cuda = """
#include <torch/extension.h>
torch::Tensor fused_softmax_cuda(torch::Tensor input);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_softmax", &fused_softmax_cuda, "Fused Softmax CUDA");
}
"""
# Compile only if CUDA is available
if torch.cuda.is_available():
from torch.utils.cpp_extension import load_inline
fused_softmax_ext = load_inline(
name="fused_softmax",
cpp_sources=cpp_source_cuda,
cuda_sources=cuda_source,
functions=["fused_softmax"],
verbose=True
)
# Test
x = torch.randn(4, 8, device="cuda")
result = fused_softmax_ext.fused_softmax(x)
expected = torch.softmax(x, dim=1)
print(f"Max difference: {(result - expected).abs().max().item():.6f}")
Building a Package with setup.py
# setup.py
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name="custom_ops",
ext_modules=[
CUDAExtension(
name="custom_ops",
sources=[
"custom_ops/ops.cpp",
"custom_ops/ops_cuda.cu",
],
extra_compile_args={
"cxx": ["-O3"],
"nvcc": ["-O3", "--use_fast_math"],
}
)
],
cmdclass={
"build_ext": BuildExtension
}
)
# Build: python setup.py install
4. Memory Optimization Techniques
GPU Memory Profiling
import torch
def print_gpu_memory_stats():
"""Utility to print GPU memory statistics"""
if torch.cuda.is_available():
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
print(f"Max Allocated: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
print(torch.cuda.memory_summary(abbreviated=True))
# Start memory tracking
torch.cuda.reset_peak_memory_stats()
print_gpu_memory_stats()
# Load model
model = TransformerBlock().cuda()
print_gpu_memory_stats()
Gradient Checkpointing (Activation Recomputation)
Gradient Checkpointing avoids storing intermediate activation values during the forward pass, recomputing them as needed during the backward pass. This saves significant memory at the cost of roughly 30% more computation.
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
import torch.nn as nn
class DeepTransformer(nn.Module):
def __init__(self, num_layers=12, d_model=512, nhead=8):
super().__init__()
self.layers = nn.ModuleList([
TransformerBlock(d_model, nhead) for _ in range(num_layers)
])
def forward_with_checkpointing(self, x):
"""Apply gradient checkpointing to each layer"""
for layer in self.layers:
# checkpoint does not store intermediate activations
x = checkpoint(layer, x, use_reentrant=False)
return x
def forward_sequential_checkpointing(self, x):
"""Apply checkpointing to sequential modules"""
# Group into segments of 4 layers
x = checkpoint_sequential(self.layers, segments=3, input=x)
return x
# Memory comparison
model = DeepTransformer(num_layers=24).cuda()
x = torch.randn(16, 512, 512, device="cuda")
# Normal forward
torch.cuda.reset_peak_memory_stats()
out = model.forward_with_checkpointing(x)
checkpoint_mem = torch.cuda.max_memory_allocated()
print(f"Checkpoint memory: {checkpoint_mem / 1024**3:.2f} GB")
Gradient Accumulation
def train_with_gradient_accumulation(
model, optimizer, dataloader, criterion,
accumulation_steps=4
):
"""
Implement large effective batch size with limited GPU memory
"""
model.train()
optimizer.zero_grad()
total_loss = 0
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.cuda(), target.cuda()
# Forward pass
output = model(data)
# Scale loss by accumulation_steps
loss = criterion(output, target) / accumulation_steps
loss.backward()
total_loss += loss.item() * accumulation_steps
# Update every accumulation_steps
if (batch_idx + 1) % accumulation_steps == 0:
# Gradient clipping (optional)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
return total_loss / len(dataloader)
Mixed Precision Training (AMP)
from torch.cuda.amp import autocast, GradScaler
def train_with_amp(model, optimizer, dataloader, criterion):
"""Automatic Mixed Precision for memory savings + speed boost"""
scaler = GradScaler()
model.train()
for data, target in dataloader:
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
# float16 operations inside autocast context
with autocast():
output = model(data)
loss = criterion(output, target)
# Backward with scaled gradients
scaler.scale(loss).backward()
# Unscale then clip gradients
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# Optimizer step (with NaN/Inf check)
scaler.step(optimizer)
scaler.update()
8-bit Optimizer
# Requires bitsandbytes: pip install bitsandbytes
try:
import bitsandbytes as bnb
# Use 8-bit Adam instead of regular Adam
optimizer_8bit = bnb.optim.Adam8bit(
model.parameters(),
lr=1e-4,
betas=(0.9, 0.999)
)
# PagedAdam (can offload to CPU)
optimizer_paged = bnb.optim.PagedAdam(
model.parameters(),
lr=1e-4
)
print("8-bit optimizer loaded successfully")
except ImportError:
print("bitsandbytes not installed, using regular Adam")
optimizer_8bit = torch.optim.Adam(model.parameters(), lr=1e-4)
5. functorch and vmap
vmap - Batched Operations
vmap efficiently applies a function operating on single samples to a batch.
import torch
from torch import vmap
# Function operating on a single sample
def single_linear(weight, bias, x):
return weight @ x + bias
# Batch processing with vmap
batched_linear = vmap(single_linear)
# Batch data
batch_size = 32
weight = torch.randn(batch_size, 10, 5)
bias = torch.randn(batch_size, 10)
x = torch.randn(batch_size, 5)
# Automatically handled in batch
result = batched_linear(weight, bias, x)
print(f"Result shape: {result.shape}") # (32, 10)
grad - Functional Gradient
from torch.func import grad, vmap, functional_call
# Functional gradient
def scalar_loss(params, x, y):
pred = functional_call(model, params, (x,))
return ((pred - y) ** 2).mean()
# Gradient with respect to parameters
params = dict(model.named_parameters())
grad_fn = grad(scalar_loss)
x = torch.randn(1, 10)
y = torch.randn(1, 5)
grads = grad_fn(params, x, y)
print({k: v.shape for k, v in grads.items()})
Ensemble Models with vmap
from torch.func import stack_module_state, functional_call, vmap
def create_ensemble(model_class, num_models, *args, **kwargs):
"""Efficient ensemble using vmap"""
models = [model_class(*args, **kwargs) for _ in range(num_models)]
# Stack parameters from all models
params, buffers = stack_module_state(models)
# Single model forward
base_model = model_class(*args, **kwargs)
def single_forward(params, buffers, x):
return functional_call(base_model, (params, buffers), (x,))
# Run all models in parallel with vmap
ensemble_forward = vmap(single_forward, in_dims=(0, 0, None))
return ensemble_forward, params, buffers
# Usage example
ensemble_fn, params, buffers = create_ensemble(
nn.Linear, num_models=5, in_features=10, out_features=5
)
x = torch.randn(32, 10)
ensemble_out = ensemble_fn(params, buffers, x)
print(f"Ensemble output shape: {ensemble_out.shape}") # (5, 32, 5)
Meta-Learning (MAML) with grad and vmap
from torch.func import grad, vmap, functional_call
def inner_loop(params, support_x, support_y, base_model, lr=0.01, steps=5):
"""MAML inner loop"""
adapted_params = {k: v.clone() for k, v in params.items()}
for _ in range(steps):
def loss_fn(params):
pred = functional_call(base_model, params, (support_x,))
return ((pred - support_y) ** 2).mean()
grads = grad(loss_fn)(adapted_params)
adapted_params = {
k: p - lr * grads[k]
for k, p in adapted_params.items()
}
return adapted_params
6. PyTorch Profiler
Basic Profiling
import torch
from torch.profiler import profile, record_function, ProfilerActivity
model = TransformerBlock().cuda()
x = torch.randn(32, 128, 512, device="cuda")
# Run profiler
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True
) as prof:
with record_function("model_inference"):
for _ in range(10):
output = model(x)
# Print results
print(prof.key_averages().table(
sort_by="cuda_time_total",
row_limit=20
))
# Export Chrome Trace
prof.export_chrome_trace("trace.json")
# Open in chrome://tracing
TensorBoard Integration
from torch.profiler import profile, ProfilerActivity, tensorboard_trace_handler
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(
wait=1, # Wait 1 step
warmup=1, # 1 warmup step
active=3, # Profile 3 steps
repeat=2 # Repeat 2 times
),
on_trace_ready=tensorboard_trace_handler("./log/profiler"),
record_shapes=True,
profile_memory=True
) as prof:
for step, (data, target) in enumerate(dataloader):
train_step(model, optimizer, data, target)
prof.step() # Profile according to schedule
# tensorboard --logdir=./log/profiler
Detailed Memory Analysis
# Memory snapshot (PyTorch 2.1+)
torch.cuda.memory._record_memory_history(max_entries=100000)
# Run code
x = torch.randn(100, 100, device="cuda")
y = x @ x.T
z = y.sum()
# Save snapshot
snapshot = torch.cuda.memory._snapshot()
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
print(f"Active allocations: {len(snapshot['segments'])}")
7. TorchScript
torch.jit.script vs torch.jit.trace
import torch
import torch.nn as nn
class ConditionalModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 5)
def forward(self, x, flag: bool = True):
if flag: # Control flow makes trace impossible
return torch.relu(self.linear(x))
else:
return torch.sigmoid(self.linear(x))
model = ConditionalModel()
# script: includes control flow (recommended)
scripted_model = torch.jit.script(model)
print(scripted_model.code)
# trace: captures only a single path
x = torch.randn(1, 10)
traced_model = torch.jit.trace(model, x)
# Note: flag=False path is not captured
# Save and load
scripted_model.save("model_scripted.pt")
loaded = torch.jit.load("model_scripted.pt")
TorchScript Optimization
# Apply optimization passes
scripted = torch.jit.script(model)
optimized = torch.jit.optimize_for_inference(scripted)
# Export for C++ environment
# In C++: torch::jit::script::Module m = torch::jit::load("model.pt");
8. Dynamic Shapes and torch.export
Using torch.export
import torch
from torch.export import export
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 5)
def forward(self, x):
return self.linear(x)
model = MyModel()
example_inputs = (torch.randn(2, 10),)
# Specify dynamic shapes
from torch.export import Dim
batch = Dim("batch", min=1, max=100)
# Export model
exported = export(
model,
example_inputs,
dynamic_shapes={"x": {0: batch}}
)
print(exported)
print(exported.graph_module.code)
# Run ExportedProgram
result = exported.module()(torch.randn(5, 10))
print(f"Result shape: {result.shape}")
9. Custom Dataset and Sampler
IterableDataset
from torch.utils.data import IterableDataset, DataLoader
import torch
class StreamingDataset(IterableDataset):
"""Stream large datasets efficiently"""
def __init__(self, data_paths, transform=None):
self.data_paths = data_paths
self.transform = transform
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
# Single process
paths = self.data_paths
else:
# Multiprocessing: split data across workers
per_worker = len(self.data_paths) // worker_info.num_workers
start = worker_info.id * per_worker
end = start + per_worker
paths = self.data_paths[start:end]
for path in paths:
# Stream data from file
data = self._load_file(path)
for sample in data:
if self.transform:
sample = self.transform(sample)
yield sample
def _load_file(self, path):
# In real implementation, load from file
return [torch.randn(10) for _ in range(100)]
# Use persistent_workers in DataLoader
dataset = StreamingDataset(data_paths=["file1.pt", "file2.pt"])
dataloader = DataLoader(
dataset,
batch_size=32,
num_workers=4,
persistent_workers=True, # Reuse worker processes
pin_memory=True, # Optimize GPU transfer
prefetch_factor=2 # Number of batches to prefetch
)
Custom Sampler
from torch.utils.data import Sampler
import numpy as np
class BalancedClassSampler(Sampler):
"""Weighted sampling for class imbalance problems"""
def __init__(self, dataset, num_samples_per_class=None):
self.dataset = dataset
labels = [dataset[i][1] for i in range(len(dataset))]
self.labels = torch.tensor(labels)
# Indices per class
self.class_indices = {}
for cls in torch.unique(self.labels):
self.class_indices[cls.item()] = (
self.labels == cls
).nonzero(as_tuple=True)[0].tolist()
self.num_classes = len(self.class_indices)
self.num_samples_per_class = (
num_samples_per_class or
max(len(v) for v in self.class_indices.values())
)
def __iter__(self):
indices = []
for cls_idx in self.class_indices.values():
# Sample equal numbers from each class (with replacement)
sampled = np.random.choice(
cls_idx,
self.num_samples_per_class,
replace=True
).tolist()
indices.extend(sampled)
# Shuffle
np.random.shuffle(indices)
return iter(indices)
def __len__(self):
return self.num_classes * self.num_samples_per_class
10. PyTorch Lightning
Complete LightningModule Example
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
class LightningTransformer(pl.LightningModule):
def __init__(
self,
d_model=512,
nhead=8,
num_layers=6,
num_classes=10,
learning_rate=1e-4
):
super().__init__()
self.save_hyperparameters() # Auto-save HParams
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
batch_first=True
),
num_layers=num_layers
)
self.classifier = nn.Linear(d_model, num_classes)
self.criterion = nn.CrossEntropyLoss()
def forward(self, x):
encoded = self.encoder(x)
# Sequence mean pooling
pooled = encoded.mean(dim=1)
return self.classifier(pooled)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = self.criterion(logits, y)
acc = (logits.argmax(dim=1) == y).float().mean()
# Logging
self.log("train_loss", loss, prog_bar=True)
self.log("train_acc", acc, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = self.criterion(logits, y)
acc = (logits.argmax(dim=1) == y).float().mean()
self.log("val_loss", loss, prog_bar=True)
self.log("val_acc", acc, prog_bar=True)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.hparams.learning_rate,
weight_decay=0.01
)
# Cosine annealing scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=100
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss"
}
}
class LightningDataModule(pl.LightningDataModule):
def __init__(self, batch_size=32):
super().__init__()
self.batch_size = batch_size
def setup(self, stage=None):
# Dummy data
x = torch.randn(1000, 32, 512)
y = torch.randint(0, 10, (1000,))
dataset = TensorDataset(x, y)
n = len(dataset)
self.train_ds = torch.utils.data.Subset(dataset, range(int(0.8*n)))
self.val_ds = torch.utils.data.Subset(dataset, range(int(0.8*n), n))
def train_dataloader(self):
return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(self.val_ds, batch_size=self.batch_size)
# Run training
model = LightningTransformer()
dm = LightningDataModule()
# Setup callbacks
callbacks = [
ModelCheckpoint(
monitor="val_loss",
save_top_k=3,
mode="min",
filename="transformer-{epoch:02d}-{val_loss:.2f}"
),
EarlyStopping(monitor="val_loss", patience=10, mode="min")
]
# Trainer
trainer = pl.Trainer(
max_epochs=100,
accelerator="auto", # Auto-detect GPU
devices="auto",
callbacks=callbacks,
logger=TensorBoardLogger("tb_logs", name="transformer"),
gradient_clip_val=1.0, # Gradient clipping
accumulate_grad_batches=4, # Gradient accumulation
precision="16-mixed", # AMP
)
trainer.fit(model, dm)
11. Model Quantization
Dynamic Quantization
import torch
import torch.nn as nn
from torch.ao.quantization import quantize_dynamic
# Inference-only quantization - easiest approach
model = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 128)
)
model.eval()
# Quantize Linear and LSTM layers to int8
quantized_model = quantize_dynamic(
model,
{nn.Linear, nn.LSTM}, # Layer types to quantize
dtype=torch.qint8
)
# Size comparison
import os
torch.save(model.state_dict(), "model_fp32.pt")
torch.save(quantized_model.state_dict(), "model_int8.pt")
fp32_size = os.path.getsize("model_fp32.pt")
int8_size = os.path.getsize("model_int8.pt")
print(f"FP32 size: {fp32_size / 1024:.1f} KB")
print(f"INT8 size: {int8_size / 1024:.1f} KB")
print(f"Compression: {fp32_size / int8_size:.1f}x")
Static Quantization
import torch
from torch.ao.quantization import (
get_default_qconfig,
prepare,
convert,
)
class QuantizableModel(nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.ao.quantization.QuantStub()
self.linear1 = nn.Linear(512, 256)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(256, 10)
self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
x = self.dequant(x)
return x
model = QuantizableModel()
model.eval()
# Quantization config
model.qconfig = get_default_qconfig("fbgemm") # For x86 CPU
# Prepare (insert observers)
prepared_model = prepare(model)
# Collect statistics with calibration data
with torch.no_grad():
for _ in range(100):
calibration_data = torch.randn(32, 512)
prepared_model(calibration_data)
# Convert to quantized
quantized_static = convert(prepared_model)
# Inference
x = torch.randn(1, 512)
with torch.no_grad():
output = quantized_static(x)
print(f"Output shape: {output.shape}")
QAT (Quantization-Aware Training)
from torch.ao.quantization import (
get_default_qat_qconfig,
prepare_qat,
convert
)
model = QuantizableModel()
model.train()
# QAT config
model.qconfig = get_default_qat_qconfig("fbgemm")
# Prepare QAT (insert fake quantization nodes)
prepared_qat = prepare_qat(model)
# Train (include quantization error in training)
optimizer = torch.optim.SGD(prepared_qat.parameters(), lr=0.0001)
for epoch in range(10):
for x, y in dummy_dataloader():
output = prepared_qat(x)
loss = nn.functional.cross_entropy(output, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Convert after training
prepared_qat.eval()
quantized_qat = convert(prepared_qat)
def dummy_dataloader():
for _ in range(10):
yield torch.randn(32, 512), torch.randint(0, 10, (32,))
12. Tensor Parallelism and Pipeline Parallelism
DeviceMesh and DTensor API
import torch
import torch.distributed as dist
from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
RowwiseParallel
)
from torch.distributed._tensor import DeviceMesh
# Initialize distributed training
def setup_distributed():
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
class TransformerMLP(nn.Module):
def __init__(self, d_model=1024, dim_feedforward=4096):
super().__init__()
self.fc1 = nn.Linear(d_model, dim_feedforward)
self.act = nn.GELU()
self.fc2 = nn.Linear(dim_feedforward, d_model)
def forward(self, x):
return self.fc2(self.act(self.fc1(x)))
# Apply Tensor Parallelism
def apply_tensor_parallel(model, mesh):
"""
fc1: Column-wise sharding (split output dimension)
fc2: Row-wise sharding (split input dimension)
"""
parallelize_module(
model,
mesh,
{
"fc1": ColwiseParallel(),
"fc2": RowwiseParallel(),
}
)
return model
# Example (requires 2-GPU setup)
# device_mesh = DeviceMesh("cuda", [0, 1])
# model = TransformerMLP()
# model = apply_tensor_parallel(model, device_mesh)
FSDP (Fully Sharded Data Parallel)
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
BackwardPrefetch,
ShardingStrategy
)
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
transformer_auto_wrap_policy
)
import functools
def setup_fsdp_model(model):
"""FSDP setup - distribute large models across multiple GPUs"""
# Mixed precision config
mp_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.bfloat16,
)
# Set TransformerBlock as wrapping unit
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={TransformerBlock}
)
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mp_policy,
sharding_strategy=ShardingStrategy.FULL_SHARD,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
device_id=torch.cuda.current_device()
)
return model
Summary
This guide covered PyTorch's advanced techniques:
- torch.compile: Over 2x performance improvement without code changes
- Custom Autograd: Implementing specialized gradient computation
- CUDA Extensions: Integrating GPU kernels with PyTorch
- Memory Optimization: Gradient Checkpointing, AMP, 8-bit Optimizer
- functorch/vmap: Batch processing and meta-learning via functional API
- PyTorch Profiler: Analyzing performance bottlenecks
- TorchScript/export: Deployment optimization
- PyTorch Lightning: Structuring code and automating training
- Quantization: Reducing model size with INT8
- Distributed Training: Tensor Parallel, FSDP
These techniques are complementary, and in real projects it is common to combine several of them. For large-scale model training, the combination of AMP + Gradient Checkpointing + FSDP + torch.compile is particularly powerful.