Skip to content
Published on

GPUメモリ最適化と混合精度トレーニング完全ガイド

Authors
  • Name
    Twitter

1. GPUメモリ構成の分析

ディープラーニングモデルの学習時、GPUメモリにはモデルパラメータ以外にも多くのデータが格納されます。学習プロセス全体でGPUメモリを占有する構成要素は、大きく4つのカテゴリに分類できます。これらの構成要素を正確に理解することが、効果的なメモリ最適化戦略を立てる上で不可欠です。

1.1 モデルパラメータ

モデルパラメータとは、ニューラルネットワークの重みとバイアスを指します。学習プロセス全体を通じてGPUメモリ上に常駐し、パラメータ数とデータ型によってメモリ使用量が決まります。

  • FP32基準: パラメータあたり4バイト
  • FP16/BF16基準: パラメータあたり2バイト

例えば、7B(70億)パラメータのモデルをFP32で格納するには約28GB、FP16では約14GBが必要です。

1.2 勾配

勾配は、損失関数に対する各パラメータの偏微分値です。逆伝播時に計算され、各パラメータに対応する勾配値が存在するため、パラメータと同量のメモリを占有します。

  • FP32学習: パラメータ数 x 4バイト
  • Mixed Precision学習: パラメータ数 x 2バイト(FP16で格納)

1.3 オプティマイザの状態

オプティマイザの状態は、メモリの中で最も大きな割合を占めることが多い要素です。SGDはモメンタムのみを保持しますが、Adam/AdamWオプティマイザは各パラメータに対して**第1モーメント(平均)第2モーメント(分散)**の2つの追加状態を保持します。

FP32学習時のAdamオプティマイザのオプティマイザ状態メモリ:

パラメータ数をΦとすると:
- マスターウェイト (FP32): 4Φ バイト
- モメンタム (FP32): 4Φ バイト
- 分散 (FP32): 4Φ バイト
- 合計: 12Φ バイト

7BモデルでAdamを使用すると、オプティマイザの状態だけで約84GBが必要になります。これが、大規模モデル学習においてオプティマイザの状態の最適化が極めて重要である理由です。

1.4 活性化値

活性化値は、順伝播時の各層の出力であり、逆伝播時の勾配計算に必要なため保持される必要があります。活性化値のメモリは以下の要因に比例します:

  • バッチサイズ: 大きいほどメモリ消費が増加
  • シーケンス長: Transformerモデルでは特に重要(Attentionで O(n^2))
  • 隠れ層の次元数: モデルが大きいほど増加
  • レイヤー数: モデルが深いほど増加

大規模モデルでは、活性化値のメモリがパラメータメモリを上回ることも頻繁にあります。

1.5 全体メモリのまとめ

Mixed Precision TrainingでAdamオプティマイザを使用する場合、Φパラメータに対する総メモリ要件は以下の通りです:

構成要素データ型メモリ
モデルパラメータ (FP16コピー)FP162Φ バイト
勾配 (FP16)FP162Φ バイト
マスターウェイトFP324Φ バイト
オプティマイザのモメンタムFP324Φ バイト
オプティマイザの分散FP324Φ バイト
合計16Φ バイト

さらに、活性化値のメモリ、一時バッファ、メモリの断片化により、実際には約20〜30%の追加オーバーヘッドが発生します。


2. FP32 vs FP16 vs BF16 vs FP8 数値表現の比較

NVIDIA GPUがサポートする浮動小数点フォーマットにはそれぞれ固有の特性があります。学習の安定性とパフォーマンスのトレードオフを理解するために、各フォーマットのビット構成と表現範囲を比較します。

2.1 ビット構成

IEEE 754に基づくすべての浮動小数点数は、符号指数部仮数部の3つの部分で構成されます。

フォーマット符号指数部仮数部総ビット数メモリ
FP321823324バイト
FP161510162バイト
BF16187162バイト
FP8 (E4M3)14381バイト
FP8 (E5M2)15281バイト

2.2 ダイナミックレンジと精度

各フォーマットの表現可能な範囲と精度の比較:

FP32(単精度)

  • ダイナミックレンジ: 約1.2 x 10^-38 〜 3.4 x 10^38
  • 精度: 約7桁の10進数
  • ディープラーニングのデフォルトデータ型であり、最も高い精度を提供します。

FP16(半精度)

  • ダイナミックレンジ: 約6.1 x 10^-5 〜 6.55 x 10^4
  • 精度: 約3桁の10進数
  • 狭いダイナミックレンジにより、勾配のアンダーフロー/オーバーフローが発生しやすくなります。必ずLoss Scalingと併用する必要があります。

BF16(Brain Floating Point 16)

  • ダイナミックレンジ: 約1.2 x 10^-38 〜 3.4 x 10^38(FP32と同じ)
  • 精度: 約2桁の10進数
  • FP32と同じ8ビットの指数部を持つため、ダイナミックレンジが同一です。精度は低いものの、Loss Scalingなしで安定した学習が可能であり、大規模モデル学習の新たな標準として台頭しています。NVIDIA Ampere(A100)以降でサポートされています。

FP8 E4M3

  • ダイナミックレンジ: 約±448
  • 精度: 最も低い
  • 主に順伝播で使用されます。狭いレンジを補うために、テンソルごとのスケーリングが不可欠です。

FP8 E5M2

  • ダイナミックレンジ: 約±57,344
  • 精度: E4M3より低いが、より広いレンジ
  • 主に逆伝播(勾配計算)で使用されます。勾配の広いダイナミックレンジに対応するためです。

2.3 各フォーマットの適切な使用シーン

FP32  → オプティマイザの状態、マスターウェイト(精密な累積演算が必要)
BF16  → 順伝播/逆伝播演算(広いレンジ、Loss Scaling不要)
FP16  → 順伝播/逆伝播演算(Tensor Core活用、Loss Scaling必要)
FP8Hopper/Blackwell GPUでの最大パフォーマンス(Transformer Engine使用)

3. Mixed Precision Trainingの原理

Mixed Precision Trainingは、NVIDIAが2017年の論文「Mixed Precision Training」(Micikevicius et al.)で提案した手法で、学習時にFP16(またはBF16)とFP32を混在させて使用します。その核心的な目的は、Tensor Coreによる計算の高速化とメモリ節約を実現しつつ、FP32レベルの学習精度を維持することです。

3.1 3つのコア技術

NVIDIA公式ドキュメントによると、Mixed Precision Trainingは3つのコア技術で構成されています:

(1) マスターウェイト(FP32コピーの維持)

FP16で演算を行う場合でも、モデルパラメータの別のFP32マスターコピーを維持します。ウェイト更新時には、勾配をFP32マスターウェイトに適用し、次の順伝播用にFP16に変換します。これは、FP16の低い精度では小さな勾配がウェイト更新に反映されない問題(スワンピング問題)があるためです。

順伝播: FP16ウェイト → FP16活性化値 → FP16損失
逆伝播: FP16勾配の計算
ウェイト更新: FP32マスターウェイト += 学習率 × FP32勾配
FP16ウェイト = cast(FP32マスターウェイト)

(2) Loss Scaling

FP16のダイナミックレンジ(約6.1 x 10^-5 〜 6.55 x 10^4)では、勾配値を収容するのに不十分な場合があります。実際の勾配値は10^-10以下であることも一般的ですが、FP16の最小の正の表現可能値は約6 x 10^-8です。これにより勾配のアンダーフローが発生し、小さな勾配がゼロとして扱われます。

Loss Scalingはこの問題を解決します:

  1. 順伝播で計算された損失値にスケールファクターを掛ける。
  2. 連鎖律により、逆伝播で計算されるすべての勾配にも同じスケールファクターが掛けられる。
  3. ウェイト更新前に、勾配をスケールファクターで割って元の大きさに戻す。

(3) Dynamic Loss Scaling

Dynamic Loss Scalingは、固定値ではなく学習中にスケールファクターを動的に調整します。NVIDIAの実装は以下のように動作します:

  1. 初期スケールファクターを大きく設定(例: 2^24 = 16,777,216)。
  2. 各イテレーションで勾配にinf/NaNが含まれていないか確認
  3. オーバーフローが発生しない場合: 現在のスケールファクターを維持し、N回連続成功後にスケールファクターを2倍にする(デフォルトN=2000)。
  4. オーバーフローが発生した場合: そのイテレーションのウェイト更新をスキップし、スケールファクターを半分にする。

このメカニズムにより、学習中の勾配分布の変化に自動的に適応します。学習後半で勾配の大きさが減少すると、スケールファクターは自然に増加してアンダーフローを防止します。

3.2 BF16学習の利点

BF16はFP32と同じダイナミックレンジを持つため、Loss Scalingが不要です。これにより実装が大幅に簡素化され、オーバーフロー/アンダーフロー関連の問題が根本的に排除されます。GoogleのTPUで最初に採用され、NVIDIA Ampere(A100)GPU以降でハードウェアレベルでサポートされています。

ただし、BF16はFP16の10ビットに対して仮数部が7ビットしかないため、精度が低くなります。一部のモデルでは、BF16学習が精度不足による収束問題を引き起こす可能性があるため、マスターウェイトは常にFP32で維持する必要があります。


4. PyTorch AMP公式ドキュメントの機能分析とコード例

PyTorchはtorch.ampモジュールを通じてAutomatic Mixed Precision(AMP)を公式にサポートしています。AMPはtorch.amp.autocasttorch.amp.GradScalerの2つのコアコンポーネントで構成されています。

注意: PyTorch 2.x以降、torch.cuda.amp.autocasttorch.cuda.amp.GradScalerは非推奨です。代わりにtorch.amp.autocast("cuda", ...)torch.amp.GradScaler("cuda", ...)を使用してください。

4.1 torch.amp.autocast

autocastはコンテキストマネージャまたはデコレータとして使用され、そのスコープ内の演算を適切な精度で自動的に実行します。すべての演算を一律にFP16に変換するのではなく、各演算タイプに対して最適なデータ型を選択します。

  • FP16で実行: Conv、Linear、MatMulなど(Tensor Coreを活用できる演算)
  • FP32を維持: Softmax、LayerNorm、損失計算など(数値安定性が重要な演算)

4.2 torch.amp.GradScaler

GradScalerはLoss Scalingを自動管理します。Dynamic Loss Scalingのプロセス全体(スケール適用、オーバーフロー確認、スケールファクター調整、ウェイト更新スキップ)を抽象化し、ユーザーが手動で管理する必要をなくします。

4.3 基本的な使用パターン

import torch
from torch.amp import autocast, GradScaler

model = MyModel().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scaler = GradScaler("cuda")

for epoch in range(num_epochs):
    for batch in dataloader:
        inputs, targets = batch
        inputs = inputs.cuda()
        targets = targets.cuda()

        optimizer.zero_grad()

        # autocast領域: 順伝播をMixed Precisionで実行
        with autocast("cuda"):
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)

        # GradScaler: 損失をスケーリングして逆伝播を実行
        scaler.scale(loss).backward()

        # GradScaler: 勾配のアンスケール → オーバーフロー確認 → オプティマイザステップ
        scaler.step(optimizer)

        # スケールファクターの更新
        scaler.update()

4.4 勾配クリッピングとの併用

Mixed Precisionで勾配クリッピングを適用する場合、クリッピング前に明示的にscaler.unscale_()を呼び出す必要があります:

scaler.scale(loss).backward()

# まず勾配のアンスケールを実行
scaler.unscale_(optimizer)

# 元の大きさに戻された勾配にクリッピングを適用
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# オプティマイザステップ(内部的にinf/NaNを確認してから実行)
scaler.step(optimizer)
scaler.update()

4.5 マルチGPU(DistributedDataParallel)での使用

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

model = DDP(model, device_ids=[local_rank])
scaler = GradScaler("cuda")

with autocast("cuda"):
    outputs = model(inputs)
    loss = loss_fn(outputs, targets)

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

DDPとAMPを併用する場合、GradScalerは各GPU上で独立して動作します。DDPのAllReduceはスケーリングされた勾配に対して実行されるため、アンスケーリングはAllReduce後に行う必要があります。PyTorchの実装ではこれが自動的に処理されます。

4.6 BF16の使用

BF16を使用する場合、Loss Scalingは不要なため、GradScalerなしでautocastのみを使用します:

with autocast("cuda", dtype=torch.bfloat16):
    outputs = model(inputs)
    loss = loss_fn(outputs, targets)

loss.backward()
optimizer.step()
optimizer.zero_grad()

このアプローチはコードがはるかに簡潔になり、Dynamic Loss Scalingによるステップスキップが発生しないため、学習がより安定します。


5. Gradient Checkpointing(Activation Checkpointing)の分析

Gradient Checkpointingは、メモリと計算時間のトレードオフを利用する手法です。順伝播時にすべての中間活性化値を保持する代わりに、一部のみを保持し、逆伝播時に残りを再計算します。

5.1 原理

通常の学習プロセスでは、順伝播時のすべてのレイヤー出力(活性化値)がメモリに格納されます。これは逆伝播時の勾配計算に必要なためです。Gradient Checkpointingはこのプロセスを変更します:

  1. 順伝播: 指定されたチェックポイント境界の入力のみを保存し、中間の活性化値は保持しない。
  2. 逆伝播: 保存された入力から各セグメントの順伝播を再実行して活性化値を再計算し、その後勾配を計算する。

これにより、活性化値のメモリはO(n)からO(sqrt(n))に削減されます(nはレイヤー数)。その代わり、順伝播の計算が約33%増加します(追加の順伝播1回分)。

5.2 PyTorchの実装: torch.utils.checkpoint

PyTorchはtorch.utils.checkpointモジュールを通じて2つのAPIを提供しています:

checkpoint関数

個々の関数やレイヤーにチェックポイントを適用します:

from torch.utils.checkpoint import checkpoint

class TransformerBlock(nn.Module):
    def __init__(self, ...):
        super().__init__()
        self.attention = MultiHeadAttention(...)
        self.ffn = FeedForward(...)
        self.norm1 = nn.LayerNorm(...)
        self.norm2 = nn.LayerNorm(...)

    def forward(self, x):
        # このブロックの中間活性化値を保持しない
        # 逆伝播時に再計算する
        return checkpoint(self._forward, x, use_reentrant=False)

    def _forward(self, x):
        x = x + self.attention(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x

checkpoint_sequential関数

シーケンシャル構造のレイヤーグループにチェックポイントを適用します:

from torch.utils.checkpoint import checkpoint_sequential

class DeepModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            *[TransformerBlock(...) for _ in range(24)]
        )

    def forward(self, x):
        # 24レイヤーを4グループ(各6レイヤー)に分割してチェックポイントを適用
        segments = 4
        return checkpoint_sequential(self.layers, segments, x,
                                     use_reentrant=False)

5.3 ReentrantとNon-reentrantチェックポイント

PyTorchは2つのチェックポイントモードを提供しています:

  • Reentrant (use_reentrant=True): レガシーアプローチ。逆伝播時に関数全体を再実行します。一部の制約があります(特定のAutograd機能と互換性がない場合がある)。
  • Non-reentrant (use_reentrant=False): 改良されたアプローチ。必要な中間活性化値のみを再計算します。PyTorch公式ドキュメントで推奨されており、将来的にデフォルトになる予定です。

5.4 実践的な適用戦略

すべてのレイヤーにチェックポイントを適用すると、大きな計算オーバーヘッドが発生します。効果的な戦略は以下の通りです:

  • Attentionレイヤーのみに適用: Self-attentionはO(n^2)の活性化メモリを持ち、最大の消費元です。
  • N番目のレイヤーごとに適用: 例えば、2レイヤーごとにチェックポイントを設定し、メモリ節約と計算オーバーヘッドのバランスを取ります。
  • Hugging Face Transformers: 1行で有効化可能: model.gradient_checkpointing_enable()

6. Gradient Accumulation手法

Gradient Accumulationは、物理的なバッチサイズの制約を克服するための手法です。GPUメモリが限られ大きなバッチを使用できない場合、複数の小さなマイクロバッチを順次処理し、勾配を蓄積してから1回のウェイト更新を行います。

6.1 原理

実効バッチサイズ = マイクロバッチサイズ x 蓄積ステップ数

例えば、マイクロバッチサイズが8で蓄積ステップ数が4の場合、実効バッチサイズは32になります。

重要なポイントは、PyTorchではloss.backward()を呼び出すと、.grad属性に勾配が**累積(加算)**されることです。optimizer.zero_grad()が呼び出されるまで勾配はリセットされないため、複数回の逆伝播の結果を自然に蓄積できます。

6.2 実装

accumulation_steps = 4
optimizer.zero_grad()

for i, (inputs, targets) in enumerate(dataloader):
    inputs, targets = inputs.cuda(), targets.cuda()

    with autocast("cuda"):
        outputs = model(inputs)
        # 蓄積ステップ数で損失を割って平均を計算
        loss = loss_fn(outputs, targets) / accumulation_steps

    scaler.scale(loss).backward()

    # accumulation_stepsごとにウェイト更新を実行
    if (i + 1) % accumulation_steps == 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

重要: accumulation_stepsで損失を割ることを忘れないでください。勾配が蓄積されるため、割らないと実質的に学習率がaccumulation_steps倍になってしまいます。

6.3 注意事項

  • Batch Normalization: BatchNormはマイクロバッチごとに統計量を計算するため、実効バッチ全体の統計量と異なる場合があります。このような場合、GroupNormやLayerNormの使用が望ましいです。
  • 学習率スケジューリング: ステップベースのスケジューラを使用する場合、蓄積ステップ数を考慮してスケジューラの呼び出し頻度を調整する必要があります。
  • DDPとの組み合わせ: DDPでGradient Accumulationを使用する場合、蓄積中はAllReduceを無効にして通信オーバーヘッドを削減できます。model.no_sync()コンテキストマネージャを使用します。
for i, (inputs, targets) in enumerate(dataloader):
    # 最後の蓄積ステップでない限りAllReduceをスキップ
    context = model.no_sync() if (i + 1) % accumulation_steps != 0 else nullcontext()
    with context:
        with autocast("cuda"):
            outputs = model(inputs)
            loss = loss_fn(outputs, targets) / accumulation_steps
        scaler.scale(loss).backward()

    if (i + 1) % accumulation_steps == 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

7. torch.cuda.memory_summary()によるメモリプロファイリング

メモリ最適化の第一歩は、現在のメモリ使用状況を正確に把握することです。PyTorchは様々なCUDAメモリプロファイリングツールを提供しています。

7.1 torch.cuda.memory_summary()

最も基本的でありながら有用なツールです。GPUメモリの割り当て状況の詳細情報を提供します:

import torch

model = MyLargeModel().cuda()
inputs = torch.randn(32, 3, 224, 224).cuda()

# 順伝播後のメモリ状態を確認
with autocast("cuda"):
    outputs = model(inputs)
    loss = loss_fn(outputs, targets)

print(torch.cuda.memory_summary(device=0, abbreviated=False))

出力例(一部):

|===========================================================================|
|                  PyTorch CUDA memory summary                              |
|===========================================================================|
|            CUDA OOMs: 0                                                   |
|        cudaMallocs:   234                                                 |
|---------------------------------------------------------------------------+
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------+
| Allocated memory      |  4096 MiB  |  6234 MiB  | 12345 MiB  |  8249 MiB  |
| Active memory         |  3800 MiB  |  5900 MiB  | 11000 MiB  |  7200 MiB  |
| Requested memory      |  3750 MiB  |  5850 MiB  | 10800 MiB  |  7050 MiB  |
| GPU reserved memory   |  6400 MiB  |  6400 MiB  |  6400 MiB  |     0 MiB  |
|---------------------------------------------------------------------------+

主要メトリクスの解釈:

  • Allocated memory: 現在割り当てられているメモリ(テンソルが実際に占有しているメモリ)
  • Active memory: 現在使用中のメモリ
  • GPU reserved memory: PyTorchがCUDAから確保した総メモリ(キャッシングアロケータ)
  • Peak Usage: 学習中の最大メモリ使用量(OOM発生の判断に重要)

7.2 個別メモリクエリ関数

コード内の特定のポイントでメモリ使用量を追跡するのに便利です:

# 現在割り当てられているメモリ(バイト)
allocated = torch.cuda.memory_allocated(device=0)

# 現在予約されているメモリ(バイト)
reserved = torch.cuda.memory_reserved(device=0)

# 学習開始以降の最大割り当てメモリ
max_allocated = torch.cuda.max_memory_allocated(device=0)

print(f"Allocated: {allocated / 1024**3:.2f} GB")
print(f"Reserved:  {reserved / 1024**3:.2f} GB")
print(f"Peak:      {max_allocated / 1024**3:.2f} GB")

# ピーク統計のリセット(セグメントごとの測定用)
torch.cuda.reset_peak_memory_stats(device=0)

7.3 Memory Snapshotによる詳細分析

PyTorch 2.x以降で利用可能なMemory Snapshotは、メモリの割り当て/解放イベントを時系列順に記録して可視化できます:

# メモリ記録の開始
torch.cuda.memory._record_memory_history(max_entries=100000)

# 学習コードの実行
train_one_epoch(model, dataloader, optimizer)

# スナップショットの保存
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")

# 記録の停止
torch.cuda.memory._record_memory_history(enabled=None)

保存されたスナップショットは、PyTorch公式のMemory Visualizer(https://pytorch.org/memory_viz)にアップロードして視覚的に分析できます。

7.4 OOMデバッグ戦略

Out of Memoryが発生した場合、以下の順序で原因を診断します:

  1. torch.cuda.memory_summary()で全体像を把握
  2. max_memory_allocated()でピークメモリを確認
  3. Memory Snapshotを使用してどの演算がメモリスパイクを引き起こしているか特定
  4. バッチサイズ、シーケンス長、隠れ層の次元数のどの要因が支配的か分析
  5. 適切な最適化手法(Mixed Precision、Gradient Checkpointing、Gradient Accumulation)を適用

8. NVIDIA Transformer EngineとFP8学習

NVIDIA Transformer Engine(TE)は、Hopper(H100)以降のGPUでFP8精度を使用してTransformerモデルの学習と推論を加速するライブラリです。FP8学習はFP16/BF16と比較して最大2倍のスループット向上を実現できます。

8.1 2つのFP8フォーマット

H100 GPUは、ハードウェアレベルで2つのFP8フォーマットをサポートしています:

  • E4M3(4ビット指数部、3ビット仮数部): レンジ約±448。比較的高い精度。順伝播の活性化値とウェイトに使用。
  • E5M2(5ビット指数部、2ビット仮数部): レンジ約±57,344。広いダイナミックレンジ。逆伝播の勾配に使用。

この分離戦略により、各段階の数値特性に最も適したフォーマットを選択し、精度の低下を最小限に抑えます。

8.2 テンソルごとのスケーリング

FP8の狭いダイナミックレンジを補うために、Transformer Engineはテンソルごとのスケーリングを適用します。各テンソルに個別のスケールファクターを維持し、テンソルの値分布をFP8の表現可能範囲に収まるように調整します。

Transformer Engineはいくつかのスケーリング戦略を提供しています:

  • Delayed Scaling: 前のイテレーションの統計に基づいてスケールファクターを決定(デフォルト)
  • Current Scaling(Just-in-time): 現在のテンソルの値に基づいてスケールファクターを即座に計算
  • Block Scaling: テンソルをブロックに分割し、各ブロックに個別のスケールファクターを適用

8.3 Transformer Engineの使用例

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling, Format

# FP8学習レシピの設定
fp8_recipe = DelayedScaling(
    margin=0,
    fp8_format=Format.HYBRID,  # 順伝播: E4M3、逆伝播: E5M2
    amax_history_len=1024,
    amax_compute_algo="max",
)

# Transformer Engineのレイヤーを使用
model = te.TransformerLayer(
    hidden_size=4096,
    ffn_hidden_size=16384,
    num_attention_heads=32,
    layer_number=1,
)

# FP8学習ループ
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    outputs = model(inputs)
    loss = loss_fn(outputs, targets)

loss.backward()
optimizer.step()

8.4 Blackwell世代での進化

NVIDIA Blackwell(B200)GPUはFP8に加えて以下のフォーマットをサポートしています:

  • MXFP8: Microscaling FP8。より細かいブロックレベルのスケーリングをサポート
  • NVFP4: 4ビット浮動小数点。主に推論用で、メモリ効率を最大化

Transformer Engine 2.x以降では、Recipeモジュールを通じてこれらの新しいフォーマットを統合的にサポートしています。


9. 実践事例: 限られたGPUでの大規模モデル学習

これまで取り上げたすべてのテクニックを組み合わせ、限られたGPU環境での大規模モデル学習の実践的な戦略を概説します。

9.1 シナリオ: 24GB GPUでの7Bモデルのファインチューニング

単一のA10G(24GB VRAM)で7Bパラメータモデルをファインチューニングする場合を想定します。

メモリ要件の分析(FP32基準):

構成要素メモリ
モデルパラメータ (FP32)28 GB
勾配 (FP32)28 GB
Adamオプティマイザの状態56 GB
活性化値 (batch=1)約2 GB
合計約114 GB

FP32ではまったく不可能です。段階的に最適化を適用していきましょう。

9.2 段階的な最適化

ステップ1: Mixed Precision Training(BF16)

from torch.amp import autocast

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf",
                                              torch_dtype=torch.bfloat16)
構成要素メモリ
モデルパラメータ (BF16)14 GB
勾配 (BF16)14 GB
Adamマスターウェイト (FP32)28 GB
Adamモメンタム (FP32)28 GB
Adam分散 (FP32)28 GB
合計約112 GB

オプティマイザの状態のため、依然として不十分です。

ステップ2: LoRAの適用(学習可能パラメータの削減)

LoRA(Low-Rank Adaptation)は、全パラメータの約0.1〜1%のみを学習対象に設定します。7BモデルにRank=16のLoRAを適用すると、約20Mの学習可能パラメータになります。

from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    lora_dropout=0.05,
)
model = get_peft_model(model, lora_config)
構成要素メモリ
フルモデル (BF16、凍結)14 GB
LoRAパラメータ (BF16)約0.04 GB
LoRA勾配 (BF16)約0.04 GB
Adam状態 (FP32、LoRAのみ)約0.24 GB
合計約14.3 GB

24GB GPUに余裕を持って収まります。バッチサイズを増やす余地があります。

ステップ3: Gradient Checkpointingの追加

残りの約10GBで活性化値のメモリを可能な限り節約し、バッチサイズを増やします:

model.gradient_checkpointing_enable()

ステップ4: Gradient Accumulationで実効バッチサイズを確保

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./output",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,  # 実効バッチサイズ = 32
    bf16=True,
    gradient_checkpointing=True,
    learning_rate=2e-4,
    num_train_epochs=3,
    logging_steps=10,
)

9.3 最適化手法の組み合わせガイド

GPU VRAMモデル規模推奨組み合わせ
8 GB約1BBF16 + LoRA(r=8) + GC + GA
16 GB約3BBF16 + LoRA(r=16) + GC + GA
24 GB約7BBF16 + LoRA(r=16) + GC + GA
40 GB約13BBF16 + LoRA(r=32) + GC + GA
80 GB約7Bフル FTBF16 + GC + GA
8x80 GB約70BBF16 + FSDP/DeepSpeed ZeRO-3 + GC

GC = Gradient Checkpointing、GA = Gradient Accumulation、フル FT = フルファインチューニング

9.4 統合メモリプロファイリングスクリプト

学習前にメモリ使用量を事前確認するために、以下のスクリプトを使用できます:

import torch
from torch.amp import autocast

def profile_memory(model, dummy_input, dtype=torch.bfloat16):
    """学習中のGPUメモリ使用量をプロファイリングする。"""
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()

    model = model.cuda()
    dummy_input = dummy_input.cuda()

    # 順伝播
    with autocast("cuda", dtype=dtype):
        outputs = model(dummy_input)
        if isinstance(outputs, dict):
            loss = outputs["loss"]
        else:
            loss = outputs.sum()

    print("=== 順伝播後 ===")
    print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"Reserved:  {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

    # 逆伝播
    loss.backward()

    print("\n=== 逆伝播後 ===")
    print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"Reserved:  {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
    print(f"Peak:      {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")

    print("\n=== メモリ全体のサマリー ===")
    print(torch.cuda.memory_summary(abbreviated=True))

# 使用例
# profile_memory(model, dummy_input)

10. 結論

GPUメモリ最適化は単一の手法ではなく、複数の手法の組み合わせです。以下に要点をまとめます:

  1. メモリ構成の理解が最優先: パラメータ、勾配、オプティマイザの状態、活性化値のそれぞれがどの程度のメモリを消費するかを把握し、どこを最適化すべきかを判断する必要があります。

  2. Mixed Precisionが基本: BF16を使用することで、Loss Scalingなしでメモリを半減し計算速度を安定的に向上させることができます。Ampere以降のGPUではBF16をデフォルトとして使用しましょう。

  3. Gradient Checkpointingは活性化メモリに対応: 約33%の計算オーバーヘッドと引き換えに、活性化メモリを大幅に削減します。大規模モデルの学習にはほぼ必須です。

  4. Gradient Accumulationはバッチサイズの問題を解決: 追加のメモリ負担なしに実効バッチサイズを増やすことができます。

  5. FP8は次世代の標準: Hopper/Blackwell GPUではTransformer Engineを通じたFP8学習が可能であり、FP16/BF16を超えるさらなるパフォーマンス向上をもたらします。

  6. プロファイリングを習慣に: torch.cuda.memory_summary()やMemory Snapshotを通じて定期的にメモリ使用量を監視することで、OOM問題を事前に防止できます。


参考文献

クイズ

Q1: 「GPUメモリ最適化と混合精度トレーニング完全ガイド」の主なトピックは何ですか?

NVIDIA公式ドキュメントに基づき、GPUメモリの構成要素を分析し、Mixed Precision TrainingやGradient Checkpointingなどのメモリ最適化手法を網羅的に解説します。

Q2: 1 モデルパラメータとは何ですか? モデルパラメータとは、ニューラルネットワークの重みとバイアスを指します。学習プロセス全体を通じてGPUメモリ上に常駐し、パラメータ数とデータ型によってメモリ使用量が決まります。 FP32基準: パラメータあたり4バイト FP16/BF16基準: パラメータあたり2バイト 例えば、7B(70億)パラメータのモデルをFP32で格納するには約28GB、FP16では約14GBが必要です。

Q3: 2 勾配の核心的な概念を説明してください。 勾配は、損失関数に対する各パラメータの偏微分値です。逆伝播時に計算され、各パラメータに対応する勾配値が存在するため、パラメータと同量のメモリを占有します。 FP32学習: パラメータ数 x 4バイト Mixed Precision学習: パラメータ数 x 2バイト(FP16で格納)

Q4: 3 オプティマイザの状態の主な特徴は何ですか? オプティマイザの状態は、メモリの中で最も大きな割合を占めることが多い要素です。SGDはモメンタムのみを保持しますが、Adam/AdamWオプティマイザは各パラメータに対して第1モーメント(平均)と第2モーメント(分散)の2つの追加状態を保持します。 FP32学習時のAdamオプティマイザのオプティマイザ状態メモリ: 7BモデルでAdamを使用すると、オプティマイザの状態だけで約84GBが必要になります。これが、大規模モデル学習においてオプティマイザの状態の最適化が極めて重要である理由です。

Q5: 4 活性化値はどのように機能しますか? 活性化値は、順伝播時の各層の出力であり、逆伝播時の勾配計算に必要なため保持される必要があります。活性化値のメモリは以下の要因に比例します: バッチサイズ: 大きいほどメモリ消費が増加 シーケンス長: Transformerモデルでは特に重要(Attentionで O(n^2)) 隠れ層の次元数: モデルが大きいほど増加 レイヤー数: モデルが深いほど増加 大規模モデルでは、活性化値のメモリがパラメータメモリを上回ることも頻繁にあります。