Skip to content
Published on

Deep Learning Training Methods Complete Guide: From Optimization to Distributed Training

Authors

Introduction

Over the past decade, deep learning has achieved revolutionary results in virtually every AI domain — computer vision, natural language processing, speech recognition, and reinforcement learning. However, simply designing a neural network architecture is not enough to build a high-performing model. How you train it is the decisive factor.

This guide systematically covers every technique for effectively training deep learning models. Starting from the fundamentals of gradient descent, we progress through advanced optimizers, learning rate scheduling, regularization, transfer learning, mixed precision training, and large-scale distributed training — all with practical code examples.


1. Gradient Descent Fundamentals

1.1 Understanding the Loss Function

In deep learning, the loss function quantifies the discrepancy between model predictions and ground-truth labels. The goal of training is to find model parameters (weights) that minimize this loss value.

The loss function L depends on model parameters theta and data (x, y). Expressed mathematically:

L(theta) = (1/N) * sum_{i=1}^{N} l(f(x_i; theta), y_i)

Here f is the model function, l is the per-sample loss, and N is the dataset size.

1.2 Intuitive Understanding of Gradient Descent

A helpful analogy for gradient descent is a hiker descending a mountain with eyes closed. At each step, the hiker moves in the direction of steepest descent (opposite to the gradient). Repeating this process eventually leads to the valley floor (minimum).

Mathematically, the update rule is:

theta_{t+1} = theta_t - lr * grad_L(theta_t)

Here lr is the learning rate and grad_L is the gradient of the loss function.

1.3 Batch GD vs Mini-batch GD vs SGD

Batch Gradient Descent

  • Computes gradients over the entire dataset
  • Stable but memory-intensive and slow
  • Impractical for large datasets

Stochastic Gradient Descent (SGD)

  • Computes gradients from a single sample
  • Fast but noisy and unstable
  • Suitable for online learning

Mini-batch Gradient Descent

  • Typically uses 32–512 samples per gradient computation
  • Combines advantages of both Batch GD and SGD
  • The most widely used approach in practice
import torch
import torch.nn as nn
import numpy as np

# Gradient descent implementation with simple linear regression
class LinearRegression(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, 1)

    def forward(self, x):
        return self.linear(x)

# Mini-batch gradient descent
def train_minibatch(model, X, y, batch_size=32, lr=0.01, epochs=100):
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    losses = []

    N = len(X)
    for epoch in range(epochs):
        perm = torch.randperm(N)
        X_shuffled = X[perm]
        y_shuffled = y[perm]

        epoch_loss = 0
        for i in range(0, N, batch_size):
            x_batch = X_shuffled[i:i+batch_size]
            y_batch = y_shuffled[i:i+batch_size]

            optimizer.zero_grad()
            pred = model(x_batch)
            loss = criterion(pred, y_batch)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        losses.append(epoch_loss / (N // batch_size))
        if epoch % 10 == 0:
            print(f"Epoch {epoch}: Loss = {losses[-1]:.4f}")

    return losses

torch.manual_seed(42)
X = torch.randn(1000, 10)
true_w = torch.randn(10, 1)
y = X @ true_w + 0.1 * torch.randn(1000, 1)

model = LinearRegression(10)
losses = train_minibatch(model, X, y)

1.4 The Critical Role of Learning Rate

The learning rate is one of the most important hyperparameters in deep learning.

  • Too large: Loss diverges or oscillates around the minimum
  • Too small: Training is extremely slow and may get stuck in local minima
  • Just right: Fast convergence to a good minimum

Common starting values are 0.1, 0.01, and 0.001, though the optimal value depends on network architecture and data.

1.5 Mathematical Derivation (Partial Derivatives, Chain Rule)

Backpropagation in neural networks uses the chain rule to compute gradients for each layer.

For a 3-layer network:

forward: x -> z1=W1*x -> a1=relu(z1) -> z2=W2*a1 -> output
loss: L = MSE(output, y)

backward (chain rule):
dL/dW2 = dL/d_output * d_output/dz2 * dz2/dW2
dL/dW1 = dL/d_output * ... * da1/dz1 * dz1/dW1
import numpy as np

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def sigmoid_deriv(x):
    s = sigmoid(x)
    return s * (1 - s)

class SimpleNet:
    def __init__(self, input_dim, hidden_dim, output_dim):
        self.W1 = np.random.randn(input_dim, hidden_dim) * np.sqrt(2/input_dim)
        self.b1 = np.zeros(hidden_dim)
        self.W2 = np.random.randn(hidden_dim, output_dim) * np.sqrt(2/hidden_dim)
        self.b2 = np.zeros(output_dim)

    def forward(self, x):
        self.x = x
        self.z1 = x @ self.W1 + self.b1
        self.a1 = sigmoid(self.z1)
        self.z2 = self.a1 @ self.W2 + self.b2
        return self.z2

    def backward(self, y, lr=0.01):
        N = len(y)
        dL_dz2 = 2 * (self.z2 - y.reshape(-1, 1)) / N

        dL_dW2 = self.a1.T @ dL_dz2
        dL_db2 = dL_dz2.sum(axis=0)

        dL_da1 = dL_dz2 @ self.W2.T
        dL_dz1 = dL_da1 * sigmoid_deriv(self.z1)

        dL_dW1 = self.x.T @ dL_dz1
        dL_db1 = dL_dz1.sum(axis=0)

        self.W2 -= lr * dL_dW2
        self.b2 -= lr * dL_db2
        self.W1 -= lr * dL_dW1
        self.b1 -= lr * dL_db1

# Test
net = SimpleNet(10, 32, 1)
X_np = np.random.randn(100, 10)
y_np = np.random.randn(100)

for i in range(100):
    pred = net.forward(X_np)
    loss = np.mean((pred.flatten() - y_np) ** 2)
    net.backward(y_np)
    if i % 20 == 0:
        print(f"Step {i}: MSE = {loss:.4f}")

2. Advanced Optimizers

2.1 Momentum SGD

Plain SGD follows gradients directly, causing zigzag movement in narrow valley-shaped loss landscapes. Momentum introduces the physics concept of inertia, allowing the optimizer to remember previous movement directions.

v_t = beta * v_{t-1} + (1 - beta) * grad_t
theta_{t+1} = theta_t - lr * v_t

The momentum coefficient (beta) is typically set to 0.9.

import torch
import torch.optim as optim

# Momentum SGD
optimizer_momentum = optim.SGD(
    model.parameters(),
    lr=0.01,
    momentum=0.9,
    nesterov=False
)

# Nesterov Accelerated Gradient (NAG) - look-ahead gradient
optimizer_nag = optim.SGD(
    model.parameters(),
    lr=0.01,
    momentum=0.9,
    nesterov=True
)

2.2 Adagrad (Adaptive Learning Rate)

Adagrad applies individual learning rates to each parameter. Frequently updated parameters receive reduced learning rates, while rarely updated ones maintain their rates.

G_t = G_{t-1} + grad_t^2
theta_{t+1} = theta_t - (lr / sqrt(G_t + epsilon)) * grad_t

Effective for sparse data, but G_t accumulates indefinitely, causing the learning rate to shrink toward zero.

optimizer_adagrad = optim.Adagrad(
    model.parameters(),
    lr=0.01,
    eps=1e-8,
    weight_decay=0
)

2.3 RMSprop

RMSprop resolves Adagrad's learning rate decay problem by using an exponential moving average of squared gradients.

E[g^2]_t = rho * E[g^2]_{t-1} + (1 - rho) * grad_t^2
theta_{t+1} = theta_t - (lr / sqrt(E[g^2]_t + epsilon)) * grad_t
optimizer_rmsprop = optim.RMSprop(
    model.parameters(),
    lr=0.001,
    alpha=0.99,
    eps=1e-8,
    momentum=0,
    centered=False
)

2.4 Adam (Adaptive Moment Estimation)

Adam combines Momentum and RMSprop, tracking both first-order moments (mean) and second-order moments (variance). It is currently the most widely used optimizer.

The algorithm:

m_t = beta1 * m_{t-1} + (1 - beta1) * g_t       # 1st moment (before bias correction)
v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2     # 2nd moment (before bias correction)

m_hat = m_t / (1 - beta1^t)                       # bias correction
v_hat = v_t / (1 - beta2^t)                       # bias correction

theta_{t+1} = theta_t - lr * m_hat / (sqrt(v_hat) + epsilon)

Default hyperparameters: lr=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8

optimizer_adam = optim.Adam(
    model.parameters(),
    lr=1e-3,
    betas=(0.9, 0.999),
    eps=1e-8,
    weight_decay=0
)

2.5 AdamW (Decoupled Weight Decay)

In standard Adam, L2 regularization is coupled with gradients and thus affected by the adaptive learning rate. AdamW applies weight decay directly to parameter updates, decoupled from the gradient-based update.

theta_{t+1} = theta_t - lr * (m_hat / (sqrt(v_hat) + epsilon) + lambda * theta_t)

AdamW has become the standard for training Transformer models (BERT, GPT, etc.).

optimizer_adamw = optim.AdamW(
    model.parameters(),
    lr=1e-4,
    betas=(0.9, 0.999),
    eps=1e-8,
    weight_decay=0.01
)

2.6 LARS and LAMB (Large-Batch Training)

When using very large batch sizes (thousands), standard Adam degrades in performance. LARS (Layer-wise Adaptive Rate Scaling) and LAMB adjust learning rates per layer.

LARS: lr_l = lr * ||w_l|| / (||g_l|| + lambda * ||w_l||)
LAMB: applies a per-layer trust ratio to the Adam update

2.7 Lion Optimizer (2023)

Google Brain's Lion (EvoLved Sign Momentum) uses only the sign of the gradient update, resulting in lower memory usage than Adam while delivering competitive performance.

class Lion(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0):
        defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
        super().__init__(params, defaults)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                lr = group['lr']
                beta1, beta2 = group['betas']
                wd = group['weight_decay']

                state = self.state[p]
                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)

                exp_avg = state['exp_avg']

                # Lion update
                update = exp_avg * beta1 + grad * (1 - beta1)
                p.data.mul_(1 - lr * wd)
                p.data.add_(update.sign_(), alpha=-lr)

                # Momentum update
                exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)

        return loss

2.8 Optimizer Comparison Experiment

import torch
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        return self.net(x)

def train_and_compare(optimizers_dict, X, y, epochs=200):
    results = {}

    for name, opt_fn in optimizers_dict.items():
        model = MLP()
        optimizer = opt_fn(model.parameters())
        criterion = nn.MSELoss()
        losses = []

        for epoch in range(epochs):
            optimizer.zero_grad()
            pred = model(X)
            loss = criterion(pred, y)
            loss.backward()
            optimizer.step()
            losses.append(loss.item())

        results[name] = losses
        print(f"{name}: Final Loss = {losses[-1]:.4f}")

    return results

X = torch.randn(500, 2)
y = (X[:, 0] * 2 + X[:, 1] * 3 + torch.randn(500) * 0.1).unsqueeze(1)

optimizers = {
    'SGD': lambda p: torch.optim.SGD(p, lr=0.01),
    'SGD+Momentum': lambda p: torch.optim.SGD(p, lr=0.01, momentum=0.9),
    'Adam': lambda p: torch.optim.Adam(p, lr=0.001),
    'AdamW': lambda p: torch.optim.AdamW(p, lr=0.001, weight_decay=0.01),
    'RMSprop': lambda p: torch.optim.RMSprop(p, lr=0.001),
}

results = train_and_compare(optimizers, X, y)

3. Learning Rate Scheduling

A fixed learning rate is rarely optimal. Learning rate scheduling dynamically adjusts the rate during training to achieve faster convergence and better final performance.

3.1 Step and Exponential Decay

import torch
import torch.optim as optim

model = MLP()
optimizer = optim.SGD(model.parameters(), lr=0.1)

# Step Decay: reduce by gamma every step_size epochs
step_scheduler = optim.lr_scheduler.StepLR(
    optimizer,
    step_size=30,
    gamma=0.1
)

# MultiStep Decay: reduce at specified milestones
multistep_scheduler = optim.lr_scheduler.MultiStepLR(
    optimizer,
    milestones=[30, 60, 80],
    gamma=0.1
)

# Exponential Decay: reduce exponentially every epoch
exp_scheduler = optim.lr_scheduler.ExponentialLR(
    optimizer,
    gamma=0.95
)

3.2 Cosine Annealing

Cosine Annealing smoothly decreases the learning rate following a cosine curve. Cosine Annealing with Warm Restarts periodically resets the learning rate for exploration.

# Cosine Annealing
cosine_scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=100,
    eta_min=1e-6
)

# Cosine Annealing with Warm Restarts (SGDR)
cosine_restart = optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer,
    T_0=10,
    T_mult=2,
    eta_min=1e-6
)

3.3 Warmup + Cosine Schedule

The standard schedule for training Transformer models. The learning rate increases linearly during warmup, then decreases following a cosine curve.

import math
from torch.optim.lr_scheduler import LambdaLR

def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(
            max(1, num_training_steps - num_warmup_steps)
        )
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))

    return LambdaLR(optimizer, lr_lambda)

optimizer = optim.AdamW(model.parameters(), lr=5e-5)
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=1000,
    num_training_steps=10000
)

3.4 OneCycleLR

OneCycleLR aggressively ramps the learning rate up and then down for fast convergence. Introduced by Leslie Smith and popularized by FastAI.

optimizer = optim.SGD(model.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=0.1,
    steps_per_epoch=len(train_loader),
    epochs=10,
    pct_start=0.3,
    anneal_strategy='cos',
    div_factor=25.0,
    final_div_factor=1e4
)

for epoch in range(10):
    for batch in train_loader:
        optimizer.zero_grad()
        loss = criterion(model(batch[0]), batch[1])
        loss.backward()
        optimizer.step()
        scheduler.step()  # OneCycleLR steps per batch

3.5 Learning Rate Finder

Automatically identifies an appropriate learning rate range before training.

from torch_lr_finder import LRFinder

model = MLP()
optimizer = optim.SGD(model.parameters(), lr=1e-7, weight_decay=1e-2)
criterion = nn.MSELoss()

lr_finder = LRFinder(model, optimizer, criterion, device="cuda")
lr_finder.range_test(train_loader, end_lr=100, num_iter=100)
lr_finder.plot()
lr_finder.reset()

# Select the LR at the steepest loss decline
# Typically use 1/10 to 1/3 of the value at the minimum

4. Loss Functions

4.1 Regression Loss Functions

import torch
import torch.nn as nn
import torch.nn.functional as F

# MSE - sensitive to outliers
mse_loss = nn.MSELoss()

# MAE - robust to outliers
mae_loss = nn.L1Loss()

# Huber Loss - compromise between MSE and MAE
# |y - y_hat| < delta: 0.5 * (y - y_hat)^2
# |y - y_hat| >= delta: delta * (|y - y_hat| - 0.5 * delta)
huber_loss = nn.HuberLoss(delta=1.0)

def huber_loss_manual(pred, target, delta=1.0):
    residual = torch.abs(pred - target)
    condition = residual < delta
    squared_loss = 0.5 * residual ** 2
    linear_loss = delta * residual - 0.5 * delta ** 2
    return torch.where(condition, squared_loss, linear_loss).mean()

4.2 Classification Loss Functions

# Cross-Entropy Loss (multi-class)
ce_loss = nn.CrossEntropyLoss()

# Binary Cross-Entropy (binary classification)
bce_loss = nn.BCEWithLogitsLoss()

# Label Smoothing Cross-Entropy (reduces overconfidence)
ce_smooth = nn.CrossEntropyLoss(label_smoothing=0.1)

# Focal Loss (addresses class imbalance)
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=None, reduction='mean'):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss

        if self.alpha is not None:
            alpha_t = self.alpha[targets]
            focal_loss = alpha_t * focal_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        return focal_loss

4.3 Segmentation Loss Functions

def bce_loss_fn(pred, target):
    return F.binary_cross_entropy_with_logits(pred, target)

# Dice Loss (robust to class imbalance)
def dice_loss(pred, target, smooth=1.0):
    pred = torch.sigmoid(pred)
    pred_flat = pred.view(-1)
    target_flat = target.view(-1)

    intersection = (pred_flat * target_flat).sum()
    dice = (2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)
    return 1 - dice

# BCE + Dice combination (common in segmentation)
def bce_dice_loss(pred, target, bce_weight=0.5):
    bce = bce_loss_fn(pred, target)
    dice = dice_loss(pred, target)
    return bce_weight * bce + (1 - bce_weight) * dice

4.4 Metric Learning Loss Functions

# Contrastive Loss (pull similar pairs together, push dissimilar apart)
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super().__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        # label=1: same class, label=0: different class
        euclidean_dist = F.pairwise_distance(output1, output2)
        loss = (label * euclidean_dist.pow(2) +
                (1 - label) * F.relu(self.margin - euclidean_dist).pow(2))
        return loss.mean()

# Triplet Loss (anchor, positive, negative)
class TripletLoss(nn.Module):
    def __init__(self, margin=0.3):
        super().__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        pos_dist = F.pairwise_distance(anchor, positive)
        neg_dist = F.pairwise_distance(anchor, negative)
        loss = F.relu(pos_dist - neg_dist + self.margin)
        return loss.mean()

5. Regularization Techniques

Techniques to prevent overfitting and improve generalization.

5.1 L1/L2 Regularization

# L2 Regularization (Weight Decay)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

# L1 Regularization (manual implementation)
def l1_regularization(model, lambda_l1):
    l1_penalty = 0
    for param in model.parameters():
        l1_penalty += torch.abs(param).sum()
    return lambda_l1 * l1_penalty

# Elastic Net (L1 + L2)
def elastic_net_loss(model, criterion, outputs, targets, lambda_l1=1e-5, lambda_l2=1e-4):
    base_loss = criterion(outputs, targets)
    l1_penalty = sum(torch.abs(p).sum() for p in model.parameters())
    l2_penalty = sum((p ** 2).sum() for p in model.parameters())
    return base_loss + lambda_l1 * l1_penalty + lambda_l2 * l2_penalty

5.2 Dropout

Dropout randomly deactivates neurons during training to prevent co-adaptation. Inverted Dropout divides by the keep probability during training, so no scaling is needed at inference.

class ModelWithDropout(nn.Module):
    def __init__(self, dropout_rate=0.5):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(784, 512),
            nn.ReLU(),
            nn.Dropout(p=dropout_rate),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(p=dropout_rate),
            nn.Linear(256, 10)
        )

    def forward(self, x):
        return self.net(x)

# Training mode: dropout is active
model.train()

# Inference mode: dropout is disabled
model.eval()

5.3 Data Augmentation

from torchvision import transforms

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomCrop(32, padding=4),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(degrees=15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Mixup Augmentation
def mixup_data(x, y, alpha=1.0):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

# CutMix Augmentation
def cutmix_data(x, y, alpha=1.0):
    lam = np.random.beta(alpha, alpha)
    batch_size, C, H, W = x.size()
    index = torch.randperm(batch_size)

    cut_ratio = np.sqrt(1. - lam)
    cut_w = int(W * cut_ratio)
    cut_h = int(H * cut_ratio)

    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    mixed_x = x.clone()
    mixed_x[:, :, bby1:bby2, bbx1:bbx2] = x[index, :, bby1:bby2, bbx1:bbx2]
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (W * H))

    return mixed_x, y, y[index], lam

5.4 Early Stopping

class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.001, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.counter = 0
        self.best_loss = None
        self.best_weights = None
        self.early_stop = False

    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_weights = {k: v.clone() for k, v in model.state_dict().items()}
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            print(f"EarlyStopping counter: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.best_weights = {k: v.clone() for k, v in model.state_dict().items()}
            self.counter = 0

    def restore(self, model):
        if self.restore_best_weights and self.best_weights:
            model.load_state_dict(self.best_weights)
            print("Restored best model weights")

6. Normalization Layers

6.1 Batch Normalization

Proposed by Sergey Ioffe and Christian Szegedy in 2015, Batch Normalization normalizes features within each mini-batch to address the internal covariate shift problem.

The process:

1. Mini-batch mean: mu_B = (1/m) * sum(x_i)
2. Mini-batch variance: sigma_B^2 = (1/m) * sum((x_i - mu_B)^2)
3. Normalize: x_hat_i = (x_i - mu_B) / sqrt(sigma_B^2 + epsilon)
4. Scale and shift: y_i = gamma * x_hat_i + beta

gamma (scale) and beta (shift) are learnable parameters.

import torch
import torch.nn as nn

class BatchNormNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(784, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )

    def forward(self, x):
        return self.net(x)

# Manual BatchNorm implementation
class BatchNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        self.eps = eps
        self.momentum = momentum

        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

    def forward(self, x):
        if self.training:
            mean = x.mean(dim=0)
            var = x.var(dim=0, unbiased=False)
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        else:
            mean = self.running_mean
            var = self.running_var

        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma * x_norm + self.beta

6.2 Layer Normalization (Transformer Standard)

Layer Normalization normalizes across the feature dimension rather than the batch dimension. It is independent of batch size, making it suitable for RNNs and Transformers.

class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-5):
        super().__init__()
        if isinstance(normalized_shape, int):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = normalized_shape
        self.gamma = nn.Parameter(torch.ones(normalized_shape))
        self.beta = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma * x_norm + self.beta

# Transformer Block with Pre-LayerNorm (modern GPT-style)
class TransformerBlock(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, nhead)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.GELU(),
            nn.Linear(dim_feedforward, d_model)
        )

    def forward(self, x):
        attn_out, _ = self.attention(self.norm1(x), self.norm1(x), self.norm1(x))
        x = x + attn_out
        x = x + self.ffn(self.norm2(x))
        return x

6.3 Instance, Group, and RMS Normalization

# Instance Normalization (per-sample, per-channel)
# Effective for style transfer
instance_norm = nn.InstanceNorm2d(64)

# Group Normalization (normalize within channel groups)
# Alternative to BN when batch size is small
group_norm = nn.GroupNorm(num_groups=8, num_channels=64)

# RMS Normalization (used in LLaMA, T5)
# Removes mean centering from LayerNorm for speed
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        return self.weight * self._norm(x.float()).type_as(x)

# Summary of when to use each normalization:
# BatchNorm: CNN, batch-dependent, best with batch size >= 16
# LayerNorm: Transformer/RNN, batch-independent
# InstanceNorm: style transfer, per-sample per-channel
# GroupNorm: small batches, detection/segmentation
# RMSNorm: LLMs, lightweight LayerNorm alternative

7. Weight Initialization

7.1 Xavier/He Initialization

Weight initialization sets the starting point for training. Poor initialization can trigger vanishing or exploding gradients.

import torch
import torch.nn as nn

class WeightInitDemo(nn.Module):
    def __init__(self, init_method='xavier'):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(256, 256) for _ in range(5)
        ])
        self.apply_init(init_method)

    def apply_init(self, method):
        for layer in self.layers:
            if method == 'zeros':
                nn.init.zeros_(layer.weight)        # Bad: symmetry problem
            elif method == 'random_small':
                nn.init.normal_(layer.weight, std=0.01)
            elif method == 'xavier_uniform':
                nn.init.xavier_uniform_(layer.weight)  # For sigmoid/tanh
            elif method == 'xavier_normal':
                nn.init.xavier_normal_(layer.weight)
            elif method == 'kaiming_uniform':
                nn.init.kaiming_uniform_(layer.weight, mode='fan_in', nonlinearity='relu')
            elif method == 'kaiming_normal':
                nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')  # For ReLU
            nn.init.zeros_(layer.bias)

    def forward(self, x):
        for layer in self.layers:
            x = torch.relu(layer(x))
        return x

# Initialization comparison
x = torch.randn(100, 256)
for method in ['zeros', 'random_small', 'xavier_uniform', 'kaiming_normal']:
    model = WeightInitDemo(method)
    with torch.no_grad():
        out = model(x)
    print(f"{method}: output mean={out.mean():.4f}, std={out.std():.4f}")

Xavier/Glorot Initialization is designed for sigmoid/tanh activations:

  • Uniform: weights drawn from Uniform(-limit, +limit) where limit = sqrt(6 / (fan_in + fan_out))

He/Kaiming Initialization is designed for ReLU activations:

  • Normal: weights drawn from Normal(0, sqrt(2 / fan_in))

8. Gradient Problem Solutions

8.1 Vanishing and Exploding Gradients

Vanishing Gradient: Gradients shrink toward zero as they propagate back through layers, preventing early layers from learning. Common with sigmoid and tanh activations in deep networks.

Exploding Gradient: Gradients grow exponentially, causing NaN or Inf values. Common in RNNs with long sequences.

import torch.nn.utils as utils

# Method 1: Gradient norm clipping
max_norm = 1.0
total_norm = utils.clip_grad_norm_(model.parameters(), max_norm)
print(f"Gradient norm: {total_norm:.4f}")

# Method 2: Gradient value clipping
utils.clip_grad_value_(model.parameters(), clip_value=0.5)

# Usage in training loop
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for batch in train_loader:
    optimizer.zero_grad()
    loss = criterion(model(batch[0]), batch[1])
    loss.backward()

    # Clip after backward, before optimizer step
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    optimizer.step()

8.2 Residual Connections (Skip Connections)

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)  # Skip Connection
        out = self.relu(out)
        return out

8.3 Gradient Checkpointing

For very deep models, trade compute for memory: discard intermediate activations and recompute them during the backward pass.

from torch.utils.checkpoint import checkpoint, checkpoint_sequential

class DeepModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(*[
            nn.Sequential(nn.Linear(512, 512), nn.ReLU())
            for _ in range(20)
        ])

    def forward(self, x):
        # Standard: stores all activations O(N) memory
        # return self.layers(x)

        # Gradient Checkpointing: O(sqrt(N)) memory
        return checkpoint_sequential(self.layers, segments=5, input=x)

9. Transfer Learning and Fine-tuning

9.1 Feature Extraction vs Fine-tuning

import torchvision.models as models

# Feature Extraction: freeze pretrained weights
def feature_extraction(num_classes):
    model = models.resnet50(pretrained=True)
    for param in model.parameters():
        param.requires_grad = False

    # Replace only the classifier head
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

# Fine-tuning: selectively unfreeze layers
def fine_tuning(num_classes, unfreeze_layers=None):
    model = models.resnet50(pretrained=True)
    for param in model.parameters():
        param.requires_grad = False

    model.fc = nn.Linear(model.fc.in_features, num_classes)

    if unfreeze_layers:
        for name, param in model.named_parameters():
            for layer in unfreeze_layers:
                if layer in name:
                    param.requires_grad = True

    return model

9.2 Progressive Unfreezing and Discriminative Learning Rates

def discriminative_lr_optimizer(model, base_lr=1e-4, lr_multiplier=10):
    # Assign lower LR to early layers, higher LR to later layers
    param_groups = [
        {'params': model.layer1.parameters(), 'lr': base_lr / (lr_multiplier**3)},
        {'params': model.layer2.parameters(), 'lr': base_lr / (lr_multiplier**2)},
        {'params': model.layer3.parameters(), 'lr': base_lr / lr_multiplier},
        {'params': model.layer4.parameters(), 'lr': base_lr},
        {'params': model.fc.parameters(), 'lr': base_lr * lr_multiplier},
    ]
    return torch.optim.Adam(param_groups)

9.3 LoRA (Low-Rank Adaptation)

LoRA is a parameter-efficient fine-tuning technique for large language models. It freezes the original weight matrices and learns a low-rank decomposition.

For an original weight matrix W with shape d by k, LoRA learns W' = W + BA, where B has shape d by r and A has shape r by k. The rank r is chosen to be much smaller than both d and k.

class LoRALayer(nn.Module):
    def __init__(self, in_features, out_features, rank=4, alpha=1.0):
        super().__init__()
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank

        # Frozen original weights
        self.weight = nn.Parameter(
            torch.randn(out_features, in_features),
            requires_grad=False
        )

        # LoRA matrix A (random init)
        self.lora_A = nn.Parameter(torch.randn(rank, in_features) * 0.01)
        # LoRA matrix B (zero init -> identical to original at start)
        self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
        self.bias = nn.Parameter(torch.zeros(out_features))

    def forward(self, x):
        base_output = nn.functional.linear(x, self.weight, self.bias)
        lora_output = (x @ self.lora_A.T @ self.lora_B.T) * self.scaling
        return base_output + lora_output

# Using HuggingFace PEFT library
from peft import get_peft_model, LoraConfig, TaskType

lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none"
)

# peft_model = get_peft_model(base_model, lora_config)
# peft_model.print_trainable_parameters()
# trainable params: 4,194,304 || all params: 6,742,609,920 || trainable%: 0.062

10. Hyperparameter Tuning

10.1 Bayesian Optimization with Optuna

import optuna
import torch
import torch.nn as nn

def objective(trial):
    lr = trial.suggest_float('lr', 1e-5, 1e-1, log=True)
    n_layers = trial.suggest_int('n_layers', 1, 5)
    n_units = trial.suggest_categorical('n_units', [64, 128, 256, 512])
    dropout_rate = trial.suggest_float('dropout', 0.0, 0.5)
    optimizer_name = trial.suggest_categorical('optimizer', ['Adam', 'AdamW', 'SGD'])

    layers = []
    in_dim = 784
    for _ in range(n_layers):
        layers.extend([
            nn.Linear(in_dim, n_units),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        ])
        in_dim = n_units
    layers.append(nn.Linear(in_dim, 10))
    model = nn.Sequential(*layers)

    if optimizer_name == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    elif optimizer_name == 'AdamW':
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    val_accuracy = 0.95  # replace with actual training
    return val_accuracy

study = optuna.create_study(
    direction='maximize',
    sampler=optuna.samplers.TPESampler(),
    pruner=optuna.pruners.MedianPruner()
)

study.optimize(objective, n_trials=100, timeout=3600)

print(f"Best trial: {study.best_trial.value:.4f}")
print(f"Best params: {study.best_trial.params}")

11. Mixed Precision Training

11.1 FP32 vs FP16 vs BF16

FormatExponent bitsMantissa bitsRangePrimary use
FP32823+-3.4e38Default training
FP16510+-65504Inference / training (overflow risk)
BF1687+-3.4e38LLM training (A100, TPU)

11.2 PyTorch AMP (Automatic Mixed Precision)

import torch
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

model = MyModel().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

for epoch in range(num_epochs):
    for inputs, labels in train_loader:
        inputs, labels = inputs.cuda(), labels.cuda()
        optimizer.zero_grad()

        # FP16 computation within autocast context
        with autocast(dtype=torch.float16):
            outputs = model(inputs)
            loss = criterion(outputs, labels)

        # Scale loss, then backprop
        scaler.scale(loss).backward()

        # Unscale for gradient clipping
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        # Step (skips if NaN/Inf gradients detected)
        scaler.step(optimizer)
        scaler.update()

# BF16 (more stable, requires Ampere or newer GPU)
with autocast(dtype=torch.bfloat16):
    outputs = model(inputs)
    loss = criterion(outputs, labels)

Key benefits of AMP:

  • Memory reduction: roughly 2x (FP16 parameters)
  • Speed improvement: 1.5x–3x on Tensor Core GPUs
  • Near-identical accuracy to FP32 in most tasks

12. Distributed Training

12.1 Data Parallelism with DDP

Distribute data across multiple GPUs. Each GPU independently computes forward and backward passes, then gradients are aggregated.

import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler

def setup(rank, world_size):
    import os
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group('nccl', rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def train_ddp(rank, world_size, model_class, dataset):
    setup(rank, world_size)

    device = torch.device(f'cuda:{rank}')
    model = model_class().to(device)
    model = DDP(model, device_ids=[rank])

    sampler = DistributedSampler(
        dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True
    )

    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=32,
        sampler=sampler,
        num_workers=4,
        pin_memory=True
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3 * world_size)
    criterion = nn.CrossEntropyLoss()
    scaler = GradScaler()

    for epoch in range(num_epochs):
        sampler.set_epoch(epoch)  # Different shuffle per epoch

        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        if rank == 0:
            print(f"Epoch {epoch}: Loss = {loss.item():.4f}")

    cleanup()

import torch.multiprocessing as mp

if __name__ == '__main__':
    world_size = torch.cuda.device_count()
    mp.spawn(train_ddp, args=(world_size, MyModel, dataset), nprocs=world_size, join=True)

12.2 FSDP (Fully Sharded Data Parallel)

FSDP shards model parameters, gradients, and optimizer states across all GPUs. Essential for training models with billions of parameters that exceed single-GPU memory.

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy, MixedPrecision
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import functools

bf16_policy = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.bfloat16,
    buffer_dtype=torch.bfloat16
)

auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={TransformerBlock}
)

model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    mixed_precision=bf16_policy,
    auto_wrap_policy=auto_wrap_policy,
    device_id=rank
)

FSDP sharding strategies:

  • FULL_SHARD: shard params, gradients, and optimizer states (maximum memory savings)
  • SHARD_GRAD_OP: shard gradients and optimizer states only
  • NO_SHARD: equivalent to DDP

12.3 Gradient Accumulation

Simulate large batch sizes on limited GPU memory by accumulating gradients across multiple micro-batches.

model = MyModel().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
scaler = GradScaler()

micro_batch_size = 8
accumulation_steps = 8  # Effective batch size: 64

optimizer.zero_grad()
for step, (inputs, labels) in enumerate(train_loader):
    inputs, labels = inputs.cuda(), labels.cuda()

    with autocast():
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss = loss / accumulation_steps  # Normalize loss

    scaler.scale(loss).backward()

    if (step + 1) % accumulation_steps == 0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

13. Large Language Model Training Techniques

13.1 Instruction Tuning

Instruction tuning trains models to follow natural language instructions. It was central to the success of FLAN, InstructGPT, and LLaMA-2.

# Instruction tuning data format
instruction_data = [
    {
        "instruction": "Analyze the sentiment of the following text.",
        "input": "The weather today is absolutely gorgeous, I feel wonderful!",
        "output": "Positive sentiment. The text expresses satisfaction with the weather and a feeling of happiness."
    },
    {
        "instruction": "Summarize the following article.",
        "input": "...(long text)...",
        "output": "...(summary)..."
    }
]

# Alpaca-style prompt template
def format_instruction(sample):
    if sample.get('input'):
        return f"""### Instruction:
{sample['instruction']}

### Input:
{sample['input']}

### Response:
{sample['output']}"""
    else:
        return f"""### Instruction:
{sample['instruction']}

### Response:
{sample['output']}"""

13.2 RLHF (Reinforcement Learning from Human Feedback)

RLHF involves three stages:

Stage 1 — SFT (Supervised Fine-tuning): fine-tune on high-quality human demonstrations Stage 2 — Reward Model: train a reward model to predict human preferences Stage 3 — PPO: optimize the policy with RL using the reward model

# Stage 2: Reward Model (Bradley-Terry preference model)
class RewardModel(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model
        self.reward_head = nn.Linear(base_model.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden = outputs.last_hidden_state[:, -1, :]
        reward = self.reward_head(last_hidden).squeeze(-1)
        return reward

# Preference loss (Bradley-Terry model)
def preference_loss(reward_chosen, reward_rejected):
    # p(chosen > rejected) = sigmoid(r_chosen - r_rejected)
    return -torch.log(torch.sigmoid(reward_chosen - reward_rejected)).mean()

13.3 DPO (Direct Preference Optimization)

DPO simplifies RLHF by eliminating the need for PPO, directly optimizing the policy on preference data using a closed-form reparameterization.

import torch
import torch.nn.functional as F

def dpo_loss(
    policy_chosen_logps,
    policy_rejected_logps,
    reference_chosen_logps,
    reference_rejected_logps,
    beta=0.1
):
    # Log ratios between policy and reference model
    chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps)
    rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps)

    # DPO loss: -log(sigmoid(chosen_rewards - rejected_rewards))
    loss = -F.logsigmoid(chosen_rewards - rejected_rewards).mean()

    chosen_reward = chosen_rewards.detach().mean()
    rejected_reward = rejected_rewards.detach().mean()
    reward_accuracy = (chosen_rewards > rejected_rewards).float().mean()

    return loss, chosen_reward, rejected_reward, reward_accuracy

DPO advantages over RLHF:

  • No need to train a separate reward model
  • No PPO hyperparameter tuning
  • More stable training
  • Comparable or better alignment results

14. Complete Training Pipeline

14.1 Production-Grade Trainer

import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler

class Trainer:
    def __init__(
        self,
        model,
        train_loader,
        val_loader,
        optimizer,
        scheduler,
        criterion,
        device='cuda',
        use_amp=True,
        grad_clip=1.0,
        accumulation_steps=1
    ):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.criterion = criterion
        self.device = device
        self.use_amp = use_amp
        self.grad_clip = grad_clip
        self.accumulation_steps = accumulation_steps
        self.scaler = GradScaler() if use_amp else None

    def train_epoch(self):
        self.model.train()
        total_loss = 0
        self.optimizer.zero_grad()

        for step, (inputs, labels) in enumerate(self.train_loader):
            inputs, labels = inputs.to(self.device), labels.to(self.device)

            if self.use_amp:
                with autocast():
                    outputs = self.model(inputs)
                    loss = self.criterion(outputs, labels) / self.accumulation_steps
                self.scaler.scale(loss).backward()
            else:
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels) / self.accumulation_steps
                loss.backward()

            if (step + 1) % self.accumulation_steps == 0:
                if self.use_amp:
                    self.scaler.unscale_(self.optimizer)

                if self.grad_clip:
                    nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)

                if self.use_amp:
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                else:
                    self.optimizer.step()

                if self.scheduler:
                    self.scheduler.step()

                self.optimizer.zero_grad()

            total_loss += loss.item() * self.accumulation_steps

        return total_loss / len(self.train_loader)

    @torch.no_grad()
    def evaluate(self):
        self.model.eval()
        total_loss = 0
        correct = 0
        total = 0

        for inputs, labels in self.val_loader:
            inputs, labels = inputs.to(self.device), labels.to(self.device)

            with autocast() if self.use_amp else torch.no_grad():
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)

            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        return total_loss / len(self.val_loader), 100. * correct / total

    def fit(self, epochs, save_path=None):
        best_val_acc = 0
        early_stopping = EarlyStopping(patience=10)

        for epoch in range(epochs):
            train_loss = self.train_epoch()
            val_loss, val_acc = self.evaluate()

            print(f"Epoch {epoch+1}/{epochs}: "
                  f"Train Loss: {train_loss:.4f}, "
                  f"Val Loss: {val_loss:.4f}, "
                  f"Val Acc: {val_acc:.2f}%")

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                if save_path:
                    torch.save(self.model.state_dict(), save_path)

            early_stopping(val_loss, self.model)
            if early_stopping.early_stop:
                print("Early stopping triggered!")
                break

        return best_val_acc

Conclusion and Best Practices

A summary of core principles for effective deep learning training:

Optimizer selection

  • General tasks: AdamW (lr=1e-3 to 1e-4, weight_decay=0.01)
  • Transformers: AdamW + Warmup + Cosine Schedule
  • Large-batch training: LAMB or LARS
  • Memory-constrained: Lion

Regularization strategy

  • Dropout is typically set between 0.1 and 0.5
  • Small datasets: stronger regularization (larger weight decay, higher dropout)
  • Large datasets: weak or no regularization

Learning rate scheduling

  • CNNs: OneCycleLR or Step Decay
  • Transformers: Warmup + Cosine or Inverse Square Root

Mixed precision

  • Always use AMP (1.5x–3x speedup, 2x memory savings)
  • A100/H100 and newer: prefer BF16
  • Older GPUs: use FP16 + Loss Scaling

Distributed training

  • Multi-GPU single server: DDP + NCCL
  • Billion-parameter models: FSDP
  • Always use Gradient Accumulation to increase effective batch size

References