- Published on
Mathematical Optimization for Machine Learning: From Adam to Convex Optimization and ZeRO
- Authors

- Name
- Youngju Kim
- @fjvbn20031
Table of Contents
- Optimization Fundamentals: Convex Optimization and KKT Conditions
- Gradient Descent Family of Optimizers
- Second-Order Optimization Methods
- Learning Rate Scheduling
- Regularization Techniques
- Loss Function Design
- LLM Training Optimization
- Quiz
Optimization Fundamentals
Convex Optimization
A function is convex if for any two points and :
Key properties of convex functions:
- Every local minimum is a global minimum
- Gradient descent is guaranteed to converge
- Deep learning loss surfaces are mostly non-convex, but convex analysis techniques remain useful
Strongly Convex: If there exists such that , convergence is linear (exponentially fast).
Lagrange Multipliers
Handles equality-constrained optimization problems:
Lagrangian:
At the optimum: and .
KKT Conditions
For the general constrained optimization problem:
KKT necessary conditions:
- Stationarity:
- Primal feasibility: ,
- Dual feasibility:
- Complementary slackness:
For convex problems, KKT conditions are also sufficient.
Saddle Points
In deep learning, saddle points are a greater concern than local minima. At a saddle point, the gradient is zero but it is neither a local min nor max. The stochastic noise in SGD helps escape saddle points.
Gradient Descent Family
SGD and Its Variants
Vanilla SGD:
SGD with Momentum:
Momentum is standard; it accumulates past gradients to reduce oscillation.
Nesterov Accelerated Gradient (NAG):
Computes the gradient at a "lookahead" position rather than the current one.
AdaGrad, RMSProp, Adam
AdaGrad: Per-parameter adaptive learning rate
Frequent features get smaller updates; rare features get larger updates. Drawback: monotonically shrinking learning rates cause learning to stall.
RMSProp: Fixes AdaGrad's accumulation problem
Adam (Adaptive Moment Estimation):
Bias correction:
Default hyperparameters: , ,
import torch
import torch.optim as optim
model = ... # define your model
# Standard Adam
optimizer_adam = optim.Adam(
model.parameters(),
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8
)
# AdamW (decoupled weight decay)
optimizer_adamw = optim.AdamW(
model.parameters(),
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.01 # applied independently of gradient scaling
)
AdamW and Lion
AdamW: Applies weight decay directly to parameter updates, separate from the gradient-based term.
This is mathematically inequivalent to adding L2 regularization inside Adam (see Quiz for details).
Lion (EvoLved Sign Momentum):
Lion uses only the sign of the update, providing uniform update magnitude and better memory efficiency.
| Optimizer | Memory | Convergence | Best Use Case |
|---|---|---|---|
| SGD+Momentum | Low | Slow | Computer vision, large batch |
| Adam | Medium | Fast | NLP, general purpose |
| AdamW | Medium | Fast | Transformer training |
| Lion | Low | Fast | Large-scale models |
| L-BFGS | High | Very fast | Small models |
Second-Order Optimization
Newton's Method
Uses second-order derivatives (Hessian):
where is the Hessian matrix. Achieves quadratic convergence, but inverting an matrix requires computation — impractical for deep learning.
L-BFGS (Limited-memory BFGS)
Approximates the inverse Hessian using the last gradient differences, without storing the full matrix.
where and .
import torch
import torch.optim as optim
# L-BFGS requires a closure function
optimizer = optim.LBFGS(
model.parameters(),
lr=1.0,
max_iter=20,
history_size=10,
line_search_fn='strong_wolfe'
)
def closure():
optimizer.zero_grad()
output = model(input_data)
loss = criterion(output, target)
loss.backward()
return loss
optimizer.step(closure)
Natural Gradient Descent
Uses the Fisher Information Matrix to account for the curvature of the parameter space:
Fisher Matrix:
K-FAC (Kronecker-factored Approximate Curvature) provides a practical implementation by factoring the Fisher matrix layer-wise.
Learning Rate Scheduling
Linear Warmup
Gradually increases the learning rate at the start to stabilize training:
Cosine Annealing
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR, ReduceLROnPlateau
# Cosine Annealing
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)
# OneCycleLR: Warmup + cosine decay
scheduler = OneCycleLR(
optimizer,
max_lr=1e-3,
total_steps=1000,
pct_start=0.3, # 30% warmup phase
anneal_strategy='cos'
)
# ReduceLROnPlateau: reduce when validation loss stalls
scheduler = ReduceLROnPlateau(
optimizer,
mode='min',
factor=0.5,
patience=10,
min_lr=1e-6
)
Cyclical Learning Rate (CLR)
Periodically varies the learning rate to help escape saddle points:
| Scheduler | Characteristic | Best Use Case |
|---|---|---|
| Cosine Annealing | Smooth decay | Transformer pretraining |
| OneCycleLR | Warmup + fast decay | Fine-tuning, short runs |
| ReduceLROnPlateau | Adaptive | General training |
| Cyclical LR | Periodic oscillation | Saddle point escape |
| Linear Warmup | Initial stabilization | LLM training |
Regularization Techniques
L1 / L2 Regularization
L2 Regularization (Ridge):
Gradient:
L1 Regularization (Lasso):
L1 induces sparse solutions, driving many weights to exactly zero.
Batch Normalization vs Layer Normalization
Batch Normalization (BN):
where , are computed across the mini-batch dimension. Normalizes along the batch axis.
Layer Normalization (LN):
Statistics are computed over the feature dimension of each individual sample.
| Normalization | Statistic Axis | Best Use Case |
|---|---|---|
| Batch Norm | Batch (same feature) | CNN, large batch |
| Layer Norm | Feature (same sample) | Transformer, RNN |
| Instance Norm | Spatial (same channel) | Style transfer |
| Group Norm | Channel groups | Small batch |
Weight Decay vs L2 Regularization
With SGD:
Weight decay and L2 regularization are equivalent here. However with Adam:
- L2 Adam: is added to the gradient, then divided by the adaptive scaling factor — regularization effect is weakened for parameters with large gradient variance
- AdamW: is applied after the gradient update — uniform decay for all parameters regardless of gradient scale
import torch.nn as nn
class RegularizedModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(512, 256)
self.bn1 = nn.BatchNorm1d(256)
self.ln1 = nn.LayerNorm(256)
self.dropout = nn.Dropout(p=0.3)
def forward(self, x):
x = self.fc1(x)
x = self.bn1(x) # or self.ln1(x) for transformers
x = torch.relu(x)
x = self.dropout(x)
return x
Loss Function Design
Cross-Entropy Loss
Binary cross-entropy:
Focal Loss
Addresses class imbalance by down-weighting easy examples:
where is the predicted probability for the ground-truth class and is the focusing parameter. When , this reduces to standard cross-entropy.
import torch
import torch.nn.functional as F
class FocalLoss(torch.nn.Module):
def __init__(self, alpha=0.25, gamma=2.0):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, logits, targets):
bce_loss = F.binary_cross_entropy_with_logits(
logits, targets.float(), reduction='none'
)
p = torch.sigmoid(logits)
p_t = p * targets + (1 - p) * (1 - targets)
focal_weight = (1 - p_t) ** self.gamma
alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
loss = alpha_t * focal_weight * bce_loss
return loss.mean()
Contrastive Loss and Triplet Loss
Contrastive Loss (Siamese Networks):
where , for similar pairs and for dissimilar pairs.
Triplet Loss:
Uses anchor (a), positive (p), and negative (n) samples.
InfoNCE Loss (NT-Xent)
The core loss function for contrastive self-supervised learning:
where is the temperature parameter and is cosine similarity.
import torch
import torch.nn.functional as F
def info_nce_loss(features, temperature=0.07):
"""
features: (2N, D) - two augmentation views of each image
"""
N = features.shape[0] // 2
features = F.normalize(features, dim=1)
# Compute similarity matrix
similarity = torch.matmul(features, features.T) / temperature
# Mask self-similarity (set diagonal to -inf)
mask = torch.eye(2 * N, dtype=torch.bool, device=features.device)
similarity.masked_fill_(mask, float('-inf'))
# Positive pairs: i with i+N, and i+N with i
labels = torch.cat([
torch.arange(N, 2 * N),
torch.arange(N)
]).to(features.device)
loss = F.cross_entropy(similarity, labels)
return loss
LLM Training Optimization
Gradient Clipping
Prevents exploding gradients during training:
import torch
def train_with_clipping(model, optimizer, loss, max_norm=1.0):
optimizer.zero_grad()
loss.backward()
# Monitor gradient norm before clipping
total_norm = 0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
# Apply clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
optimizer.step()
return total_norm
ZeRO Optimizer (Zero Redundancy Optimizer)
Partitions model training state across GPUs in three progressive stages:
| ZeRO Stage | Partitioned State | Memory Reduction |
|---|---|---|
| Stage 1 | Optimizer states | ~4x |
| Stage 2 | + Gradients | ~8x |
| Stage 3 | + Parameters | ~64x (N GPUs) |
Mixed precision (FP16/BF16) combined with ZeRO-3 enables training multi-billion parameter models on a single node.
8-bit Adam
Uses quantization to store optimizer states in INT8 instead of FP32:
- Reduces optimizer state memory by 75% compared to FP32
- Block-wise quantization minimizes precision loss
- Available via the
bitsandbyteslibrary
# 8-bit Adam via bitsandbytes
import bitsandbytes as bnb
optimizer = bnb.optim.Adam8bit(
model.parameters(),
lr=1e-4,
betas=(0.9, 0.999)
)
Adafactor
Approximates Adam's second moment matrix via low-rank factorization:
Requires memory proportional to parameter size (row + column vectors only). Used to train T5, PaLM, and other massive models.
| Optimizer | Memory (relative to params) | LLM Suitability |
|---|---|---|
| Adam | 8x (params + 2 states) | Moderate |
| AdamW | 8x | Good |
| 8-bit Adam | 6x | Good |
| Adafactor | ~2x | Excellent |
| Lion | 6x | Good |
Quiz
Q1. Why is bias correction necessary in the Adam optimizer?
Answer: To correct the bias introduced by initializing the moment estimates at zero.
Explanation: Adam initializes and . In early timesteps, and underestimate the true moments of the gradients. For example, at : , whose expected value is much smaller than . Dividing by corrects this. With at , the correction factor is . As grows large, and the correction factor approaches 1, becoming negligible.
Q2. Why are weight decay and L2 regularization not equivalent in Adam (and how does AdamW fix this)?
Answer: Because Adam's adaptive learning rate scales the L2 penalty gradient, weakening its regularization effect.
Explanation: In SGD, makes both approaches mathematically identical. In Adam with L2 regularization, the combined gradient becomes , which is then divided by the adaptive factor . Parameters with high gradient variance (large ) receive proportionally smaller regularization. AdamW decouples weight decay from the gradient update: , applying uniform decay to all parameters regardless of their gradient scale.
Q3. How do Batch Normalization and Layer Normalization differ, and when is each appropriate?
Answer: BN normalizes across the batch dimension; LN normalizes across the feature dimension of each sample.
Explanation: BN computes mean and variance over the mini-batch for each feature position. It depends on batch size — small batches yield unstable statistics. It is best suited for CNNs with fixed spatial structure and sufficiently large batches. LN computes statistics over the feature dimension of each sample independently, making it batch-size agnostic. It is ideal for Transformers (variable sequence lengths), RNNs, and online inference scenarios where batch statistics are unavailable or unreliable.
Q4. What is the mathematical principle behind Focal Loss outperforming Cross-Entropy on imbalanced datasets?
Answer: The modulating factor dynamically down-weights easy examples during training.
Explanation: Standard CE loss treats every sample equally regardless of prediction confidence. Focal Loss introduces : for an easy example with , the weight is , reducing its contribution by 100x. For a hard example with , the weight is , preserving nearly the full loss signal. With , easy well-classified majority-class samples effectively stop contributing, forcing the model to focus training on the rare, hard minority-class examples.
Q5. How does InfoNCE Loss enable contrastive learning to produce useful representations?
Answer: By maximizing similarity between augmented views of the same image while pushing apart views from different images.
Explanation: InfoNCE maximizes a lower bound on mutual information. The numerator increases cosine similarity between the two augmented views of the same image (positive pair). The denominator includes negative pairs (other images in the batch). The temperature controls distribution sharpness: smaller creates a more peaked distribution, forcing tighter positive-pair clusters. Large batches provide more diverse negatives, improving representation quality. SimCLR, MoCo, and CLIP all rely on this loss formulation to learn generalizable visual and multimodal representations.