Split View: 머신러닝을 위한 수학적 최적화: Adam부터 볼록 최적화, ZeRO optimizer까지
머신러닝을 위한 수학적 최적화: Adam부터 볼록 최적화, ZeRO optimizer까지
목차
최적화 기초
볼록 최적화 (Convex Optimization)
함수 이 **볼록(convex)**이면, 임의의 두 점 와 에 대해 다음이 성립합니다.
볼록 함수의 핵심 성질:
- 모든 지역 최솟값(local minimum)이 전역 최솟값(global minimum)
- 경사 하강법이 수렴이 보장됨
- 딥러닝 손실 함수는 대부분 비볼록(non-convex)이지만, 볼록 분석 기법이 여전히 유용
강볼록(Strongly Convex): 이 존재하여 가 성립하면, 수렴 속도가 선형(linear convergence)으로 빨라집니다.
라그랑주 승수법
등호 제약 최적화 문제를 다룹니다.
라그랑주안(Lagrangian):
최적해에서 , 이 성립합니다.
KKT 조건
부등호 제약까지 포함한 일반적인 최적화:
KKT 조건 (필요 조건):
- 정류성(Stationarity):
- 원시 실현가능성(Primal feasibility): ,
- 쌍대 실현가능성(Dual feasibility):
- 상보 여유(Complementary slackness):
볼록 문제에서 KKT 조건은 충분 조건도 됩니다.
안장점 (Saddle Point)
딥러닝 최적화에서 큰 문제는 지역 최솟값보다 안장점입니다. 안장점에서는 일부 방향으로는 함수값이 증가하고 다른 방향으로는 감소하여 경사가 0이 됩니다. SGD의 확률적 노이즈가 안장점 탈출에 도움이 됩니다.
경사 하강법 계열
SGD와 그 변형
기본 SGD:
SGD with Momentum:
모멘텀 가 보통 사용되며, 이전 기울기 방향을 유지해 oscillation을 줄입니다.
Nesterov Accelerated Gradient (NAG):
현재 위치가 아니라 "앞을 내다본" 위치에서 기울기를 계산합니다.
AdaGrad, RMSProp, Adam
AdaGrad: 파라미터별 학습률 적응
자주 나타나는 특성은 학습률이 줄어들고, 드물게 나타나는 특성은 학습률이 큽니다. 단점: 학습률이 단조 감소하여 학습이 멈출 수 있습니다.
RMSProp: AdaGrad의 누적 문제 해결
Adam (Adaptive Moment Estimation):
Bias correction:
기본 하이퍼파라미터: , ,
import torch
import torch.optim as optim
model = ... # 모델 정의
# Adam optimizer
optimizer_adam = optim.Adam(
model.parameters(),
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8
)
# AdamW optimizer (weight decay 분리)
optimizer_adamw = optim.AdamW(
model.parameters(),
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.01 # L2와 독립적으로 적용
)
AdamW와 Lion
AdamW: Adam에서 weight decay를 L2 패널티가 아닌 파라미터 업데이트에 직접 적용합니다.
일반 Adam의 L2 정규화와 수학적으로 동등하지 않습니다 (자세한 설명은 퀴즈 참고).
Lion (EvoLved Sign Momentum):
Lion은 부호(sign)만 사용하므로 메모리 효율적이며, 업데이트 크기가 균일합니다.
| Optimizer | 메모리 | 수렴 속도 | 적합한 상황 |
|---|---|---|---|
| SGD+Momentum | 낮음 | 느림 | 컴퓨터 비전, 큰 배치 |
| Adam | 중간 | 빠름 | NLP, 범용 |
| AdamW | 중간 | 빠름 | Transformer 학습 |
| Lion | 낮음 | 빠름 | 대규모 모델 |
| L-BFGS | 높음 | 매우 빠름 | 소규모 모델 |
2차 최적화
Newton 방법
2차 미분(Hessian)을 활용합니다.
여기서 는 Hessian 행렬입니다. 이차 수렴(quadratic convergence)하지만, Hessian의 역행렬 계산이 으로 딥러닝에서 비현실적입니다.
L-BFGS (Limited-memory BFGS)
Hessian을 직접 저장하지 않고, 최근 개의 기울기 차이로 근사합니다.
여기서 , 입니다.
import torch
import torch.optim as optim
# L-BFGS는 클로저(closure) 함수 필요
optimizer = optim.LBFGS(
model.parameters(),
lr=1.0,
max_iter=20,
history_size=10,
line_search_fn='strong_wolfe'
)
def closure():
optimizer.zero_grad()
output = model(input_data)
loss = criterion(output, target)
loss.backward()
return loss
optimizer.step(closure)
자연 경사 하강법 (Natural Gradient)
Fisher Information Matrix를 사용하여 파라미터 공간의 곡률을 고려합니다.
Fisher Matrix:
K-FAC(Kronecker-factored Approximate Curvature)은 자연 경사법을 실용적으로 구현합니다.
학습률 스케줄링
Warmup
초기에 학습률을 서서히 증가시켜 학습을 안정화합니다.
Cosine Annealing
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR, ReduceLROnPlateau
# Cosine Annealing
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)
# OneCycleLR: Warmup + Cosine Decay
scheduler = OneCycleLR(
optimizer,
max_lr=1e-3,
total_steps=1000,
pct_start=0.3, # 30% warmup
anneal_strategy='cos'
)
# ReduceLROnPlateau: 검증 손실이 개선되지 않으면 감소
scheduler = ReduceLROnPlateau(
optimizer,
mode='min',
factor=0.5,
patience=10,
min_lr=1e-6
)
Cyclical Learning Rate (CLR)
학습률을 주기적으로 변동시켜 안장점 탈출을 돕습니다.
| 스케줄러 | 특징 | 적합한 상황 |
|---|---|---|
| Cosine Annealing | 부드러운 감소 | Transformer 사전학습 |
| OneCycleLR | Warmup + 빠른 감소 | 파인튜닝, 짧은 학습 |
| ReduceLROnPlateau | 적응형 | 일반 학습, 검증 필요 |
| Cyclical LR | 주기적 변동 | 안장점 회피 |
| Linear Warmup | 초기 안정화 | LLM 학습 |
정규화 기법
L1 / L2 정규화
L2 정규화 (Ridge):
기울기:
L1 정규화 (Lasso):
L1은 희소(sparse) 솔루션을 유도합니다.
Batch Normalization vs Layer Normalization
Batch Normalization (BN):
여기서 , 는 미니배치 내 통계량입니다. 배치 방향으로 정규화합니다.
Layer Normalization (LN):
여기서 통계량은 각 샘플의 피처 차원을 따라 계산됩니다.
| 정규화 | 통계량 계산 방향 | 적합한 상황 |
|---|---|---|
| Batch Norm | 배치 방향 (같은 피처) | CNN, 큰 배치 |
| Layer Norm | 피처 방향 (같은 샘플) | Transformer, RNN |
| Instance Norm | 공간 방향 (같은 채널) | 스타일 전이 |
| Group Norm | 채널 그룹 | 작은 배치 |
Weight Decay vs L2 정규화
SGD에서:
이 경우 weight decay와 L2 정규화는 동일합니다. 하지만 Adam에서는:
- L2 Adam: 기울기에 를 더한 후 적응형 학습률 적용 → 적응 계수로 나누어져 정규화 효과 약화
- AdamW: 파라미터 업데이트 후 를 직접 빼냄 → 모든 파라미터에 균등한 weight decay
# Dropout
import torch.nn as nn
class RegularizedModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(512, 256)
self.bn1 = nn.BatchNorm1d(256)
self.ln1 = nn.LayerNorm(256)
self.dropout = nn.Dropout(p=0.3)
def forward(self, x):
x = self.fc1(x)
x = self.bn1(x) # 또는 self.ln1(x)
x = torch.relu(x)
x = self.dropout(x)
return x
손실 함수 설계
Cross-Entropy Loss
이진 분류:
Focal Loss
클래스 불균형 문제를 해결합니다. 쉬운 샘플의 기여를 줄입니다.
여기서 는 정답 클래스의 예측 확률, 는 focusing parameter입니다. 이면 일반 Cross-Entropy와 동일합니다.
import torch
import torch.nn.functional as F
class FocalLoss(torch.nn.Module):
def __init__(self, alpha=0.25, gamma=2.0):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, logits, targets):
bce_loss = F.binary_cross_entropy_with_logits(
logits, targets.float(), reduction='none'
)
p = torch.sigmoid(logits)
p_t = p * targets + (1 - p) * (1 - targets)
focal_weight = (1 - p_t) ** self.gamma
alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
loss = alpha_t * focal_weight * bce_loss
return loss.mean()
Contrastive Loss와 Triplet Loss
Contrastive Loss (Siamese Network):
여기서 , 이면 유사 쌍, 이면 비유사 쌍입니다.
Triplet Loss:
앵커(a), 포지티브(p), 네거티브(n) 샘플을 사용합니다.
InfoNCE Loss (NT-Xent)
대조 학습(Contrastive Learning)의 핵심 손실 함수입니다.
여기서 는 temperature parameter, 은 코사인 유사도입니다.
import torch
import torch.nn.functional as F
def info_nce_loss(features, temperature=0.07):
"""
features: (2N, D) - 각 이미지의 두 augmentation view
"""
N = features.shape[0] // 2
features = F.normalize(features, dim=1)
# 유사도 행렬 계산
similarity = torch.matmul(features, features.T) / temperature
# 자기 자신 제거 (대각선을 -inf로)
mask = torch.eye(2 * N, dtype=torch.bool, device=features.device)
similarity.masked_fill_(mask, float('-inf'))
# 포지티브 쌍: i와 i+N, i+N과 i
labels = torch.cat([
torch.arange(N, 2 * N),
torch.arange(N)
]).to(features.device)
loss = F.cross_entropy(similarity, labels)
return loss
LLM 학습 최적화
Gradient Clipping
기울기 폭발(exploding gradient)을 방지합니다.
import torch
def train_with_clipping(model, optimizer, loss, max_norm=1.0):
optimizer.zero_grad()
loss.backward()
# 기울기 노름 모니터링
total_norm = 0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
# 클리핑 적용
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
optimizer.step()
return total_norm
ZeRO Optimizer (Zero Redundancy Optimizer)
모델 학습 시 메모리를 3단계로 최적화합니다.
| ZeRO 단계 | 분산 대상 | 메모리 절감 |
|---|---|---|
| Stage 1 | Optimizer 상태 | ~4x |
| Stage 2 | + 기울기 | ~8x |
| Stage 3 | + 파라미터 | ~64x (N GPU 기준) |
혼합 정밀도(mixed precision) + ZeRO-3로 수십억 파라미터 모델을 단일 노드에서 학습 가능합니다.
8-bit Adam
양자화(quantization)를 통해 optimizer 상태를 FP32 대신 INT8로 저장합니다.
- Optimizer 상태 메모리를 75% 절감 (FP32 대비)
- Block-wise quantization으로 정밀도 손실 최소화
bitsandbytes라이브러리로 구현 가능
# bitsandbytes 8-bit Adam
import bitsandbytes as bnb
optimizer = bnb.optim.Adam8bit(
model.parameters(),
lr=1e-4,
betas=(0.9, 0.999)
)
Adafactor
Adam의 2차 모멘트를 행렬 분해로 근사합니다.
파라미터 크기에 비례한 메모리만 사용 (행 + 열 벡터). T5, PaLM 등 초대형 모델 학습에 사용됩니다.
| Optimizer | 메모리 (파라미터 기준 배수) | LLM 적합도 |
|---|---|---|
| Adam | 8x (params + 2 states) | 보통 |
| AdamW | 8x | 좋음 |
| 8-bit Adam | 6x | 좋음 |
| Adafactor | ~2x | 매우 좋음 |
| Lion | 6x | 좋음 |
퀴즈
Q1. Adam optimizer에서 bias correction이 필요한 이유는 무엇인가요?
정답: 초기 모멘트 추정값의 0 초기화로 인한 편향을 보정하기 위해서입니다.
설명: Adam에서 , 으로 초기화하면, 초기 타임스텝 에서 와 는 실제 기울기의 모멘트를 과소평가합니다. 예를 들어 에서 이고, 이는 실제 기댓값 인데, 이를 로 복원하려면 을 곱해야 합니다. 일 때 에서 로 보정됩니다. 가 충분히 크면 이므로 보정 계수가 1에 수렴하여 효과가 사라집니다.
Q2. Weight Decay와 L2 정규화가 Adam에서 동등하지 않은 이유는?
정답: Adam의 적응형 학습률이 L2 페널티 기울기를 스케일링하기 때문입니다.
설명: SGD에서는 로 두 방식이 수학적으로 동일합니다. 그러나 Adam에서 L2 정규화를 추가하면 기울기가 가 되고, 이것이 적응형 스케일링 계수 로 나누어집니다. 따라서 가중치가 큰 파라미터(큰 )는 L2 페널티도 작아집니다. AdamW는 weight decay를 적응형 스케일링 밖으로 꺼내어 로 처리함으로써, 모든 파라미터에 균등한 정규화를 적용합니다.
Q3. Batch Normalization과 Layer Normalization의 차이와 적합한 상황은?
정답: BN은 배치 차원에서, LN은 피처 차원에서 정규화합니다.
설명: BN은 미니배치 내 같은 피처(뉴런)들의 평균/분산으로 정규화합니다. 따라서 배치 크기에 의존하며, 배치 크기가 작으면 통계량 추정이 불안정해집니다. CNN처럼 공간적 피처가 있고 배치 크기가 충분한 경우에 적합합니다. LN은 각 샘플의 피처 차원을 따라 정규화하므로 배치 크기에 독립적입니다. Transformer처럼 시퀀스 길이가 가변적이거나 RNN처럼 배치 통계를 유지하기 어려운 상황에 적합합니다. 추론 시에도 배치 통계가 필요 없으므로 온라인 추론에 유리합니다.
Q4. Focal Loss가 Cross-Entropy보다 클래스 불균형에 효과적인 수학적 원리는?
정답: 가중치가 쉬운 샘플의 기여를 동적으로 감소시키기 때문입니다.
설명: 일반 CE 손실은 로 쉽게 분류되는 다수 클래스 샘플도 동일하게 기여합니다. Focal Loss의 항을 살펴보면, (쉬운 샘플)일 때 로 가중치가 매우 작아집니다. 반면 (어려운 샘플)일 때 로 거의 그대로 유지됩니다. 를 사용하면 쉬운 샘플의 손실이 100배 감소합니다. 이를 통해 모델이 어려운 소수 클래스 샘플에 집중하여 학습합니다.
Q5. InfoNCE Loss가 대조 학습에서 좋은 표현을 학습하는 원리는?
정답: 같은 이미지의 다른 augmentation 쌍을 유사하게, 다른 이미지는 멀어지도록 학습합니다.
설명: InfoNCE는 상호 정보량(mutual information)의 하한을 최대화합니다. 분자 는 포지티브 쌍(같은 이미지의 두 뷰)의 유사도를 높이고, 분모는 개의 네거티브 쌍을 포함합니다. Temperature 는 분포의 날카로움을 조절합니다. 가 작을수록 경쟁이 치열해져 표현 공간이 더 구별적이 됩니다. 대규모 배치에서 다양한 네거티브 샘플을 제공하면 표현이 더욱 일반화됩니다. SimCLR, MoCo, CLIP 등 주요 대조 학습 모델이 이 손실 함수를 사용합니다.
Mathematical Optimization for Machine Learning: From Adam to Convex Optimization and ZeRO
Table of Contents
- Optimization Fundamentals: Convex Optimization and KKT Conditions
- Gradient Descent Family of Optimizers
- Second-Order Optimization Methods
- Learning Rate Scheduling
- Regularization Techniques
- Loss Function Design
- LLM Training Optimization
- Quiz
Optimization Fundamentals
Convex Optimization
A function is convex if for any two points and :
Key properties of convex functions:
- Every local minimum is a global minimum
- Gradient descent is guaranteed to converge
- Deep learning loss surfaces are mostly non-convex, but convex analysis techniques remain useful
Strongly Convex: If there exists such that , convergence is linear (exponentially fast).
Lagrange Multipliers
Handles equality-constrained optimization problems:
Lagrangian:
At the optimum: and .
KKT Conditions
For the general constrained optimization problem:
KKT necessary conditions:
- Stationarity:
- Primal feasibility: ,
- Dual feasibility:
- Complementary slackness:
For convex problems, KKT conditions are also sufficient.
Saddle Points
In deep learning, saddle points are a greater concern than local minima. At a saddle point, the gradient is zero but it is neither a local min nor max. The stochastic noise in SGD helps escape saddle points.
Gradient Descent Family
SGD and Its Variants
Vanilla SGD:
SGD with Momentum:
Momentum is standard; it accumulates past gradients to reduce oscillation.
Nesterov Accelerated Gradient (NAG):
Computes the gradient at a "lookahead" position rather than the current one.
AdaGrad, RMSProp, Adam
AdaGrad: Per-parameter adaptive learning rate
Frequent features get smaller updates; rare features get larger updates. Drawback: monotonically shrinking learning rates cause learning to stall.
RMSProp: Fixes AdaGrad's accumulation problem
Adam (Adaptive Moment Estimation):
Bias correction:
Default hyperparameters: , ,
import torch
import torch.optim as optim
model = ... # define your model
# Standard Adam
optimizer_adam = optim.Adam(
model.parameters(),
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8
)
# AdamW (decoupled weight decay)
optimizer_adamw = optim.AdamW(
model.parameters(),
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.01 # applied independently of gradient scaling
)
AdamW and Lion
AdamW: Applies weight decay directly to parameter updates, separate from the gradient-based term.
This is mathematically inequivalent to adding L2 regularization inside Adam (see Quiz for details).
Lion (EvoLved Sign Momentum):
Lion uses only the sign of the update, providing uniform update magnitude and better memory efficiency.
| Optimizer | Memory | Convergence | Best Use Case |
|---|---|---|---|
| SGD+Momentum | Low | Slow | Computer vision, large batch |
| Adam | Medium | Fast | NLP, general purpose |
| AdamW | Medium | Fast | Transformer training |
| Lion | Low | Fast | Large-scale models |
| L-BFGS | High | Very fast | Small models |
Second-Order Optimization
Newton's Method
Uses second-order derivatives (Hessian):
where is the Hessian matrix. Achieves quadratic convergence, but inverting an matrix requires computation — impractical for deep learning.
L-BFGS (Limited-memory BFGS)
Approximates the inverse Hessian using the last gradient differences, without storing the full matrix.
where and .
import torch
import torch.optim as optim
# L-BFGS requires a closure function
optimizer = optim.LBFGS(
model.parameters(),
lr=1.0,
max_iter=20,
history_size=10,
line_search_fn='strong_wolfe'
)
def closure():
optimizer.zero_grad()
output = model(input_data)
loss = criterion(output, target)
loss.backward()
return loss
optimizer.step(closure)
Natural Gradient Descent
Uses the Fisher Information Matrix to account for the curvature of the parameter space:
Fisher Matrix:
K-FAC (Kronecker-factored Approximate Curvature) provides a practical implementation by factoring the Fisher matrix layer-wise.
Learning Rate Scheduling
Linear Warmup
Gradually increases the learning rate at the start to stabilize training:
Cosine Annealing
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR, ReduceLROnPlateau
# Cosine Annealing
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)
# OneCycleLR: Warmup + cosine decay
scheduler = OneCycleLR(
optimizer,
max_lr=1e-3,
total_steps=1000,
pct_start=0.3, # 30% warmup phase
anneal_strategy='cos'
)
# ReduceLROnPlateau: reduce when validation loss stalls
scheduler = ReduceLROnPlateau(
optimizer,
mode='min',
factor=0.5,
patience=10,
min_lr=1e-6
)
Cyclical Learning Rate (CLR)
Periodically varies the learning rate to help escape saddle points:
| Scheduler | Characteristic | Best Use Case |
|---|---|---|
| Cosine Annealing | Smooth decay | Transformer pretraining |
| OneCycleLR | Warmup + fast decay | Fine-tuning, short runs |
| ReduceLROnPlateau | Adaptive | General training |
| Cyclical LR | Periodic oscillation | Saddle point escape |
| Linear Warmup | Initial stabilization | LLM training |
Regularization Techniques
L1 / L2 Regularization
L2 Regularization (Ridge):
Gradient:
L1 Regularization (Lasso):
L1 induces sparse solutions, driving many weights to exactly zero.
Batch Normalization vs Layer Normalization
Batch Normalization (BN):
where , are computed across the mini-batch dimension. Normalizes along the batch axis.
Layer Normalization (LN):
Statistics are computed over the feature dimension of each individual sample.
| Normalization | Statistic Axis | Best Use Case |
|---|---|---|
| Batch Norm | Batch (same feature) | CNN, large batch |
| Layer Norm | Feature (same sample) | Transformer, RNN |
| Instance Norm | Spatial (same channel) | Style transfer |
| Group Norm | Channel groups | Small batch |
Weight Decay vs L2 Regularization
With SGD:
Weight decay and L2 regularization are equivalent here. However with Adam:
- L2 Adam: is added to the gradient, then divided by the adaptive scaling factor — regularization effect is weakened for parameters with large gradient variance
- AdamW: is applied after the gradient update — uniform decay for all parameters regardless of gradient scale
import torch.nn as nn
class RegularizedModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(512, 256)
self.bn1 = nn.BatchNorm1d(256)
self.ln1 = nn.LayerNorm(256)
self.dropout = nn.Dropout(p=0.3)
def forward(self, x):
x = self.fc1(x)
x = self.bn1(x) # or self.ln1(x) for transformers
x = torch.relu(x)
x = self.dropout(x)
return x
Loss Function Design
Cross-Entropy Loss
Binary cross-entropy:
Focal Loss
Addresses class imbalance by down-weighting easy examples:
where is the predicted probability for the ground-truth class and is the focusing parameter. When , this reduces to standard cross-entropy.
import torch
import torch.nn.functional as F
class FocalLoss(torch.nn.Module):
def __init__(self, alpha=0.25, gamma=2.0):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, logits, targets):
bce_loss = F.binary_cross_entropy_with_logits(
logits, targets.float(), reduction='none'
)
p = torch.sigmoid(logits)
p_t = p * targets + (1 - p) * (1 - targets)
focal_weight = (1 - p_t) ** self.gamma
alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
loss = alpha_t * focal_weight * bce_loss
return loss.mean()
Contrastive Loss and Triplet Loss
Contrastive Loss (Siamese Networks):
where , for similar pairs and for dissimilar pairs.
Triplet Loss:
Uses anchor (a), positive (p), and negative (n) samples.
InfoNCE Loss (NT-Xent)
The core loss function for contrastive self-supervised learning:
where is the temperature parameter and is cosine similarity.
import torch
import torch.nn.functional as F
def info_nce_loss(features, temperature=0.07):
"""
features: (2N, D) - two augmentation views of each image
"""
N = features.shape[0] // 2
features = F.normalize(features, dim=1)
# Compute similarity matrix
similarity = torch.matmul(features, features.T) / temperature
# Mask self-similarity (set diagonal to -inf)
mask = torch.eye(2 * N, dtype=torch.bool, device=features.device)
similarity.masked_fill_(mask, float('-inf'))
# Positive pairs: i with i+N, and i+N with i
labels = torch.cat([
torch.arange(N, 2 * N),
torch.arange(N)
]).to(features.device)
loss = F.cross_entropy(similarity, labels)
return loss
LLM Training Optimization
Gradient Clipping
Prevents exploding gradients during training:
import torch
def train_with_clipping(model, optimizer, loss, max_norm=1.0):
optimizer.zero_grad()
loss.backward()
# Monitor gradient norm before clipping
total_norm = 0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
# Apply clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
optimizer.step()
return total_norm
ZeRO Optimizer (Zero Redundancy Optimizer)
Partitions model training state across GPUs in three progressive stages:
| ZeRO Stage | Partitioned State | Memory Reduction |
|---|---|---|
| Stage 1 | Optimizer states | ~4x |
| Stage 2 | + Gradients | ~8x |
| Stage 3 | + Parameters | ~64x (N GPUs) |
Mixed precision (FP16/BF16) combined with ZeRO-3 enables training multi-billion parameter models on a single node.
8-bit Adam
Uses quantization to store optimizer states in INT8 instead of FP32:
- Reduces optimizer state memory by 75% compared to FP32
- Block-wise quantization minimizes precision loss
- Available via the
bitsandbyteslibrary
# 8-bit Adam via bitsandbytes
import bitsandbytes as bnb
optimizer = bnb.optim.Adam8bit(
model.parameters(),
lr=1e-4,
betas=(0.9, 0.999)
)
Adafactor
Approximates Adam's second moment matrix via low-rank factorization:
Requires memory proportional to parameter size (row + column vectors only). Used to train T5, PaLM, and other massive models.
| Optimizer | Memory (relative to params) | LLM Suitability |
|---|---|---|
| Adam | 8x (params + 2 states) | Moderate |
| AdamW | 8x | Good |
| 8-bit Adam | 6x | Good |
| Adafactor | ~2x | Excellent |
| Lion | 6x | Good |
Quiz
Q1. Why is bias correction necessary in the Adam optimizer?
Answer: To correct the bias introduced by initializing the moment estimates at zero.
Explanation: Adam initializes and . In early timesteps, and underestimate the true moments of the gradients. For example, at : , whose expected value is much smaller than . Dividing by corrects this. With at , the correction factor is . As grows large, and the correction factor approaches 1, becoming negligible.
Q2. Why are weight decay and L2 regularization not equivalent in Adam (and how does AdamW fix this)?
Answer: Because Adam's adaptive learning rate scales the L2 penalty gradient, weakening its regularization effect.
Explanation: In SGD, makes both approaches mathematically identical. In Adam with L2 regularization, the combined gradient becomes , which is then divided by the adaptive factor . Parameters with high gradient variance (large ) receive proportionally smaller regularization. AdamW decouples weight decay from the gradient update: , applying uniform decay to all parameters regardless of their gradient scale.
Q3. How do Batch Normalization and Layer Normalization differ, and when is each appropriate?
Answer: BN normalizes across the batch dimension; LN normalizes across the feature dimension of each sample.
Explanation: BN computes mean and variance over the mini-batch for each feature position. It depends on batch size — small batches yield unstable statistics. It is best suited for CNNs with fixed spatial structure and sufficiently large batches. LN computes statistics over the feature dimension of each sample independently, making it batch-size agnostic. It is ideal for Transformers (variable sequence lengths), RNNs, and online inference scenarios where batch statistics are unavailable or unreliable.
Q4. What is the mathematical principle behind Focal Loss outperforming Cross-Entropy on imbalanced datasets?
Answer: The modulating factor dynamically down-weights easy examples during training.
Explanation: Standard CE loss treats every sample equally regardless of prediction confidence. Focal Loss introduces : for an easy example with , the weight is , reducing its contribution by 100x. For a hard example with , the weight is , preserving nearly the full loss signal. With , easy well-classified majority-class samples effectively stop contributing, forcing the model to focus training on the rare, hard minority-class examples.
Q5. How does InfoNCE Loss enable contrastive learning to produce useful representations?
Answer: By maximizing similarity between augmented views of the same image while pushing apart views from different images.
Explanation: InfoNCE maximizes a lower bound on mutual information. The numerator increases cosine similarity between the two augmented views of the same image (positive pair). The denominator includes negative pairs (other images in the batch). The temperature controls distribution sharpness: smaller creates a more peaked distribution, forcing tighter positive-pair clusters. Large batches provide more diverse negatives, improving representation quality. SimCLR, MoCo, and CLIP all rely on this loss formulation to learn generalizable visual and multimodal representations.