Skip to content
Published on

機械学習のための数学的最適化: Adamから凸最適化、ZeRO optimizerまで

Authors

目次

  1. 最適化の基礎: 凸最適化とKKT条件
  2. 勾配降下法系のoptimizer
  3. 2次最適化法
  4. 学習率スケジューリング
  5. 正則化手法
  6. 損失関数の設計
  7. LLMトレーニング最適化
  8. クイズ

最適化の基礎

凸最適化 (Convex Optimization)

関数 f:RnRf: \mathbb{R}^n \to \mathbb{R} が**凸(convex)**であるとは、任意の2点 x,yx, yλ[0,1]\lambda \in [0,1] に対して次が成立することです。

f(λx+(1λ)y)λf(x)+(1λ)f(y)f(\lambda x + (1-\lambda)y) \leq \lambda f(x) + (1-\lambda)f(y)

凸関数の重要な性質:

  • すべての局所最小値が大域最小値と一致する
  • 勾配降下法の収束が保証される
  • 深層学習の損失関数はほとんど非凸だが、凸解析の手法が依然として有用

強凸性(Strongly Convex): m>0m > 0 が存在して f(y)f(x)+f(x)T(yx)+m2yx2f(y) \geq f(x) + \nabla f(x)^T(y-x) + \frac{m}{2}\|y-x\|^2 が成立すれば、収束が線形(線形収束)になります。

ラグランジュ乗数法

等式制約付き最適化問題を扱います。

minxf(x)subject togi(x)=0,  i=1,,m\min_x f(x) \quad \text{subject to} \quad g_i(x) = 0, \; i = 1, \ldots, m

ラグランジアン:

L(x,λ)=f(x)+i=1mλigi(x)\mathcal{L}(x, \lambda) = f(x) + \sum_{i=1}^{m} \lambda_i g_i(x)

最適解では xL=0\nabla_x \mathcal{L} = 0λL=0\nabla_\lambda \mathcal{L} = 0 が成立します。

KKT条件

不等式制約を含む一般的な最適化問題:

minxf(x)s.t.gi(x)0,  hj(x)=0\min_x f(x) \quad \text{s.t.} \quad g_i(x) \leq 0, \; h_j(x) = 0

KKT条件(必要条件):

  1. 停留性(Stationarity): f(x)+iμigi(x)+jλjhj(x)=0\nabla f(x^*) + \sum_i \mu_i \nabla g_i(x^*) + \sum_j \lambda_j \nabla h_j(x^*) = 0
  2. 主実行可能性(Primal feasibility): gi(x)0g_i(x^*) \leq 0hj(x)=0h_j(x^*) = 0
  3. 双対実行可能性(Dual feasibility): μi0\mu_i \geq 0
  4. 相補スラック(Complementary slackness): μigi(x)=0\mu_i g_i(x^*) = 0

凸問題ではKKT条件は十分条件にもなります。

鞍点 (Saddle Point)

深層学習の最適化では、局所最小値より鞍点の方が大きな問題です。鞍点では勾配がゼロになりますが、ある方向では関数値が増加し、別の方向では減少します。SGDの確率的ノイズが鞍点からの脱出を助けます。


勾配降下法系

SGDとその変形

基本SGD:

θt+1=θtηθL(θt;xi,yi)\theta_{t+1} = \theta_t - \eta \nabla_\theta \mathcal{L}(\theta_t; x_i, y_i)

SGD with Momentum:

vt+1=βvt+θL(θt)v_{t+1} = \beta v_t + \nabla_\theta \mathcal{L}(\theta_t) θt+1=θtηvt+1\theta_{t+1} = \theta_t - \eta v_{t+1}

モメンタム β=0.9\beta = 0.9 が標準的で、過去の勾配方向を保持してoscillationを低減します。

Nesterov加速勾配(NAG):

vt+1=βvt+θL(θtβvt)v_{t+1} = \beta v_t + \nabla_\theta \mathcal{L}(\theta_t - \beta v_t) θt+1=θtηvt+1\theta_{t+1} = \theta_t - \eta v_{t+1}

現在位置ではなく「先読み」位置で勾配を計算します。

AdaGrad、RMSProp、Adam

AdaGrad: パラメータごとの適応学習率

Gt=Gt1+gt2G_t = G_{t-1} + g_t^2 θt+1=θtηGt+ϵgt\theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{G_t + \epsilon}} g_t

頻出する特徴は学習率が小さく、稀な特徴は大きい学習率になります。欠点: 学習率が単調減少して学習が停滞します。

RMSProp: AdaGradの蓄積問題を解決

vt=βvt1+(1β)gt2v_t = \beta v_{t-1} + (1-\beta) g_t^2 θt+1=θtηvt+ϵgt\theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{v_t + \epsilon}} g_t

Adam (Adaptive Moment Estimation):

mt=β1mt1+(1β1)gt(1次モーメント)m_t = \beta_1 m_{t-1} + (1-\beta_1) g_t \quad \text{(1次モーメント)} vt=β2vt1+(1β2)gt2(2次モーメント)v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t^2 \quad \text{(2次モーメント)}

バイアス補正:

m^t=mt1β1t,v^t=vt1β2t\hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t}

θt+1=θtηv^t+ϵm^t\theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t

デフォルトハイパーパラメータ: β1=0.9\beta_1 = 0.9β2=0.999\beta_2 = 0.999ϵ=108\epsilon = 10^{-8}

import torch
import torch.optim as optim

model = ...  # モデル定義

# 標準Adam
optimizer_adam = optim.Adam(
    model.parameters(),
    lr=1e-3,
    betas=(0.9, 0.999),
    eps=1e-8
)

# AdamW (weight decayを分離)
optimizer_adamw = optim.AdamW(
    model.parameters(),
    lr=1e-3,
    betas=(0.9, 0.999),
    eps=1e-8,
    weight_decay=0.01  # 勾配スケーリングとは独立して適用
)

AdamWとLion

AdamW: weight decayをL2ペナルティとしてではなく、パラメータ更新に直接適用します。

θt+1=θtη(m^tv^t+ϵ+λθt)\theta_{t+1} = \theta_t - \eta \left(\frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda \theta_t\right)

AdamにL2正則化を追加する場合と数学的に同等ではありません(クイズを参照)。

Lion (EvoLved Sign Momentum):

ut=β1mt1+(1β1)gtu_t = \beta_1 m_{t-1} + (1-\beta_1) g_t θt+1=θtηsign(ut)\theta_{t+1} = \theta_t - \eta \cdot \text{sign}(u_t) mt=β2mt1+(1β2)gtm_t = \beta_2 m_{t-1} + (1-\beta_2) g_t

符号(sign)のみを使用するため、更新量が均一でメモリ効率が良いです。

Optimizerメモリ収束速度適したシナリオ
SGD+Momentum遅いコンピュータビジョン、大バッチ
Adam速いNLP、汎用
AdamW速いTransformerの学習
Lion速い大規模モデル
L-BFGS非常に速い小規模モデル

2次最適化法

Newton法

2次微分(Hessian)を活用します。

θt+1=θtHt1f(θt)\theta_{t+1} = \theta_t - H_t^{-1} \nabla f(\theta_t)

ここで Ht=2f(θt)H_t = \nabla^2 f(\theta_t) はHessian行列です。2次収束しますが、n×nn \times n 行列の逆行列計算が O(n3)O(n^3) となり深層学習では非現実的です。

L-BFGS (Limited-memory BFGS)

Hessianを直接保存せず、最近 mm 個の勾配差分で近似します。

Ht1ベクトル列 {sk},{yk} による近似H_t^{-1} \approx \text{ベクトル列 } \{s_k\}, \{y_k\}\text{ による近似}

ここで sk=θk+1θks_k = \theta_{k+1} - \theta_kyk=fk+1fky_k = \nabla f_{k+1} - \nabla f_k です。

import torch
import torch.optim as optim

# L-BFGSはクロージャ(closure)関数が必要
optimizer = optim.LBFGS(
    model.parameters(),
    lr=1.0,
    max_iter=20,
    history_size=10,
    line_search_fn='strong_wolfe'
)

def closure():
    optimizer.zero_grad()
    output = model(input_data)
    loss = criterion(output, target)
    loss.backward()
    return loss

optimizer.step(closure)

自然勾配降下法 (Natural Gradient Descent)

Fisher情報行列を使用してパラメータ空間の曲率を考慮します。

θt+1=θtηF(θt)1L(θt)\theta_{t+1} = \theta_t - \eta F(\theta_t)^{-1} \nabla \mathcal{L}(\theta_t)

Fisher行列: F(θ)=E[logp(yx;θ)logp(yx;θ)T]F(\theta) = \mathbb{E}\left[\nabla \log p(y|x;\theta) \nabla \log p(y|x;\theta)^T\right]

K-FAC (Kronecker-factored Approximate Curvature) は自然勾配法を層ごとに分解して実用的に実装します。


学習率スケジューリング

Warmup

初期に学習率を徐々に増加させてトレーニングを安定化します。

ηt=ηmaxtTwarmup(tTwarmup)\eta_t = \eta_{\max} \cdot \frac{t}{T_{\text{warmup}}} \quad (t \leq T_{\text{warmup}})

Cosine Annealing

ηt=ηmin+12(ηmaxηmin)(1+cosπtT)\eta_t = \eta_{\min} + \frac{1}{2}(\eta_{\max} - \eta_{\min})\left(1 + \cos\frac{\pi t}{T}\right)

from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR, ReduceLROnPlateau

# Cosine Annealing
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)

# OneCycleLR: Warmup + コサイン減衰
scheduler = OneCycleLR(
    optimizer,
    max_lr=1e-3,
    total_steps=1000,
    pct_start=0.3,         # 30% warmupフェーズ
    anneal_strategy='cos'
)

# ReduceLROnPlateau: 検証損失が改善しない場合に減少
scheduler = ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=10,
    min_lr=1e-6
)

Cyclical Learning Rate (CLR)

学習率を周期的に変動させて鞍点からの脱出を助けます。

ηt=ηmin+(ηmaxηmin)max(0,1tstep_size2k+1)\eta_t = \eta_{\min} + (\eta_{\max} - \eta_{\min}) \cdot \max\left(0, 1 - \left|\frac{t}{\text{step\_size}} - 2k + 1\right|\right)

スケジューラ特徴適したシナリオ
Cosine Annealing滑らかな減少Transformerの事前学習
OneCycleLRWarmup + 急速な減少ファインチューニング、短い学習
ReduceLROnPlateau適応型一般的な学習
Cyclical LR周期的変動鞍点回避
Linear Warmup初期安定化LLM学習

正則化手法

L1 / L2 正則化

L2正則化 (Ridge):

Lreg=L+λ2θ22\mathcal{L}_{\text{reg}} = \mathcal{L} + \frac{\lambda}{2} \|\theta\|_2^2

勾配: θLreg=θL+λθ\nabla_\theta \mathcal{L}_{\text{reg}} = \nabla_\theta \mathcal{L} + \lambda \theta

L1正則化 (Lasso):

Lreg=L+λθ1\mathcal{L}_{\text{reg}} = \mathcal{L} + \lambda \|\theta\|_1

L1はスパース(疎)な解を誘導して多くの重みを正確にゼロにします。

Batch NormalizationとLayer Normalization

Batch Normalization (BN):

x^i=xiμBσB2+ϵγ+β\hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}} \cdot \gamma + \beta

ここで μB\mu_BσB2\sigma_B^2 はミニバッチ内の統計量です。バッチ方向に正規化します。

Layer Normalization (LN):

x^=xμLσL2+ϵγ+β\hat{x} = \frac{x - \mu_L}{\sqrt{\sigma_L^2 + \epsilon}} \cdot \gamma + \beta

ここで統計量は各サンプルの特徴次元を使って計算されます。

正規化統計量の計算軸適したシナリオ
Batch Normバッチ方向(同じ特徴)CNN、大バッチ
Layer Norm特徴方向(同じサンプル)Transformer、RNN
Instance Norm空間方向(同じチャネル)スタイル転送
Group Normチャネルグループ小さいバッチ

Weight DecayとL2正則化

SGDでは:

θt+1=θtη(L+λθt)=(1ηλ)θtηL\theta_{t+1} = \theta_t - \eta(\nabla \mathcal{L} + \lambda \theta_t) = (1 - \eta\lambda)\theta_t - \eta \nabla \mathcal{L}

この場合、weight decayとL2正則化は同一です。しかしAdamでは:

  • L2 Adam: 勾配に λθ\lambda\theta を加えた後、適応スケーリング係数で除算 → 正則化効果が弱まる
  • AdamW: パラメータ更新後に λθ\lambda\theta を直接引く → すべてのパラメータに均等なweight decay
import torch.nn as nn

class RegularizedModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(512, 256)
        self.bn1 = nn.BatchNorm1d(256)
        self.ln1 = nn.LayerNorm(256)
        self.dropout = nn.Dropout(p=0.3)

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)   # Transformer向けはself.ln1(x)
        x = torch.relu(x)
        x = self.dropout(x)
        return x

損失関数の設計

Cross-Entropy Loss

LCE=c=1Cyclogp^c\mathcal{L}_{CE} = -\sum_{c=1}^{C} y_c \log \hat{p}_c

二値分類: LBCE=[ylogp+(1y)log(1p)]\mathcal{L}_{BCE} = -[y \log p + (1-y)\log(1-p)]

Focal Loss

クラス不均衡問題を解決します。簡単なサンプルの寄与を減らします。

LFL=(1pt)γlog(pt)\mathcal{L}_{FL} = -(1-p_t)^\gamma \log(p_t)

ここで ptp_t は正解クラスの予測確率、γ0\gamma \geq 0 はfocusing parameterです。γ=0\gamma = 0 の場合は通常のCross-Entropyと同じです。

import torch
import torch.nn.functional as F

class FocalLoss(torch.nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, logits, targets):
        bce_loss = F.binary_cross_entropy_with_logits(
            logits, targets.float(), reduction='none'
        )
        p = torch.sigmoid(logits)
        p_t = p * targets + (1 - p) * (1 - targets)
        focal_weight = (1 - p_t) ** self.gamma
        alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
        loss = alpha_t * focal_weight * bce_loss
        return loss.mean()

Contrastive LossとTriplet Loss

Contrastive Loss (Siameseネットワーク):

L=(1y)d22+ymax(0,md)2\mathcal{L} = (1-y)\frac{d^2}{2} + y \cdot \max(0, m - d)^2

ここで d=f(x1)f(x2)2d = \|f(x_1) - f(x_2)\|_2y=0y=0 が類似ペア、y=1y=1 が非類似ペアです。

Triplet Loss:

Ltrip=max(0,f(a)f(p)22f(a)f(n)22+m)\mathcal{L}_{trip} = \max(0, \|f(a) - f(p)\|_2^2 - \|f(a) - f(n)\|_2^2 + m)

アンカー(a)、ポジティブ(p)、ネガティブ(n)サンプルを使用します。

InfoNCE Loss (NT-Xent)

対照学習(Contrastive Learning)の核心損失関数です。

LInfoNCE=logexp(sim(zi,zj)/τ)k=12N1kiexp(sim(zi,zk)/τ)\mathcal{L}_{InfoNCE} = -\log \frac{\exp(\text{sim}(z_i, z_j)/\tau)}{\sum_{k=1}^{2N} \mathbf{1}_{k \neq i} \exp(\text{sim}(z_i, z_k)/\tau)}

ここで τ\tau はtemperatureパラメータ、sim\text{sim} はコサイン類似度です。

import torch
import torch.nn.functional as F

def info_nce_loss(features, temperature=0.07):
    """
    features: (2N, D) - 各画像の2つのaugmentationビュー
    """
    N = features.shape[0] // 2
    features = F.normalize(features, dim=1)

    # 類似度行列を計算
    similarity = torch.matmul(features, features.T) / temperature

    # 自己類似度を除去(対角を-infに)
    mask = torch.eye(2 * N, dtype=torch.bool, device=features.device)
    similarity.masked_fill_(mask, float('-inf'))

    # ポジティブペア: iとi+N、i+Nとi
    labels = torch.cat([
        torch.arange(N, 2 * N),
        torch.arange(N)
    ]).to(features.device)

    loss = F.cross_entropy(similarity, labels)
    return loss

LLMトレーニング最適化

勾配クリッピング (Gradient Clipping)

勾配爆発(exploding gradient)を防ぎます。

ggmin(1,clip_normg2)g \leftarrow g \cdot \min\left(1, \frac{\text{clip\_norm}}{\|g\|_2}\right)

import torch

def train_with_clipping(model, optimizer, loss, max_norm=1.0):
    optimizer.zero_grad()
    loss.backward()

    # クリッピング前の勾配ノルムを監視
    total_norm = 0
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** 0.5

    # クリッピング適用
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
    optimizer.step()
    return total_norm

ZeRO Optimizer (Zero Redundancy Optimizer)

モデル学習時のメモリを3段階で最適化します。

ZeROステージ分散対象メモリ削減
Stage 1Optimizer状態約4倍
Stage 2+ 勾配約8倍
Stage 3+ パラメータ約64倍(N GPU基準)

混合精度(FP16/BF16) + ZeRO-3で数十億パラメータのモデルを単一ノードで学習可能です。

8-bit Adam

量子化を通じてoptimizer状態をFP32ではなくINT8で保存します。

  • Optimizer状態のメモリをFP32比で75%削減
  • ブロックごとの量子化で精度損失を最小化
  • bitsandbytes ライブラリで実装可能
# bitsandbytes 8-bit Adam
import bitsandbytes as bnb

optimizer = bnb.optim.Adam8bit(
    model.parameters(),
    lr=1e-4,
    betas=(0.9, 0.999)
)

Adafactor

Adamの2次モーメントを行列分解で近似します。

Vtr^tv^tT(ランク1近似)V_t \approx \hat{r}_t \hat{v}_t^T \quad \text{(ランク1近似)}

パラメータサイズに比例したメモリのみ使用(行ベクトル + 列ベクトル)。T5、PaLMなどの超大規模モデルの学習に使用されます。

Optimizerメモリ(パラメータ比)LLM適合度
Adam8倍(params + 2 states)普通
AdamW8倍良い
8-bit Adam6倍良い
Adafactor約2倍非常に良い
Lion6倍良い

クイズ

Q1. Adam optimizerでbias correctionが必要な理由は何ですか?

答え: モーメント推定値のゼロ初期化によるバイアスを補正するためです。

解説: Adamでは m0=0m_0 = 0v0=0v_0 = 0 で初期化します。初期のタイムステップ tt では、mtm_tvtv_t が実際の勾配のモーメントを過小評価します。例えば t=1t=1 では m1=(1β1)g1m_1 = (1-\beta_1)g_1 となり、その期待値 (1β1)E[g1](1-\beta_1)\mathbb{E}[g_1]E[g1]\mathbb{E}[g_1] よりはるかに小さいです。(1β1t)(1-\beta_1^t) で割ることでこれを補正します。β1=0.9\beta_1 = 0.9t=1t=1 の場合、補正係数は 1/0.1=101/0.1 = 10 です。tt が大きくなると β1t0\beta_1^t \to 0 となり補正係数が1に近づいて効果がなくなります。

Q2. Weight DecayとL2正則化がAdamで同等でない理由(AdamW登場の背景)は?

答え: Adamの適応学習率がL2ペナルティの勾配をスケーリングするため、正則化効果が弱まるからです。

解説: SGDでは θθη(L+λθ)\theta \leftarrow \theta - \eta(\nabla \mathcal{L} + \lambda\theta) により両者は数学的に同一です。しかしAdamにL2正則化を追加すると、勾配が gt+λθtg_t + \lambda\theta_t となり、適応スケーリング係数 1/v^t1/\sqrt{\hat{v}_t} で除算されます。勾配分散が大きいパラメータ(大きな vtv_t)ではL2ペナルティも小さくなります。AdamWはweight decayを勾配更新から分離して θθ(1ηλ)ηm^t/(v^t+ϵ)\theta \leftarrow \theta(1-\eta\lambda) - \eta\hat{m}_t/(\sqrt{\hat{v}_t}+\epsilon) と処理することで、すべてのパラメータに均等な正則化を適用します。

Q3. Batch NormalizationとLayer Normalizationの違いと各々が適した状況は?

答え: BNはバッチ次元で、LNは各サンプルの特徴次元で正規化します。

解説: BNはミニバッチ内の同じ特徴(ニューロン)の平均・分散で正規化します。バッチサイズに依存し、小さいバッチでは統計量の推定が不安定になります。空間的特徴があり十分なバッチサイズを持つCNNに適しています。LNは各サンプルの特徴次元に沿って正規化するためバッチサイズに依存しません。系列長が可変なTransformerや、バッチ統計の維持が困難なRNN、またオンライン推論シナリオに適しています。

Q4. Focal LossがCross-Entropyよりクラス不均衡問題に効果的な数学的原理は?

答え: (1pt)γ(1-p_t)^\gamma の重みが簡単なサンプルの寄与を動的に減少させるからです。

解説: 通常のCE損失 log(pt)-\log(p_t) は多数クラスの簡単なサンプルも同等に扱います。Focal Lossの (1pt)γ(1-p_t)^\gamma を見ると、pt=0.9p_t = 0.9(簡単なサンプル)の場合 (10.9)2=0.01(1-0.9)^2 = 0.01 で重みが非常に小さくなります。一方 pt=0.1p_t = 0.1(難しいサンプル)の場合 (10.1)2=0.81(1-0.1)^2 = 0.81 でほぼそのまま維持されます。γ=2\gamma = 2 を使用すると簡単なサンプルの損失が100倍減少します。これにより、モデルが難しい少数クラスのサンプルに集中して学習します。

Q5. InfoNCE Lossが対照学習で良い表現を学習する原理は?

答え: 同じ画像の異なるaugmentationペアを類似させ、異なる画像は遠ざけるように学習するからです。

解説: InfoNCEは相互情報量(mutual information)の下限を最大化します。分子 exp(sim(zi,zj)/τ)\exp(\text{sim}(z_i, z_j)/\tau) は同じ画像の2つのビュー(ポジティブペア)の類似度を高め、分母には 2N12N-1 個のネガティブペアが含まれます。Temperature τ\tau は分布の鋭さを調整します。τ\tau が小さいほど競争が激しくなり表現空間がより識別的になります。大規模バッチで多様なネガティブサンプルを提供するほど表現がより一般化されます。SimCLR、MoCo、CLIPなどの主要な対照学習モデルがこの損失関数を使用しています。