Skip to content
Published on

PyTorch Advanced Techniques Complete Guide: torch.compile, Custom Ops, Memory Optimization

Authors

PyTorch Advanced Techniques Complete Guide

PyTorch is one of the most popular deep learning frameworks for both research and production. However, most developers stop at basic tensor operations and nn.Module definitions. This guide covers advanced techniques for maximizing performance, implementing custom operations, and managing memory efficiently in real production environments.

1. torch.compile (PyTorch 2.0+)

torch.compile, introduced in PyTorch 2.0, is a feature that compiles your model to dramatically improve execution speed. Unlike TorchScript or ONNX export, torch.compile can bring over 2x speedups with minimal code changes.

Introduction and Benefits

torch.compile is composed of three core components:

  1. TorchDynamo: Intercepts Python bytecode to generate an FX graph
  2. AOTAutograd: Pre-compiles the automatic differentiation graph
  3. Inductor: Generates optimized kernels via TorchInductor backend (Triton GPU kernels or C++ CPU kernels)
import torch
import torch.nn as nn
import time

# Define a base model
class TransformerBlock(nn.Module):
    def __init__(self, d_model=512, nhead=8, dim_feedforward=2048):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, nhead, batch_first=True)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_out)
        ff_out = self.feed_forward(x)
        x = self.norm2(x + ff_out)
        return x

device = "cuda" if torch.cuda.is_available() else "cpu"
model = TransformerBlock().to(device)

# Apply torch.compile
compiled_model = torch.compile(model)

# Warmup
x = torch.randn(32, 128, 512, device=device)
for _ in range(3):
    _ = compiled_model(x)

# Performance comparison
N = 100

# Eager mode
start = time.perf_counter()
for _ in range(N):
    _ = model(x)
if device == "cuda":
    torch.cuda.synchronize()
elapsed_eager = time.perf_counter() - start

# Compiled mode
start = time.perf_counter()
for _ in range(N):
    _ = compiled_model(x)
if device == "cuda":
    torch.cuda.synchronize()
elapsed_compiled = time.perf_counter() - start

print(f"Eager mode: {elapsed_eager:.3f}s")
print(f"Compiled mode: {elapsed_compiled:.3f}s")
print(f"Speedup: {elapsed_eager / elapsed_compiled:.2f}x")

Compilation Modes

torch.compile offers three modes:

# Default mode - fast compile, good performance
model_default = torch.compile(model, mode="default")

# Reduce-overhead mode - better for small models
model_reduce = torch.compile(model, mode="reduce-overhead")

# Max-autotune - longer compile, best performance
model_autotune = torch.compile(model, mode="max-autotune")

# Full graph mode - no dynamic graphs (strict)
model_full = torch.compile(model, fullgraph=True)

# Backend selection
model_eager = torch.compile(model, backend="eager")    # No compilation (debug)
model_aot = torch.compile(model, backend="aot_eager")  # AOT only
model_inductor = torch.compile(model, backend="inductor")  # Default

Dynamic Shape Support

import torch._dynamo as dynamo

# Enable dynamic shapes
model_dynamic = torch.compile(model, dynamic=True)

# Works without recompilation across different batch sizes
for batch_size in [8, 16, 32, 64]:
    x = torch.randn(batch_size, 128, 512, device=device)
    out = model_dynamic(x)
    print(f"Batch {batch_size}: output shape {out.shape}")

# Check compilation cache
print(dynamo.explain(model)(x))

Migrating Existing Code

# Applying torch.compile to a training loop
def train_epoch(model, optimizer, dataloader, criterion):
    model.train()
    total_loss = 0

    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(dataloader)

# Just compile the model
model = MyModel().cuda()
compiled_model = torch.compile(model)  # Only this line added!

optimizer = torch.optim.Adam(compiled_model.parameters())
criterion = nn.CrossEntropyLoss()

2. Custom Autograd

PyTorch's automatic differentiation engine is powerful, but sometimes you need custom gradient computation for numerical stability or efficiency.

Subclassing torch.autograd.Function

import torch
from torch.autograd import Function

class SigmoidFunction(Function):
    """Numerically stable sigmoid implementation"""

    @staticmethod
    def forward(ctx, input):
        # sigmoid = 1 / (1 + exp(-x))
        sigmoid = torch.sigmoid(input)
        # Save tensors for use in backward
        ctx.save_for_backward(sigmoid)
        return sigmoid

    @staticmethod
    def backward(ctx, grad_output):
        # Retrieve saved sigmoid
        sigmoid, = ctx.saved_tensors
        # gradient = sigmoid * (1 - sigmoid) * grad_output
        grad_input = sigmoid * (1 - sigmoid) * grad_output
        return grad_input

# Usage
def custom_sigmoid(x):
    return SigmoidFunction.apply(x)

x = torch.randn(3, 4, requires_grad=True)
y = custom_sigmoid(x)
loss = y.sum()
loss.backward()
print(f"Gradient: {x.grad}")

More Complex Example: Leaky ReLU with Custom Backward

class LeakyReLUFunction(Function):
    @staticmethod
    def forward(ctx, input, negative_slope=0.01):
        ctx.save_for_backward(input)
        ctx.negative_slope = negative_slope
        return input.clamp(min=0) + negative_slope * input.clamp(max=0)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        negative_slope = ctx.negative_slope
        grad_input = grad_output.clone()
        grad_input[input < 0] *= negative_slope
        # Gradient with respect to negative_slope is None
        return grad_input, None

class CustomLeakyReLU(nn.Module):
    def __init__(self, negative_slope=0.01):
        super().__init__()
        self.negative_slope = negative_slope

    def forward(self, x):
        return LeakyReLUFunction.apply(x, self.negative_slope)

Numerical Gradient Check

from torch.autograd import gradcheck

def test_custom_op():
    # Numerical gradient check (float64 recommended)
    input = torch.randn(3, 4, dtype=torch.float64, requires_grad=True)

    # gradcheck computes Jacobian numerically and compares with autograd
    result = gradcheck(SigmoidFunction.apply, (input,), eps=1e-6, atol=1e-4)
    print(f"Gradient check passed: {result}")

    # Double backprop test
    input = torch.randn(3, 4, dtype=torch.float64, requires_grad=True)
    result = gradcheck(
        SigmoidFunction.apply,
        (input,),
        eps=1e-6,
        atol=1e-4,
        check_grad_dtypes=True
    )

test_custom_op()

Double Backpropagation

class SquaredFunction(Function):
    """Custom x^2 implementation with double backprop support"""

    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x ** 2

    @staticmethod
    def backward(ctx, grad_output):
        x, = ctx.saved_tensors
        # create_graph=True required to support double backprop
        return 2 * x * grad_output

# Double backprop example (useful in meta-learning like MAML)
x = torch.randn(3, requires_grad=True)
y = SquaredFunction.apply(x)
grad_x = torch.autograd.grad(y.sum(), x, create_graph=True)[0]
# grad_x = 2x, backprop through this again
grad_grad_x = torch.autograd.grad(grad_x.sum(), x)[0]
# grad_grad_x = 2 (constant)
print(f"Second derivative: {grad_grad_x}")

3. Custom CUDA Operators

torch.utils.cpp_extension Overview

PyTorch provides tools for writing C++/CUDA extensions and using them from Python.

# JIT compilation (suitable for development/prototyping)
from torch.utils.cpp_extension import load_inline

# C++ CPU operator
cpp_source = """
#include <torch/extension.h>

torch::Tensor relu_forward(torch::Tensor input) {
    return input.clamp_min(0);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("relu_forward", &relu_forward, "ReLU forward");
}
"""

# Inline compilation
custom_relu_cpp = load_inline(
    name="custom_relu_cpp",
    cpp_sources=cpp_source,
    functions=["relu_forward"],
    verbose=False
)

x = torch.randn(5)
result = custom_relu_cpp.relu_forward(x)
print(result)

CUDA Kernel Example: Fused Softmax

# CUDA source
cuda_source = """
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

__global__ void fused_softmax_kernel(
    const float* input,
    float* output,
    int rows,
    int cols
) {
    int row = blockIdx.x;
    if (row >= rows) return;

    const float* row_input = input + row * cols;
    float* row_output = output + row * cols;

    // Find max (for numerical stability)
    float max_val = row_input[0];
    for (int i = 1; i < cols; i++) {
        max_val = fmaxf(max_val, row_input[i]);
    }

    // Compute exp(x - max) and sum
    float sum = 0.0f;
    for (int i = 0; i < cols; i++) {
        row_output[i] = expf(row_input[i] - max_val);
        sum += row_output[i];
    }

    // Normalize
    for (int i = 0; i < cols; i++) {
        row_output[i] /= sum;
    }
}

torch::Tensor fused_softmax_cuda(torch::Tensor input) {
    auto output = torch::zeros_like(input);
    int rows = input.size(0);
    int cols = input.size(1);

    fused_softmax_kernel<<<rows, 1>>>(
        input.data_ptr<float>(),
        output.data_ptr<float>(),
        rows,
        cols
    );

    return output;
}
"""

cpp_source_cuda = """
#include <torch/extension.h>
torch::Tensor fused_softmax_cuda(torch::Tensor input);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("fused_softmax", &fused_softmax_cuda, "Fused Softmax CUDA");
}
"""

# Compile only if CUDA is available
if torch.cuda.is_available():
    from torch.utils.cpp_extension import load_inline
    fused_softmax_ext = load_inline(
        name="fused_softmax",
        cpp_sources=cpp_source_cuda,
        cuda_sources=cuda_source,
        functions=["fused_softmax"],
        verbose=True
    )

    # Test
    x = torch.randn(4, 8, device="cuda")
    result = fused_softmax_ext.fused_softmax(x)
    expected = torch.softmax(x, dim=1)
    print(f"Max difference: {(result - expected).abs().max().item():.6f}")

Building a Package with setup.py

# setup.py
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
    name="custom_ops",
    ext_modules=[
        CUDAExtension(
            name="custom_ops",
            sources=[
                "custom_ops/ops.cpp",
                "custom_ops/ops_cuda.cu",
            ],
            extra_compile_args={
                "cxx": ["-O3"],
                "nvcc": ["-O3", "--use_fast_math"],
            }
        )
    ],
    cmdclass={
        "build_ext": BuildExtension
    }
)
# Build: python setup.py install

4. Memory Optimization Techniques

GPU Memory Profiling

import torch

def print_gpu_memory_stats():
    """Utility to print GPU memory statistics"""
    if torch.cuda.is_available():
        print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
        print(f"Reserved:  {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
        print(f"Max Allocated: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
        print(torch.cuda.memory_summary(abbreviated=True))

# Start memory tracking
torch.cuda.reset_peak_memory_stats()
print_gpu_memory_stats()

# Load model
model = TransformerBlock().cuda()
print_gpu_memory_stats()

Gradient Checkpointing (Activation Recomputation)

Gradient Checkpointing avoids storing intermediate activation values during the forward pass, recomputing them as needed during the backward pass. This saves significant memory at the cost of roughly 30% more computation.

from torch.utils.checkpoint import checkpoint, checkpoint_sequential
import torch.nn as nn

class DeepTransformer(nn.Module):
    def __init__(self, num_layers=12, d_model=512, nhead=8):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, nhead) for _ in range(num_layers)
        ])

    def forward_with_checkpointing(self, x):
        """Apply gradient checkpointing to each layer"""
        for layer in self.layers:
            # checkpoint does not store intermediate activations
            x = checkpoint(layer, x, use_reentrant=False)
        return x

    def forward_sequential_checkpointing(self, x):
        """Apply checkpointing to sequential modules"""
        # Group into segments of 4 layers
        x = checkpoint_sequential(self.layers, segments=3, input=x)
        return x

# Memory comparison
model = DeepTransformer(num_layers=24).cuda()
x = torch.randn(16, 512, 512, device="cuda")

# Normal forward
torch.cuda.reset_peak_memory_stats()
out = model.forward_with_checkpointing(x)
checkpoint_mem = torch.cuda.max_memory_allocated()

print(f"Checkpoint memory: {checkpoint_mem / 1024**3:.2f} GB")

Gradient Accumulation

def train_with_gradient_accumulation(
    model, optimizer, dataloader, criterion,
    accumulation_steps=4
):
    """
    Implement large effective batch size with limited GPU memory
    """
    model.train()
    optimizer.zero_grad()
    total_loss = 0

    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.cuda(), target.cuda()

        # Forward pass
        output = model(data)
        # Scale loss by accumulation_steps
        loss = criterion(output, target) / accumulation_steps
        loss.backward()

        total_loss += loss.item() * accumulation_steps

        # Update every accumulation_steps
        if (batch_idx + 1) % accumulation_steps == 0:
            # Gradient clipping (optional)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            optimizer.zero_grad()

    return total_loss / len(dataloader)

Mixed Precision Training (AMP)

from torch.cuda.amp import autocast, GradScaler

def train_with_amp(model, optimizer, dataloader, criterion):
    """Automatic Mixed Precision for memory savings + speed boost"""
    scaler = GradScaler()
    model.train()

    for data, target in dataloader:
        data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()

        # float16 operations inside autocast context
        with autocast():
            output = model(data)
            loss = criterion(output, target)

        # Backward with scaled gradients
        scaler.scale(loss).backward()

        # Unscale then clip gradients
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        # Optimizer step (with NaN/Inf check)
        scaler.step(optimizer)
        scaler.update()

8-bit Optimizer

# Requires bitsandbytes: pip install bitsandbytes
try:
    import bitsandbytes as bnb

    # Use 8-bit Adam instead of regular Adam
    optimizer_8bit = bnb.optim.Adam8bit(
        model.parameters(),
        lr=1e-4,
        betas=(0.9, 0.999)
    )

    # PagedAdam (can offload to CPU)
    optimizer_paged = bnb.optim.PagedAdam(
        model.parameters(),
        lr=1e-4
    )

    print("8-bit optimizer loaded successfully")
except ImportError:
    print("bitsandbytes not installed, using regular Adam")
    optimizer_8bit = torch.optim.Adam(model.parameters(), lr=1e-4)

5. functorch and vmap

vmap - Batched Operations

vmap efficiently applies a function operating on single samples to a batch.

import torch
from torch import vmap

# Function operating on a single sample
def single_linear(weight, bias, x):
    return weight @ x + bias

# Batch processing with vmap
batched_linear = vmap(single_linear)

# Batch data
batch_size = 32
weight = torch.randn(batch_size, 10, 5)
bias = torch.randn(batch_size, 10)
x = torch.randn(batch_size, 5)

# Automatically handled in batch
result = batched_linear(weight, bias, x)
print(f"Result shape: {result.shape}")  # (32, 10)

grad - Functional Gradient

from torch.func import grad, vmap, functional_call

# Functional gradient
def scalar_loss(params, x, y):
    pred = functional_call(model, params, (x,))
    return ((pred - y) ** 2).mean()

# Gradient with respect to parameters
params = dict(model.named_parameters())
grad_fn = grad(scalar_loss)

x = torch.randn(1, 10)
y = torch.randn(1, 5)
grads = grad_fn(params, x, y)
print({k: v.shape for k, v in grads.items()})

Ensemble Models with vmap

from torch.func import stack_module_state, functional_call, vmap

def create_ensemble(model_class, num_models, *args, **kwargs):
    """Efficient ensemble using vmap"""
    models = [model_class(*args, **kwargs) for _ in range(num_models)]

    # Stack parameters from all models
    params, buffers = stack_module_state(models)

    # Single model forward
    base_model = model_class(*args, **kwargs)

    def single_forward(params, buffers, x):
        return functional_call(base_model, (params, buffers), (x,))

    # Run all models in parallel with vmap
    ensemble_forward = vmap(single_forward, in_dims=(0, 0, None))

    return ensemble_forward, params, buffers

# Usage example
ensemble_fn, params, buffers = create_ensemble(
    nn.Linear, num_models=5, in_features=10, out_features=5
)

x = torch.randn(32, 10)
ensemble_out = ensemble_fn(params, buffers, x)
print(f"Ensemble output shape: {ensemble_out.shape}")  # (5, 32, 5)

Meta-Learning (MAML) with grad and vmap

from torch.func import grad, vmap, functional_call

def inner_loop(params, support_x, support_y, base_model, lr=0.01, steps=5):
    """MAML inner loop"""
    adapted_params = {k: v.clone() for k, v in params.items()}

    for _ in range(steps):
        def loss_fn(params):
            pred = functional_call(base_model, params, (support_x,))
            return ((pred - support_y) ** 2).mean()

        grads = grad(loss_fn)(adapted_params)
        adapted_params = {
            k: p - lr * grads[k]
            for k, p in adapted_params.items()
        }

    return adapted_params

6. PyTorch Profiler

Basic Profiling

import torch
from torch.profiler import profile, record_function, ProfilerActivity

model = TransformerBlock().cuda()
x = torch.randn(32, 128, 512, device="cuda")

# Run profiler
with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True,
    profile_memory=True,
    with_stack=True
) as prof:
    with record_function("model_inference"):
        for _ in range(10):
            output = model(x)

# Print results
print(prof.key_averages().table(
    sort_by="cuda_time_total",
    row_limit=20
))

# Export Chrome Trace
prof.export_chrome_trace("trace.json")
# Open in chrome://tracing

TensorBoard Integration

from torch.profiler import profile, ProfilerActivity, tensorboard_trace_handler

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    schedule=torch.profiler.schedule(
        wait=1,    # Wait 1 step
        warmup=1,  # 1 warmup step
        active=3,  # Profile 3 steps
        repeat=2   # Repeat 2 times
    ),
    on_trace_ready=tensorboard_trace_handler("./log/profiler"),
    record_shapes=True,
    profile_memory=True
) as prof:
    for step, (data, target) in enumerate(dataloader):
        train_step(model, optimizer, data, target)
        prof.step()  # Profile according to schedule

# tensorboard --logdir=./log/profiler

Detailed Memory Analysis

# Memory snapshot (PyTorch 2.1+)
torch.cuda.memory._record_memory_history(max_entries=100000)

# Run code
x = torch.randn(100, 100, device="cuda")
y = x @ x.T
z = y.sum()

# Save snapshot
snapshot = torch.cuda.memory._snapshot()
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")

print(f"Active allocations: {len(snapshot['segments'])}")

7. TorchScript

torch.jit.script vs torch.jit.trace

import torch
import torch.nn as nn

class ConditionalModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)

    def forward(self, x, flag: bool = True):
        if flag:  # Control flow makes trace impossible
            return torch.relu(self.linear(x))
        else:
            return torch.sigmoid(self.linear(x))

model = ConditionalModel()

# script: includes control flow (recommended)
scripted_model = torch.jit.script(model)
print(scripted_model.code)

# trace: captures only a single path
x = torch.randn(1, 10)
traced_model = torch.jit.trace(model, x)
# Note: flag=False path is not captured

# Save and load
scripted_model.save("model_scripted.pt")
loaded = torch.jit.load("model_scripted.pt")

TorchScript Optimization

# Apply optimization passes
scripted = torch.jit.script(model)
optimized = torch.jit.optimize_for_inference(scripted)

# Export for C++ environment
# In C++: torch::jit::script::Module m = torch::jit::load("model.pt");

8. Dynamic Shapes and torch.export

Using torch.export

import torch
from torch.export import export

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)

    def forward(self, x):
        return self.linear(x)

model = MyModel()
example_inputs = (torch.randn(2, 10),)

# Specify dynamic shapes
from torch.export import Dim
batch = Dim("batch", min=1, max=100)

# Export model
exported = export(
    model,
    example_inputs,
    dynamic_shapes={"x": {0: batch}}
)

print(exported)
print(exported.graph_module.code)

# Run ExportedProgram
result = exported.module()(torch.randn(5, 10))
print(f"Result shape: {result.shape}")

9. Custom Dataset and Sampler

IterableDataset

from torch.utils.data import IterableDataset, DataLoader
import torch

class StreamingDataset(IterableDataset):
    """Stream large datasets efficiently"""

    def __init__(self, data_paths, transform=None):
        self.data_paths = data_paths
        self.transform = transform

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()

        if worker_info is None:
            # Single process
            paths = self.data_paths
        else:
            # Multiprocessing: split data across workers
            per_worker = len(self.data_paths) // worker_info.num_workers
            start = worker_info.id * per_worker
            end = start + per_worker
            paths = self.data_paths[start:end]

        for path in paths:
            # Stream data from file
            data = self._load_file(path)
            for sample in data:
                if self.transform:
                    sample = self.transform(sample)
                yield sample

    def _load_file(self, path):
        # In real implementation, load from file
        return [torch.randn(10) for _ in range(100)]

# Use persistent_workers in DataLoader
dataset = StreamingDataset(data_paths=["file1.pt", "file2.pt"])
dataloader = DataLoader(
    dataset,
    batch_size=32,
    num_workers=4,
    persistent_workers=True,  # Reuse worker processes
    pin_memory=True,           # Optimize GPU transfer
    prefetch_factor=2          # Number of batches to prefetch
)

Custom Sampler

from torch.utils.data import Sampler
import numpy as np

class BalancedClassSampler(Sampler):
    """Weighted sampling for class imbalance problems"""

    def __init__(self, dataset, num_samples_per_class=None):
        self.dataset = dataset
        labels = [dataset[i][1] for i in range(len(dataset))]
        self.labels = torch.tensor(labels)

        # Indices per class
        self.class_indices = {}
        for cls in torch.unique(self.labels):
            self.class_indices[cls.item()] = (
                self.labels == cls
            ).nonzero(as_tuple=True)[0].tolist()

        self.num_classes = len(self.class_indices)
        self.num_samples_per_class = (
            num_samples_per_class or
            max(len(v) for v in self.class_indices.values())
        )

    def __iter__(self):
        indices = []
        for cls_idx in self.class_indices.values():
            # Sample equal numbers from each class (with replacement)
            sampled = np.random.choice(
                cls_idx,
                self.num_samples_per_class,
                replace=True
            ).tolist()
            indices.extend(sampled)

        # Shuffle
        np.random.shuffle(indices)
        return iter(indices)

    def __len__(self):
        return self.num_classes * self.num_samples_per_class

10. PyTorch Lightning

Complete LightningModule Example

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

class LightningTransformer(pl.LightningModule):
    def __init__(
        self,
        d_model=512,
        nhead=8,
        num_layers=6,
        num_classes=10,
        learning_rate=1e-4
    ):
        super().__init__()
        self.save_hyperparameters()  # Auto-save HParams

        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=d_model,
                nhead=nhead,
                batch_first=True
            ),
            num_layers=num_layers
        )
        self.classifier = nn.Linear(d_model, num_classes)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        encoded = self.encoder(x)
        # Sequence mean pooling
        pooled = encoded.mean(dim=1)
        return self.classifier(pooled)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()

        # Logging
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()

        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=0.01
        )
        # Cosine annealing scheduler
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=100
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss"
            }
        }


class LightningDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32):
        super().__init__()
        self.batch_size = batch_size

    def setup(self, stage=None):
        # Dummy data
        x = torch.randn(1000, 32, 512)
        y = torch.randint(0, 10, (1000,))
        dataset = TensorDataset(x, y)
        n = len(dataset)
        self.train_ds = torch.utils.data.Subset(dataset, range(int(0.8*n)))
        self.val_ds = torch.utils.data.Subset(dataset, range(int(0.8*n), n))

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=self.batch_size)


# Run training
model = LightningTransformer()
dm = LightningDataModule()

# Setup callbacks
callbacks = [
    ModelCheckpoint(
        monitor="val_loss",
        save_top_k=3,
        mode="min",
        filename="transformer-{epoch:02d}-{val_loss:.2f}"
    ),
    EarlyStopping(monitor="val_loss", patience=10, mode="min")
]

# Trainer
trainer = pl.Trainer(
    max_epochs=100,
    accelerator="auto",    # Auto-detect GPU
    devices="auto",
    callbacks=callbacks,
    logger=TensorBoardLogger("tb_logs", name="transformer"),
    gradient_clip_val=1.0,  # Gradient clipping
    accumulate_grad_batches=4,  # Gradient accumulation
    precision="16-mixed",   # AMP
)

trainer.fit(model, dm)

11. Model Quantization

Dynamic Quantization

import torch
import torch.nn as nn
from torch.ao.quantization import quantize_dynamic

# Inference-only quantization - easiest approach
model = nn.Sequential(
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Linear(256, 128)
)
model.eval()

# Quantize Linear and LSTM layers to int8
quantized_model = quantize_dynamic(
    model,
    {nn.Linear, nn.LSTM},  # Layer types to quantize
    dtype=torch.qint8
)

# Size comparison
import os
torch.save(model.state_dict(), "model_fp32.pt")
torch.save(quantized_model.state_dict(), "model_int8.pt")
fp32_size = os.path.getsize("model_fp32.pt")
int8_size = os.path.getsize("model_int8.pt")
print(f"FP32 size: {fp32_size / 1024:.1f} KB")
print(f"INT8 size: {int8_size / 1024:.1f} KB")
print(f"Compression: {fp32_size / int8_size:.1f}x")

Static Quantization

import torch
from torch.ao.quantization import (
    get_default_qconfig,
    prepare,
    convert,
)

class QuantizableModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = torch.ao.quantization.QuantStub()
        self.linear1 = nn.Linear(512, 256)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(256, 10)
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.dequant(x)
        return x

model = QuantizableModel()
model.eval()

# Quantization config
model.qconfig = get_default_qconfig("fbgemm")  # For x86 CPU

# Prepare (insert observers)
prepared_model = prepare(model)

# Collect statistics with calibration data
with torch.no_grad():
    for _ in range(100):
        calibration_data = torch.randn(32, 512)
        prepared_model(calibration_data)

# Convert to quantized
quantized_static = convert(prepared_model)

# Inference
x = torch.randn(1, 512)
with torch.no_grad():
    output = quantized_static(x)
print(f"Output shape: {output.shape}")

QAT (Quantization-Aware Training)

from torch.ao.quantization import (
    get_default_qat_qconfig,
    prepare_qat,
    convert
)

model = QuantizableModel()
model.train()

# QAT config
model.qconfig = get_default_qat_qconfig("fbgemm")

# Prepare QAT (insert fake quantization nodes)
prepared_qat = prepare_qat(model)

# Train (include quantization error in training)
optimizer = torch.optim.SGD(prepared_qat.parameters(), lr=0.0001)
for epoch in range(10):
    for x, y in dummy_dataloader():
        output = prepared_qat(x)
        loss = nn.functional.cross_entropy(output, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# Convert after training
prepared_qat.eval()
quantized_qat = convert(prepared_qat)


def dummy_dataloader():
    for _ in range(10):
        yield torch.randn(32, 512), torch.randint(0, 10, (32,))

12. Tensor Parallelism and Pipeline Parallelism

DeviceMesh and DTensor API

import torch
import torch.distributed as dist
from torch.distributed.tensor.parallel import (
    parallelize_module,
    ColwiseParallel,
    RowwiseParallel
)
from torch.distributed._tensor import DeviceMesh

# Initialize distributed training
def setup_distributed():
    dist.init_process_group(backend="nccl")
    torch.cuda.set_device(dist.get_rank())

class TransformerMLP(nn.Module):
    def __init__(self, d_model=1024, dim_feedforward=4096):
        super().__init__()
        self.fc1 = nn.Linear(d_model, dim_feedforward)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(dim_feedforward, d_model)

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))

# Apply Tensor Parallelism
def apply_tensor_parallel(model, mesh):
    """
    fc1: Column-wise sharding (split output dimension)
    fc2: Row-wise sharding (split input dimension)
    """
    parallelize_module(
        model,
        mesh,
        {
            "fc1": ColwiseParallel(),
            "fc2": RowwiseParallel(),
        }
    )
    return model

# Example (requires 2-GPU setup)
# device_mesh = DeviceMesh("cuda", [0, 1])
# model = TransformerMLP()
# model = apply_tensor_parallel(model, device_mesh)

FSDP (Fully Sharded Data Parallel)

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    BackwardPrefetch,
    ShardingStrategy
)
from torch.distributed.fsdp.wrap import (
    size_based_auto_wrap_policy,
    transformer_auto_wrap_policy
)
import functools

def setup_fsdp_model(model):
    """FSDP setup - distribute large models across multiple GPUs"""

    # Mixed precision config
    mp_policy = MixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.float32,
        buffer_dtype=torch.bfloat16,
    )

    # Set TransformerBlock as wrapping unit
    auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={TransformerBlock}
    )

    model = FSDP(
        model,
        auto_wrap_policy=auto_wrap_policy,
        mixed_precision=mp_policy,
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        device_id=torch.cuda.current_device()
    )

    return model

Summary

This guide covered PyTorch's advanced techniques:

  1. torch.compile: Over 2x performance improvement without code changes
  2. Custom Autograd: Implementing specialized gradient computation
  3. CUDA Extensions: Integrating GPU kernels with PyTorch
  4. Memory Optimization: Gradient Checkpointing, AMP, 8-bit Optimizer
  5. functorch/vmap: Batch processing and meta-learning via functional API
  6. PyTorch Profiler: Analyzing performance bottlenecks
  7. TorchScript/export: Deployment optimization
  8. PyTorch Lightning: Structuring code and automating training
  9. Quantization: Reducing model size with INT8
  10. Distributed Training: Tensor Parallel, FSDP

These techniques are complementary, and in real projects it is common to combine several of them. For large-scale model training, the combination of AMP + Gradient Checkpointing + FSDP + torch.compile is particularly powerful.

References