Skip to content
Published on

Mathematical Optimization for Machine Learning: From Adam to Convex Optimization and ZeRO

Authors

Table of Contents

  1. Optimization Fundamentals: Convex Optimization and KKT Conditions
  2. Gradient Descent Family of Optimizers
  3. Second-Order Optimization Methods
  4. Learning Rate Scheduling
  5. Regularization Techniques
  6. Loss Function Design
  7. LLM Training Optimization
  8. Quiz

Optimization Fundamentals

Convex Optimization

A function f:RnRf: \mathbb{R}^n \to \mathbb{R} is convex if for any two points x,yx, y and λ[0,1]\lambda \in [0,1]:

f(λx+(1λ)y)λf(x)+(1λ)f(y)f(\lambda x + (1-\lambda)y) \leq \lambda f(x) + (1-\lambda)f(y)

Key properties of convex functions:

  • Every local minimum is a global minimum
  • Gradient descent is guaranteed to converge
  • Deep learning loss surfaces are mostly non-convex, but convex analysis techniques remain useful

Strongly Convex: If there exists m>0m > 0 such that f(y)f(x)+f(x)T(yx)+m2yx2f(y) \geq f(x) + \nabla f(x)^T(y-x) + \frac{m}{2}\|y-x\|^2, convergence is linear (exponentially fast).

Lagrange Multipliers

Handles equality-constrained optimization problems:

minxf(x)subject togi(x)=0,  i=1,,m\min_x f(x) \quad \text{subject to} \quad g_i(x) = 0, \; i = 1, \ldots, m

Lagrangian:

L(x,λ)=f(x)+i=1mλigi(x)\mathcal{L}(x, \lambda) = f(x) + \sum_{i=1}^{m} \lambda_i g_i(x)

At the optimum: xL=0\nabla_x \mathcal{L} = 0 and λL=0\nabla_\lambda \mathcal{L} = 0.

KKT Conditions

For the general constrained optimization problem:

minxf(x)s.t.gi(x)0,  hj(x)=0\min_x f(x) \quad \text{s.t.} \quad g_i(x) \leq 0, \; h_j(x) = 0

KKT necessary conditions:

  1. Stationarity: f(x)+iμigi(x)+jλjhj(x)=0\nabla f(x^*) + \sum_i \mu_i \nabla g_i(x^*) + \sum_j \lambda_j \nabla h_j(x^*) = 0
  2. Primal feasibility: gi(x)0g_i(x^*) \leq 0, hj(x)=0h_j(x^*) = 0
  3. Dual feasibility: μi0\mu_i \geq 0
  4. Complementary slackness: μigi(x)=0\mu_i g_i(x^*) = 0

For convex problems, KKT conditions are also sufficient.

Saddle Points

In deep learning, saddle points are a greater concern than local minima. At a saddle point, the gradient is zero but it is neither a local min nor max. The stochastic noise in SGD helps escape saddle points.


Gradient Descent Family

SGD and Its Variants

Vanilla SGD:

θt+1=θtηθL(θt;xi,yi)\theta_{t+1} = \theta_t - \eta \nabla_\theta \mathcal{L}(\theta_t; x_i, y_i)

SGD with Momentum:

vt+1=βvt+θL(θt)v_{t+1} = \beta v_t + \nabla_\theta \mathcal{L}(\theta_t) θt+1=θtηvt+1\theta_{t+1} = \theta_t - \eta v_{t+1}

Momentum β=0.9\beta = 0.9 is standard; it accumulates past gradients to reduce oscillation.

Nesterov Accelerated Gradient (NAG):

vt+1=βvt+θL(θtβvt)v_{t+1} = \beta v_t + \nabla_\theta \mathcal{L}(\theta_t - \beta v_t) θt+1=θtηvt+1\theta_{t+1} = \theta_t - \eta v_{t+1}

Computes the gradient at a "lookahead" position rather than the current one.

AdaGrad, RMSProp, Adam

AdaGrad: Per-parameter adaptive learning rate

Gt=Gt1+gt2G_t = G_{t-1} + g_t^2 θt+1=θtηGt+ϵgt\theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{G_t + \epsilon}} g_t

Frequent features get smaller updates; rare features get larger updates. Drawback: monotonically shrinking learning rates cause learning to stall.

RMSProp: Fixes AdaGrad's accumulation problem

vt=βvt1+(1β)gt2v_t = \beta v_{t-1} + (1-\beta) g_t^2 θt+1=θtηvt+ϵgt\theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{v_t + \epsilon}} g_t

Adam (Adaptive Moment Estimation):

mt=β1mt1+(1β1)gt(1st moment)m_t = \beta_1 m_{t-1} + (1-\beta_1) g_t \quad \text{(1st moment)} vt=β2vt1+(1β2)gt2(2nd moment)v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t^2 \quad \text{(2nd moment)}

Bias correction:

m^t=mt1β1t,v^t=vt1β2t\hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t}

θt+1=θtηv^t+ϵm^t\theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t

Default hyperparameters: β1=0.9\beta_1 = 0.9, β2=0.999\beta_2 = 0.999, ϵ=108\epsilon = 10^{-8}

import torch
import torch.optim as optim

model = ...  # define your model

# Standard Adam
optimizer_adam = optim.Adam(
    model.parameters(),
    lr=1e-3,
    betas=(0.9, 0.999),
    eps=1e-8
)

# AdamW (decoupled weight decay)
optimizer_adamw = optim.AdamW(
    model.parameters(),
    lr=1e-3,
    betas=(0.9, 0.999),
    eps=1e-8,
    weight_decay=0.01  # applied independently of gradient scaling
)

AdamW and Lion

AdamW: Applies weight decay directly to parameter updates, separate from the gradient-based term.

θt+1=θtη(m^tv^t+ϵ+λθt)\theta_{t+1} = \theta_t - \eta \left(\frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda \theta_t\right)

This is mathematically inequivalent to adding L2 regularization inside Adam (see Quiz for details).

Lion (EvoLved Sign Momentum):

ut=β1mt1+(1β1)gtu_t = \beta_1 m_{t-1} + (1-\beta_1) g_t θt+1=θtηsign(ut)\theta_{t+1} = \theta_t - \eta \cdot \text{sign}(u_t) mt=β2mt1+(1β2)gtm_t = \beta_2 m_{t-1} + (1-\beta_2) g_t

Lion uses only the sign of the update, providing uniform update magnitude and better memory efficiency.

OptimizerMemoryConvergenceBest Use Case
SGD+MomentumLowSlowComputer vision, large batch
AdamMediumFastNLP, general purpose
AdamWMediumFastTransformer training
LionLowFastLarge-scale models
L-BFGSHighVery fastSmall models

Second-Order Optimization

Newton's Method

Uses second-order derivatives (Hessian):

θt+1=θtHt1f(θt)\theta_{t+1} = \theta_t - H_t^{-1} \nabla f(\theta_t)

where Ht=2f(θt)H_t = \nabla^2 f(\theta_t) is the Hessian matrix. Achieves quadratic convergence, but inverting an n×nn \times n matrix requires O(n3)O(n^3) computation — impractical for deep learning.

L-BFGS (Limited-memory BFGS)

Approximates the inverse Hessian using the last mm gradient differences, without storing the full matrix.

Ht1approximation via vector sequences {sk},{yk}H_t^{-1} \approx \text{approximation via vector sequences } \{s_k\}, \{y_k\}

where sk=θk+1θks_k = \theta_{k+1} - \theta_k and yk=fk+1fky_k = \nabla f_{k+1} - \nabla f_k.

import torch
import torch.optim as optim

# L-BFGS requires a closure function
optimizer = optim.LBFGS(
    model.parameters(),
    lr=1.0,
    max_iter=20,
    history_size=10,
    line_search_fn='strong_wolfe'
)

def closure():
    optimizer.zero_grad()
    output = model(input_data)
    loss = criterion(output, target)
    loss.backward()
    return loss

optimizer.step(closure)

Natural Gradient Descent

Uses the Fisher Information Matrix to account for the curvature of the parameter space:

θt+1=θtηF(θt)1L(θt)\theta_{t+1} = \theta_t - \eta F(\theta_t)^{-1} \nabla \mathcal{L}(\theta_t)

Fisher Matrix: F(θ)=E[logp(yx;θ)logp(yx;θ)T]F(\theta) = \mathbb{E}\left[\nabla \log p(y|x;\theta) \nabla \log p(y|x;\theta)^T\right]

K-FAC (Kronecker-factored Approximate Curvature) provides a practical implementation by factoring the Fisher matrix layer-wise.


Learning Rate Scheduling

Linear Warmup

Gradually increases the learning rate at the start to stabilize training:

ηt=ηmaxtTwarmup(tTwarmup)\eta_t = \eta_{\max} \cdot \frac{t}{T_{\text{warmup}}} \quad (t \leq T_{\text{warmup}})

Cosine Annealing

ηt=ηmin+12(ηmaxηmin)(1+cosπtT)\eta_t = \eta_{\min} + \frac{1}{2}(\eta_{\max} - \eta_{\min})\left(1 + \cos\frac{\pi t}{T}\right)

from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR, ReduceLROnPlateau

# Cosine Annealing
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)

# OneCycleLR: Warmup + cosine decay
scheduler = OneCycleLR(
    optimizer,
    max_lr=1e-3,
    total_steps=1000,
    pct_start=0.3,         # 30% warmup phase
    anneal_strategy='cos'
)

# ReduceLROnPlateau: reduce when validation loss stalls
scheduler = ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=10,
    min_lr=1e-6
)

Cyclical Learning Rate (CLR)

Periodically varies the learning rate to help escape saddle points:

ηt=ηmin+(ηmaxηmin)max(0,1tstep_size2k+1)\eta_t = \eta_{\min} + (\eta_{\max} - \eta_{\min}) \cdot \max\left(0, 1 - \left|\frac{t}{\text{step\_size}} - 2k + 1\right|\right)

SchedulerCharacteristicBest Use Case
Cosine AnnealingSmooth decayTransformer pretraining
OneCycleLRWarmup + fast decayFine-tuning, short runs
ReduceLROnPlateauAdaptiveGeneral training
Cyclical LRPeriodic oscillationSaddle point escape
Linear WarmupInitial stabilizationLLM training

Regularization Techniques

L1 / L2 Regularization

L2 Regularization (Ridge):

Lreg=L+λ2θ22\mathcal{L}_{\text{reg}} = \mathcal{L} + \frac{\lambda}{2} \|\theta\|_2^2

Gradient: θLreg=θL+λθ\nabla_\theta \mathcal{L}_{\text{reg}} = \nabla_\theta \mathcal{L} + \lambda \theta

L1 Regularization (Lasso):

Lreg=L+λθ1\mathcal{L}_{\text{reg}} = \mathcal{L} + \lambda \|\theta\|_1

L1 induces sparse solutions, driving many weights to exactly zero.

Batch Normalization vs Layer Normalization

Batch Normalization (BN):

x^i=xiμBσB2+ϵγ+β\hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}} \cdot \gamma + \beta

where μB\mu_B, σB2\sigma_B^2 are computed across the mini-batch dimension. Normalizes along the batch axis.

Layer Normalization (LN):

x^=xμLσL2+ϵγ+β\hat{x} = \frac{x - \mu_L}{\sqrt{\sigma_L^2 + \epsilon}} \cdot \gamma + \beta

Statistics are computed over the feature dimension of each individual sample.

NormalizationStatistic AxisBest Use Case
Batch NormBatch (same feature)CNN, large batch
Layer NormFeature (same sample)Transformer, RNN
Instance NormSpatial (same channel)Style transfer
Group NormChannel groupsSmall batch

Weight Decay vs L2 Regularization

With SGD:

θt+1=θtη(L+λθt)=(1ηλ)θtηL\theta_{t+1} = \theta_t - \eta(\nabla \mathcal{L} + \lambda \theta_t) = (1 - \eta\lambda)\theta_t - \eta \nabla \mathcal{L}

Weight decay and L2 regularization are equivalent here. However with Adam:

  • L2 Adam: λθ\lambda\theta is added to the gradient, then divided by the adaptive scaling factor — regularization effect is weakened for parameters with large gradient variance
  • AdamW: λθ\lambda\theta is applied after the gradient update — uniform decay for all parameters regardless of gradient scale
import torch.nn as nn

class RegularizedModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(512, 256)
        self.bn1 = nn.BatchNorm1d(256)
        self.ln1 = nn.LayerNorm(256)
        self.dropout = nn.Dropout(p=0.3)

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)   # or self.ln1(x) for transformers
        x = torch.relu(x)
        x = self.dropout(x)
        return x

Loss Function Design

Cross-Entropy Loss

LCE=c=1Cyclogp^c\mathcal{L}_{CE} = -\sum_{c=1}^{C} y_c \log \hat{p}_c

Binary cross-entropy: LBCE=[ylogp+(1y)log(1p)]\mathcal{L}_{BCE} = -[y \log p + (1-y)\log(1-p)]

Focal Loss

Addresses class imbalance by down-weighting easy examples:

LFL=(1pt)γlog(pt)\mathcal{L}_{FL} = -(1-p_t)^\gamma \log(p_t)

where ptp_t is the predicted probability for the ground-truth class and γ0\gamma \geq 0 is the focusing parameter. When γ=0\gamma = 0, this reduces to standard cross-entropy.

import torch
import torch.nn.functional as F

class FocalLoss(torch.nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, logits, targets):
        bce_loss = F.binary_cross_entropy_with_logits(
            logits, targets.float(), reduction='none'
        )
        p = torch.sigmoid(logits)
        p_t = p * targets + (1 - p) * (1 - targets)
        focal_weight = (1 - p_t) ** self.gamma
        alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
        loss = alpha_t * focal_weight * bce_loss
        return loss.mean()

Contrastive Loss and Triplet Loss

Contrastive Loss (Siamese Networks):

L=(1y)d22+ymax(0,md)2\mathcal{L} = (1-y)\frac{d^2}{2} + y \cdot \max(0, m - d)^2

where d=f(x1)f(x2)2d = \|f(x_1) - f(x_2)\|_2, y=0y=0 for similar pairs and y=1y=1 for dissimilar pairs.

Triplet Loss:

Ltrip=max(0,f(a)f(p)22f(a)f(n)22+m)\mathcal{L}_{trip} = \max(0, \|f(a) - f(p)\|_2^2 - \|f(a) - f(n)\|_2^2 + m)

Uses anchor (a), positive (p), and negative (n) samples.

InfoNCE Loss (NT-Xent)

The core loss function for contrastive self-supervised learning:

LInfoNCE=logexp(sim(zi,zj)/τ)k=12N1kiexp(sim(zi,zk)/τ)\mathcal{L}_{InfoNCE} = -\log \frac{\exp(\text{sim}(z_i, z_j)/\tau)}{\sum_{k=1}^{2N} \mathbf{1}_{k \neq i} \exp(\text{sim}(z_i, z_k)/\tau)}

where τ\tau is the temperature parameter and sim\text{sim} is cosine similarity.

import torch
import torch.nn.functional as F

def info_nce_loss(features, temperature=0.07):
    """
    features: (2N, D) - two augmentation views of each image
    """
    N = features.shape[0] // 2
    features = F.normalize(features, dim=1)

    # Compute similarity matrix
    similarity = torch.matmul(features, features.T) / temperature

    # Mask self-similarity (set diagonal to -inf)
    mask = torch.eye(2 * N, dtype=torch.bool, device=features.device)
    similarity.masked_fill_(mask, float('-inf'))

    # Positive pairs: i with i+N, and i+N with i
    labels = torch.cat([
        torch.arange(N, 2 * N),
        torch.arange(N)
    ]).to(features.device)

    loss = F.cross_entropy(similarity, labels)
    return loss

LLM Training Optimization

Gradient Clipping

Prevents exploding gradients during training:

ggmin(1,clip_normg2)g \leftarrow g \cdot \min\left(1, \frac{\text{clip\_norm}}{\|g\|_2}\right)

import torch

def train_with_clipping(model, optimizer, loss, max_norm=1.0):
    optimizer.zero_grad()
    loss.backward()

    # Monitor gradient norm before clipping
    total_norm = 0
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** 0.5

    # Apply clipping
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
    optimizer.step()
    return total_norm

ZeRO Optimizer (Zero Redundancy Optimizer)

Partitions model training state across GPUs in three progressive stages:

ZeRO StagePartitioned StateMemory Reduction
Stage 1Optimizer states~4x
Stage 2+ Gradients~8x
Stage 3+ Parameters~64x (N GPUs)

Mixed precision (FP16/BF16) combined with ZeRO-3 enables training multi-billion parameter models on a single node.

8-bit Adam

Uses quantization to store optimizer states in INT8 instead of FP32:

  • Reduces optimizer state memory by 75% compared to FP32
  • Block-wise quantization minimizes precision loss
  • Available via the bitsandbytes library
# 8-bit Adam via bitsandbytes
import bitsandbytes as bnb

optimizer = bnb.optim.Adam8bit(
    model.parameters(),
    lr=1e-4,
    betas=(0.9, 0.999)
)

Adafactor

Approximates Adam's second moment matrix via low-rank factorization:

Vtr^tv^tT(rank-1 approximation)V_t \approx \hat{r}_t \hat{v}_t^T \quad \text{(rank-1 approximation)}

Requires memory proportional to parameter size (row + column vectors only). Used to train T5, PaLM, and other massive models.

OptimizerMemory (relative to params)LLM Suitability
Adam8x (params + 2 states)Moderate
AdamW8xGood
8-bit Adam6xGood
Adafactor~2xExcellent
Lion6xGood

Quiz

Q1. Why is bias correction necessary in the Adam optimizer?

Answer: To correct the bias introduced by initializing the moment estimates at zero.

Explanation: Adam initializes m0=0m_0 = 0 and v0=0v_0 = 0. In early timesteps, mtm_t and vtv_t underestimate the true moments of the gradients. For example, at t=1t=1: m1=(1β1)g1m_1 = (1-\beta_1)g_1, whose expected value (1β1)E[g1](1-\beta_1)\mathbb{E}[g_1] is much smaller than E[g1]\mathbb{E}[g_1]. Dividing by (1β1t)(1-\beta_1^t) corrects this. With β1=0.9\beta_1 = 0.9 at t=1t=1, the correction factor is 1/0.1=101/0.1 = 10. As tt grows large, β1t0\beta_1^t \to 0 and the correction factor approaches 1, becoming negligible.

Q2. Why are weight decay and L2 regularization not equivalent in Adam (and how does AdamW fix this)?

Answer: Because Adam's adaptive learning rate scales the L2 penalty gradient, weakening its regularization effect.

Explanation: In SGD, θθη(L+λθ)\theta \leftarrow \theta - \eta(\nabla \mathcal{L} + \lambda\theta) makes both approaches mathematically identical. In Adam with L2 regularization, the combined gradient becomes gt+λθtg_t + \lambda\theta_t, which is then divided by the adaptive factor 1/v^t1/\sqrt{\hat{v}_t}. Parameters with high gradient variance (large vtv_t) receive proportionally smaller regularization. AdamW decouples weight decay from the gradient update: θθ(1ηλ)ηm^t/(v^t+ϵ)\theta \leftarrow \theta(1-\eta\lambda) - \eta\hat{m}_t/(\sqrt{\hat{v}_t}+\epsilon), applying uniform decay to all parameters regardless of their gradient scale.

Q3. How do Batch Normalization and Layer Normalization differ, and when is each appropriate?

Answer: BN normalizes across the batch dimension; LN normalizes across the feature dimension of each sample.

Explanation: BN computes mean and variance over the mini-batch for each feature position. It depends on batch size — small batches yield unstable statistics. It is best suited for CNNs with fixed spatial structure and sufficiently large batches. LN computes statistics over the feature dimension of each sample independently, making it batch-size agnostic. It is ideal for Transformers (variable sequence lengths), RNNs, and online inference scenarios where batch statistics are unavailable or unreliable.

Q4. What is the mathematical principle behind Focal Loss outperforming Cross-Entropy on imbalanced datasets?

Answer: The (1pt)γ(1-p_t)^\gamma modulating factor dynamically down-weights easy examples during training.

Explanation: Standard CE loss log(pt)-\log(p_t) treats every sample equally regardless of prediction confidence. Focal Loss introduces (1pt)γ(1-p_t)^\gamma: for an easy example with pt=0.9p_t = 0.9, the weight is (10.9)2=0.01(1-0.9)^2 = 0.01, reducing its contribution by 100x. For a hard example with pt=0.1p_t = 0.1, the weight is (10.1)2=0.81(1-0.1)^2 = 0.81, preserving nearly the full loss signal. With γ=2\gamma = 2, easy well-classified majority-class samples effectively stop contributing, forcing the model to focus training on the rare, hard minority-class examples.

Q5. How does InfoNCE Loss enable contrastive learning to produce useful representations?

Answer: By maximizing similarity between augmented views of the same image while pushing apart views from different images.

Explanation: InfoNCE maximizes a lower bound on mutual information. The numerator exp(sim(zi,zj)/τ)\exp(\text{sim}(z_i, z_j)/\tau) increases cosine similarity between the two augmented views of the same image (positive pair). The denominator includes 2N12N-1 negative pairs (other images in the batch). The temperature τ\tau controls distribution sharpness: smaller τ\tau creates a more peaked distribution, forcing tighter positive-pair clusters. Large batches provide more diverse negatives, improving representation quality. SimCLR, MoCo, and CLIP all rely on this loss formulation to learn generalizable visual and multimodal representations.