- Authors

- Name
- Youngju Kim
- @fjvbn20031
目次
最適化の基礎
凸最適化 (Convex Optimization)
関数 が**凸(convex)**であるとは、任意の2点 と に対して次が成立することです。
凸関数の重要な性質:
- すべての局所最小値が大域最小値と一致する
- 勾配降下法の収束が保証される
- 深層学習の損失関数はほとんど非凸だが、凸解析の手法が依然として有用
強凸性(Strongly Convex): が存在して が成立すれば、収束が線形(線形収束)になります。
ラグランジュ乗数法
等式制約付き最適化問題を扱います。
ラグランジアン:
最適解では 、 が成立します。
KKT条件
不等式制約を含む一般的な最適化問題:
KKT条件(必要条件):
- 停留性(Stationarity):
- 主実行可能性(Primal feasibility): 、
- 双対実行可能性(Dual feasibility):
- 相補スラック(Complementary slackness):
凸問題ではKKT条件は十分条件にもなります。
鞍点 (Saddle Point)
深層学習の最適化では、局所最小値より鞍点の方が大きな問題です。鞍点では勾配がゼロになりますが、ある方向では関数値が増加し、別の方向では減少します。SGDの確率的ノイズが鞍点からの脱出を助けます。
勾配降下法系
SGDとその変形
基本SGD:
SGD with Momentum:
モメンタム が標準的で、過去の勾配方向を保持してoscillationを低減します。
Nesterov加速勾配(NAG):
現在位置ではなく「先読み」位置で勾配を計算します。
AdaGrad、RMSProp、Adam
AdaGrad: パラメータごとの適応学習率
頻出する特徴は学習率が小さく、稀な特徴は大きい学習率になります。欠点: 学習率が単調減少して学習が停滞します。
RMSProp: AdaGradの蓄積問題を解決
Adam (Adaptive Moment Estimation):
バイアス補正:
デフォルトハイパーパラメータ: 、、
import torch
import torch.optim as optim
model = ... # モデル定義
# 標準Adam
optimizer_adam = optim.Adam(
model.parameters(),
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8
)
# AdamW (weight decayを分離)
optimizer_adamw = optim.AdamW(
model.parameters(),
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.01 # 勾配スケーリングとは独立して適用
)
AdamWとLion
AdamW: weight decayをL2ペナルティとしてではなく、パラメータ更新に直接適用します。
AdamにL2正則化を追加する場合と数学的に同等ではありません(クイズを参照)。
Lion (EvoLved Sign Momentum):
符号(sign)のみを使用するため、更新量が均一でメモリ効率が良いです。
| Optimizer | メモリ | 収束速度 | 適したシナリオ |
|---|---|---|---|
| SGD+Momentum | 低 | 遅い | コンピュータビジョン、大バッチ |
| Adam | 中 | 速い | NLP、汎用 |
| AdamW | 中 | 速い | Transformerの学習 |
| Lion | 低 | 速い | 大規模モデル |
| L-BFGS | 高 | 非常に速い | 小規模モデル |
2次最適化法
Newton法
2次微分(Hessian)を活用します。
ここで はHessian行列です。2次収束しますが、 行列の逆行列計算が となり深層学習では非現実的です。
L-BFGS (Limited-memory BFGS)
Hessianを直接保存せず、最近 個の勾配差分で近似します。
ここで 、 です。
import torch
import torch.optim as optim
# L-BFGSはクロージャ(closure)関数が必要
optimizer = optim.LBFGS(
model.parameters(),
lr=1.0,
max_iter=20,
history_size=10,
line_search_fn='strong_wolfe'
)
def closure():
optimizer.zero_grad()
output = model(input_data)
loss = criterion(output, target)
loss.backward()
return loss
optimizer.step(closure)
自然勾配降下法 (Natural Gradient Descent)
Fisher情報行列を使用してパラメータ空間の曲率を考慮します。
Fisher行列:
K-FAC (Kronecker-factored Approximate Curvature) は自然勾配法を層ごとに分解して実用的に実装します。
学習率スケジューリング
Warmup
初期に学習率を徐々に増加させてトレーニングを安定化します。
Cosine Annealing
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR, ReduceLROnPlateau
# Cosine Annealing
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)
# OneCycleLR: Warmup + コサイン減衰
scheduler = OneCycleLR(
optimizer,
max_lr=1e-3,
total_steps=1000,
pct_start=0.3, # 30% warmupフェーズ
anneal_strategy='cos'
)
# ReduceLROnPlateau: 検証損失が改善しない場合に減少
scheduler = ReduceLROnPlateau(
optimizer,
mode='min',
factor=0.5,
patience=10,
min_lr=1e-6
)
Cyclical Learning Rate (CLR)
学習率を周期的に変動させて鞍点からの脱出を助けます。
| スケジューラ | 特徴 | 適したシナリオ |
|---|---|---|
| Cosine Annealing | 滑らかな減少 | Transformerの事前学習 |
| OneCycleLR | Warmup + 急速な減少 | ファインチューニング、短い学習 |
| ReduceLROnPlateau | 適応型 | 一般的な学習 |
| Cyclical LR | 周期的変動 | 鞍点回避 |
| Linear Warmup | 初期安定化 | LLM学習 |
正則化手法
L1 / L2 正則化
L2正則化 (Ridge):
勾配:
L1正則化 (Lasso):
L1はスパース(疎)な解を誘導して多くの重みを正確にゼロにします。
Batch NormalizationとLayer Normalization
Batch Normalization (BN):
ここで 、 はミニバッチ内の統計量です。バッチ方向に正規化します。
Layer Normalization (LN):
ここで統計量は各サンプルの特徴次元を使って計算されます。
| 正規化 | 統計量の計算軸 | 適したシナリオ |
|---|---|---|
| Batch Norm | バッチ方向(同じ特徴) | CNN、大バッチ |
| Layer Norm | 特徴方向(同じサンプル) | Transformer、RNN |
| Instance Norm | 空間方向(同じチャネル) | スタイル転送 |
| Group Norm | チャネルグループ | 小さいバッチ |
Weight DecayとL2正則化
SGDでは:
この場合、weight decayとL2正則化は同一です。しかしAdamでは:
- L2 Adam: 勾配に を加えた後、適応スケーリング係数で除算 → 正則化効果が弱まる
- AdamW: パラメータ更新後に を直接引く → すべてのパラメータに均等なweight decay
import torch.nn as nn
class RegularizedModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(512, 256)
self.bn1 = nn.BatchNorm1d(256)
self.ln1 = nn.LayerNorm(256)
self.dropout = nn.Dropout(p=0.3)
def forward(self, x):
x = self.fc1(x)
x = self.bn1(x) # Transformer向けはself.ln1(x)
x = torch.relu(x)
x = self.dropout(x)
return x
損失関数の設計
Cross-Entropy Loss
二値分類:
Focal Loss
クラス不均衡問題を解決します。簡単なサンプルの寄与を減らします。
ここで は正解クラスの予測確率、 はfocusing parameterです。 の場合は通常のCross-Entropyと同じです。
import torch
import torch.nn.functional as F
class FocalLoss(torch.nn.Module):
def __init__(self, alpha=0.25, gamma=2.0):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, logits, targets):
bce_loss = F.binary_cross_entropy_with_logits(
logits, targets.float(), reduction='none'
)
p = torch.sigmoid(logits)
p_t = p * targets + (1 - p) * (1 - targets)
focal_weight = (1 - p_t) ** self.gamma
alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
loss = alpha_t * focal_weight * bce_loss
return loss.mean()
Contrastive LossとTriplet Loss
Contrastive Loss (Siameseネットワーク):
ここで 、 が類似ペア、 が非類似ペアです。
Triplet Loss:
アンカー(a)、ポジティブ(p)、ネガティブ(n)サンプルを使用します。
InfoNCE Loss (NT-Xent)
対照学習(Contrastive Learning)の核心損失関数です。
ここで はtemperatureパラメータ、 はコサイン類似度です。
import torch
import torch.nn.functional as F
def info_nce_loss(features, temperature=0.07):
"""
features: (2N, D) - 各画像の2つのaugmentationビュー
"""
N = features.shape[0] // 2
features = F.normalize(features, dim=1)
# 類似度行列を計算
similarity = torch.matmul(features, features.T) / temperature
# 自己類似度を除去(対角を-infに)
mask = torch.eye(2 * N, dtype=torch.bool, device=features.device)
similarity.masked_fill_(mask, float('-inf'))
# ポジティブペア: iとi+N、i+Nとi
labels = torch.cat([
torch.arange(N, 2 * N),
torch.arange(N)
]).to(features.device)
loss = F.cross_entropy(similarity, labels)
return loss
LLMトレーニング最適化
勾配クリッピング (Gradient Clipping)
勾配爆発(exploding gradient)を防ぎます。
import torch
def train_with_clipping(model, optimizer, loss, max_norm=1.0):
optimizer.zero_grad()
loss.backward()
# クリッピング前の勾配ノルムを監視
total_norm = 0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
# クリッピング適用
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
optimizer.step()
return total_norm
ZeRO Optimizer (Zero Redundancy Optimizer)
モデル学習時のメモリを3段階で最適化します。
| ZeROステージ | 分散対象 | メモリ削減 |
|---|---|---|
| Stage 1 | Optimizer状態 | 約4倍 |
| Stage 2 | + 勾配 | 約8倍 |
| Stage 3 | + パラメータ | 約64倍(N GPU基準) |
混合精度(FP16/BF16) + ZeRO-3で数十億パラメータのモデルを単一ノードで学習可能です。
8-bit Adam
量子化を通じてoptimizer状態をFP32ではなくINT8で保存します。
- Optimizer状態のメモリをFP32比で75%削減
- ブロックごとの量子化で精度損失を最小化
bitsandbytesライブラリで実装可能
# bitsandbytes 8-bit Adam
import bitsandbytes as bnb
optimizer = bnb.optim.Adam8bit(
model.parameters(),
lr=1e-4,
betas=(0.9, 0.999)
)
Adafactor
Adamの2次モーメントを行列分解で近似します。
パラメータサイズに比例したメモリのみ使用(行ベクトル + 列ベクトル)。T5、PaLMなどの超大規模モデルの学習に使用されます。
| Optimizer | メモリ(パラメータ比) | LLM適合度 |
|---|---|---|
| Adam | 8倍(params + 2 states) | 普通 |
| AdamW | 8倍 | 良い |
| 8-bit Adam | 6倍 | 良い |
| Adafactor | 約2倍 | 非常に良い |
| Lion | 6倍 | 良い |
クイズ
Q1. Adam optimizerでbias correctionが必要な理由は何ですか?
答え: モーメント推定値のゼロ初期化によるバイアスを補正するためです。
解説: Adamでは 、 で初期化します。初期のタイムステップ では、 と が実際の勾配のモーメントを過小評価します。例えば では となり、その期待値 は よりはるかに小さいです。 で割ることでこれを補正します。 で の場合、補正係数は です。 が大きくなると となり補正係数が1に近づいて効果がなくなります。
Q2. Weight DecayとL2正則化がAdamで同等でない理由(AdamW登場の背景)は?
答え: Adamの適応学習率がL2ペナルティの勾配をスケーリングするため、正則化効果が弱まるからです。
解説: SGDでは により両者は数学的に同一です。しかしAdamにL2正則化を追加すると、勾配が となり、適応スケーリング係数 で除算されます。勾配分散が大きいパラメータ(大きな )ではL2ペナルティも小さくなります。AdamWはweight decayを勾配更新から分離して と処理することで、すべてのパラメータに均等な正則化を適用します。
Q3. Batch NormalizationとLayer Normalizationの違いと各々が適した状況は?
答え: BNはバッチ次元で、LNは各サンプルの特徴次元で正規化します。
解説: BNはミニバッチ内の同じ特徴(ニューロン)の平均・分散で正規化します。バッチサイズに依存し、小さいバッチでは統計量の推定が不安定になります。空間的特徴があり十分なバッチサイズを持つCNNに適しています。LNは各サンプルの特徴次元に沿って正規化するためバッチサイズに依存しません。系列長が可変なTransformerや、バッチ統計の維持が困難なRNN、またオンライン推論シナリオに適しています。
Q4. Focal LossがCross-Entropyよりクラス不均衡問題に効果的な数学的原理は?
答え: の重みが簡単なサンプルの寄与を動的に減少させるからです。
解説: 通常のCE損失 は多数クラスの簡単なサンプルも同等に扱います。Focal Lossの を見ると、(簡単なサンプル)の場合 で重みが非常に小さくなります。一方 (難しいサンプル)の場合 でほぼそのまま維持されます。 を使用すると簡単なサンプルの損失が100倍減少します。これにより、モデルが難しい少数クラスのサンプルに集中して学習します。
Q5. InfoNCE Lossが対照学習で良い表現を学習する原理は?
答え: 同じ画像の異なるaugmentationペアを類似させ、異なる画像は遠ざけるように学習するからです。
解説: InfoNCEは相互情報量(mutual information)の下限を最大化します。分子 は同じ画像の2つのビュー(ポジティブペア)の類似度を高め、分母には 個のネガティブペアが含まれます。Temperature は分布の鋭さを調整します。 が小さいほど競争が激しくなり表現空間がより識別的になります。大規模バッチで多様なネガティブサンプルを提供するほど表現がより一般化されます。SimCLR、MoCo、CLIPなどの主要な対照学習モデルがこの損失関数を使用しています。