- Authors

- Name
- Youngju Kim
- @fjvbn20031
PyTorch高度なテクニック完全ガイド
PyTorchはリサーチとプロダクション両方において最も人気の深層学習フレームワークの一つです。しかし、ほとんどの開発者は基本的なテンソル演算とnn.Moduleの定義で止まっています。このガイドでは、実際のプロダクション環境でパフォーマンスを最大化し、カスタム演算を実装し、メモリを効率的に管理するための高度なテクニックを解説します。
1. torch.compile(PyTorch 2.0以降)
PyTorch 2.0で導入されたtorch.compileは、モデルをコンパイルして実行速度を劇的に改善する機能です。TorchScriptやONNXエクスポートとは異なり、torch.compileは最小限のコード変更で2倍以上のスピードアップをもたらすことができます。
導入とメリット
torch.compileは3つのコアコンポーネントで構成されています:
- TorchDynamo: Pythonバイトコードを傍受してFXグラフを生成
- AOTAutograd: 自動微分グラフを事前コンパイル
- Inductor: TorchInductorバックエンド(Triton GPUカーネルまたはC++ CPUカーネル)を通じて最適化されたカーネルを生成
import torch
import torch.nn as nn
import time
# ベースモデルを定義
class TransformerBlock(nn.Module):
def __init__(self, d_model=512, nhead=8, dim_feedforward=2048):
super().__init__()
self.attention = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
nn.ReLU(),
nn.Linear(dim_feedforward, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
attn_out, _ = self.attention(x, x, x)
x = self.norm1(x + attn_out)
ff_out = self.feed_forward(x)
x = self.norm2(x + ff_out)
return x
device = "cuda" if torch.cuda.is_available() else "cpu"
model = TransformerBlock().to(device)
# torch.compileを適用
compiled_model = torch.compile(model)
# ウォームアップ
x = torch.randn(32, 128, 512, device=device)
for _ in range(3):
_ = compiled_model(x)
# パフォーマンス比較
N = 100
# Eagerモード
start = time.perf_counter()
for _ in range(N):
_ = model(x)
if device == "cuda":
torch.cuda.synchronize()
elapsed_eager = time.perf_counter() - start
# コンパイルモード
start = time.perf_counter()
for _ in range(N):
_ = compiled_model(x)
if device == "cuda":
torch.cuda.synchronize()
elapsed_compiled = time.perf_counter() - start
print(f"Eagerモード: {elapsed_eager:.3f}s")
print(f"コンパイルモード: {elapsed_compiled:.3f}s")
print(f"スピードアップ: {elapsed_eager / elapsed_compiled:.2f}x")
コンパイルモード
torch.compileは3つのモードを提供します:
# デフォルトモード - 高速コンパイル、良好なパフォーマンス
model_default = torch.compile(model, mode="default")
# reduce-overheadモード - 小規模モデルに有効
model_reduce = torch.compile(model, mode="reduce-overhead")
# max-autotune - 長時間コンパイル、最高パフォーマンス
model_autotune = torch.compile(model, mode="max-autotune")
# フルグラフモード - 動的グラフなし(厳格)
model_full = torch.compile(model, fullgraph=True)
# バックエンド選択
model_eager = torch.compile(model, backend="eager") # コンパイルなし(デバッグ用)
model_aot = torch.compile(model, backend="aot_eager") # AOTのみ
model_inductor = torch.compile(model, backend="inductor") # デフォルト
動的シェイプのサポート
import torch._dynamo as dynamo
# 動的シェイプを有効化
model_dynamic = torch.compile(model, dynamic=True)
# 異なるバッチサイズでも再コンパイルなしで動作
for batch_size in [8, 16, 32, 64]:
x = torch.randn(batch_size, 128, 512, device=device)
out = model_dynamic(x)
print(f"Batch {batch_size}: 出力シェイプ {out.shape}")
# コンパイルキャッシュを確認
print(dynamo.explain(model)(x))
既存コードへの移行
# 学習ループにtorch.compileを適用
def train_epoch(model, optimizer, dataloader, criterion):
model.train()
total_loss = 0
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
# モデルをコンパイルするだけ
model = MyModel().cuda()
compiled_model = torch.compile(model) # この行を追加するだけ!
optimizer = torch.optim.Adam(compiled_model.parameters())
criterion = nn.CrossEntropyLoss()
2. カスタムAutograd
PyTorchの自動微分エンジンは強力ですが、数値安定性や効率のためにカスタム勾配計算が必要な場合があります。
torch.autograd.Functionのサブクラス化
import torch
from torch.autograd import Function
class SigmoidFunction(Function):
"""数値的に安定したシグモイド実装"""
@staticmethod
def forward(ctx, input):
# sigmoid = 1 / (1 + exp(-x))
sigmoid = torch.sigmoid(input)
# backwardで使用するテンソルを保存
ctx.save_for_backward(sigmoid)
return sigmoid
@staticmethod
def backward(ctx, grad_output):
# 保存したシグモイドを取得
sigmoid, = ctx.saved_tensors
# gradient = sigmoid * (1 - sigmoid) * grad_output
grad_input = sigmoid * (1 - sigmoid) * grad_output
return grad_input
# 使用方法
def custom_sigmoid(x):
return SigmoidFunction.apply(x)
x = torch.randn(3, 4, requires_grad=True)
y = custom_sigmoid(x)
loss = y.sum()
loss.backward()
print(f"勾配: {x.grad}")
より複雑な例: カスタムBackwardを持つLeaky ReLU
class LeakyReLUFunction(Function):
@staticmethod
def forward(ctx, input, negative_slope=0.01):
ctx.save_for_backward(input)
ctx.negative_slope = negative_slope
return input.clamp(min=0) + negative_slope * input.clamp(max=0)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
negative_slope = ctx.negative_slope
grad_input = grad_output.clone()
grad_input[input < 0] *= negative_slope
# negative_slopeに対する勾配はNone
return grad_input, None
class CustomLeakyReLU(nn.Module):
def __init__(self, negative_slope=0.01):
super().__init__()
self.negative_slope = negative_slope
def forward(self, x):
return LeakyReLUFunction.apply(x, self.negative_slope)
数値勾配チェック
from torch.autograd import gradcheck
def test_custom_op():
# 数値勾配チェック(float64推奨)
input = torch.randn(3, 4, dtype=torch.float64, requires_grad=True)
# gradcheckはヤコビアンを数値的に計算してautogradと比較
result = gradcheck(SigmoidFunction.apply, (input,), eps=1e-6, atol=1e-4)
print(f"勾配チェック合格: {result}")
# 二重逆伝播テスト
input = torch.randn(3, 4, dtype=torch.float64, requires_grad=True)
result = gradcheck(
SigmoidFunction.apply,
(input,),
eps=1e-6,
atol=1e-4,
check_grad_dtypes=True
)
test_custom_op()
二重逆伝播
class SquaredFunction(Function):
"""二重逆伝播をサポートするカスタムx^2実装"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x ** 2
@staticmethod
def backward(ctx, grad_output):
x, = ctx.saved_tensors
# 二重逆伝播をサポートするにはcreate_graph=Trueが必要
return 2 * x * grad_output
# 二重逆伝播の例(MAMLなどのメタ学習で有用)
x = torch.randn(3, requires_grad=True)
y = SquaredFunction.apply(x)
grad_x = torch.autograd.grad(y.sum(), x, create_graph=True)[0]
# grad_x = 2x、再び逆伝播
grad_grad_x = torch.autograd.grad(grad_x.sum(), x)[0]
# grad_grad_x = 2 (定数)
print(f"二階微分: {grad_grad_x}")
3. カスタムCUDAオペレータ
torch.utils.cpp_extensionの概要
PyTorchはC++/CUDA拡張を書いてPythonから使用するためのツールを提供します。
# JITコンパイル(開発/プロトタイプに適しています)
from torch.utils.cpp_extension import load_inline
# C++ CPUオペレータ
cpp_source = """
#include <torch/extension.h>
torch::Tensor relu_forward(torch::Tensor input) {
return input.clamp_min(0);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("relu_forward", &relu_forward, "ReLU forward");
}
"""
# インラインコンパイル
custom_relu_cpp = load_inline(
name="custom_relu_cpp",
cpp_sources=cpp_source,
functions=["relu_forward"],
verbose=False
)
x = torch.randn(5)
result = custom_relu_cpp.relu_forward(x)
print(result)
CUDAカーネルの例: Fused Softmax
# CUDAソース
cuda_source = """
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
__global__ void fused_softmax_kernel(
const float* input,
float* output,
int rows,
int cols
) {
int row = blockIdx.x;
if (row >= rows) return;
const float* row_input = input + row * cols;
float* row_output = output + row * cols;
// 最大値を求める(数値安定性のため)
float max_val = row_input[0];
for (int i = 1; i < cols; i++) {
max_val = fmaxf(max_val, row_input[i]);
}
// exp(x - max)と合計を計算
float sum = 0.0f;
for (int i = 0; i < cols; i++) {
row_output[i] = expf(row_input[i] - max_val);
sum += row_output[i];
}
// 正規化
for (int i = 0; i < cols; i++) {
row_output[i] /= sum;
}
}
torch::Tensor fused_softmax_cuda(torch::Tensor input) {
auto output = torch::zeros_like(input);
int rows = input.size(0);
int cols = input.size(1);
fused_softmax_kernel<<<rows, 1>>>(
input.data_ptr<float>(),
output.data_ptr<float>(),
rows,
cols
);
return output;
}
"""
cpp_source_cuda = """
#include <torch/extension.h>
torch::Tensor fused_softmax_cuda(torch::Tensor input);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_softmax", &fused_softmax_cuda, "Fused Softmax CUDA");
}
"""
# CUDAが利用可能な場合のみコンパイル
if torch.cuda.is_available():
from torch.utils.cpp_extension import load_inline
fused_softmax_ext = load_inline(
name="fused_softmax",
cpp_sources=cpp_source_cuda,
cuda_sources=cuda_source,
functions=["fused_softmax"],
verbose=True
)
# テスト
x = torch.randn(4, 8, device="cuda")
result = fused_softmax_ext.fused_softmax(x)
expected = torch.softmax(x, dim=1)
print(f"最大差分: {(result - expected).abs().max().item():.6f}")
setup.pyでパッケージをビルド
# setup.py
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name="custom_ops",
ext_modules=[
CUDAExtension(
name="custom_ops",
sources=[
"custom_ops/ops.cpp",
"custom_ops/ops_cuda.cu",
],
extra_compile_args={
"cxx": ["-O3"],
"nvcc": ["-O3", "--use_fast_math"],
}
)
],
cmdclass={
"build_ext": BuildExtension
}
)
# ビルド: python setup.py install
4. メモリ最適化テクニック
GPUメモリのプロファイリング
import torch
def print_gpu_memory_stats():
"""GPUメモリ統計を表示するユーティリティ"""
if torch.cuda.is_available():
print(f"割り当て済み: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"予約済み: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
print(f"最大割り当て: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
print(torch.cuda.memory_summary(abbreviated=True))
# メモリトラッキングを開始
torch.cuda.reset_peak_memory_stats()
print_gpu_memory_stats()
# モデルを読み込む
model = TransformerBlock().cuda()
print_gpu_memory_stats()
Gradient Checkpointing(活性化の再計算)
Gradient Checkpointingは順伝播中の中間活性化値を保存せず、逆伝播中に必要に応じて再計算します。これにより約30%多い計算コストで大幅なメモリ節約ができます。
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
import torch.nn as nn
class DeepTransformer(nn.Module):
def __init__(self, num_layers=12, d_model=512, nhead=8):
super().__init__()
self.layers = nn.ModuleList([
TransformerBlock(d_model, nhead) for _ in range(num_layers)
])
def forward_with_checkpointing(self, x):
"""各レイヤーにgradient checkpointingを適用"""
for layer in self.layers:
# checkpointは中間活性化を保存しない
x = checkpoint(layer, x, use_reentrant=False)
return x
def forward_sequential_checkpointing(self, x):
"""シーケンシャルモジュールにcheckpointingを適用"""
# 4レイヤーのセグメントにグループ化
x = checkpoint_sequential(self.layers, segments=3, input=x)
return x
# メモリ比較
model = DeepTransformer(num_layers=24).cuda()
x = torch.randn(16, 512, 512, device="cuda")
# 通常の順伝播
torch.cuda.reset_peak_memory_stats()
out = model.forward_with_checkpointing(x)
checkpoint_mem = torch.cuda.max_memory_allocated()
print(f"Checkpointメモリ: {checkpoint_mem / 1024**3:.2f} GB")
勾配蓄積
def train_with_gradient_accumulation(
model, optimizer, dataloader, criterion,
accumulation_steps=4
):
"""
限られたGPUメモリで大きな有効バッチサイズを実装
"""
model.train()
optimizer.zero_grad()
total_loss = 0
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.cuda(), target.cuda()
# 順伝播
output = model(data)
# lossをaccumulation_stepsでスケール
loss = criterion(output, target) / accumulation_steps
loss.backward()
total_loss += loss.item() * accumulation_steps
# accumulation_stepsごとに更新
if (batch_idx + 1) % accumulation_steps == 0:
# 勾配クリッピング(オプション)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
return total_loss / len(dataloader)
混合精度学習(AMP)
from torch.cuda.amp import autocast, GradScaler
def train_with_amp(model, optimizer, dataloader, criterion):
"""メモリ節約+速度向上のための自動混合精度"""
scaler = GradScaler()
model.train()
for data, target in dataloader:
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
# autocastコンテキスト内でfloat16演算
with autocast():
output = model(data)
loss = criterion(output, target)
# スケールされた勾配での逆伝播
scaler.scale(loss).backward()
# スケール解除後に勾配クリッピング
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# オプティマイザーステップ(NaN/Infチェックあり)
scaler.step(optimizer)
scaler.update()
8ビットオプティマイザー
# bitsandbytesが必要: pip install bitsandbytes
try:
import bitsandbytes as bnb
# 通常のAdamの代わりに8ビットAdamを使用
optimizer_8bit = bnb.optim.Adam8bit(
model.parameters(),
lr=1e-4,
betas=(0.9, 0.999)
)
# PagedAdam(CPUにオフロード可能)
optimizer_paged = bnb.optim.PagedAdam(
model.parameters(),
lr=1e-4
)
print("8ビットオプティマイザーが正常に読み込まれました")
except ImportError:
print("bitsandbytesがインストールされていません。通常のAdamを使用します")
optimizer_8bit = torch.optim.Adam(model.parameters(), lr=1e-4)
5. functorchとvmap
vmap - バッチ処理
vmapは単一サンプルを処理する関数をバッチに効率的に適用します。
import torch
from torch import vmap
# 単一サンプルを処理する関数
def single_linear(weight, bias, x):
return weight @ x + bias
# vmapでバッチ処理
batched_linear = vmap(single_linear)
# バッチデータ
batch_size = 32
weight = torch.randn(batch_size, 10, 5)
bias = torch.randn(batch_size, 10)
x = torch.randn(batch_size, 5)
# バッチ内で自動的に処理
result = batched_linear(weight, bias, x)
print(f"結果シェイプ: {result.shape}") # (32, 10)
grad - 関数型勾配
from torch.func import grad, vmap, functional_call
# 関数型勾配
def scalar_loss(params, x, y):
pred = functional_call(model, params, (x,))
return ((pred - y) ** 2).mean()
# パラメータに対する勾配
params = dict(model.named_parameters())
grad_fn = grad(scalar_loss)
x = torch.randn(1, 10)
y = torch.randn(1, 5)
grads = grad_fn(params, x, y)
print({k: v.shape for k, v in grads.items()})
vmapでアンサンブルモデル
from torch.func import stack_module_state, functional_call, vmap
def create_ensemble(model_class, num_models, *args, **kwargs):
"""vmapを使用した効率的なアンサンブル"""
models = [model_class(*args, **kwargs) for _ in range(num_models)]
# 全モデルのパラメータをスタック
params, buffers = stack_module_state(models)
# 単一モデルの順伝播
base_model = model_class(*args, **kwargs)
def single_forward(params, buffers, x):
return functional_call(base_model, (params, buffers), (x,))
# vmapで全モデルを並列実行
ensemble_forward = vmap(single_forward, in_dims=(0, 0, None))
return ensemble_forward, params, buffers
# 使用例
ensemble_fn, params, buffers = create_ensemble(
nn.Linear, num_models=5, in_features=10, out_features=5
)
x = torch.randn(32, 10)
ensemble_out = ensemble_fn(params, buffers, x)
print(f"アンサンブル出力シェイプ: {ensemble_out.shape}") # (5, 32, 5)
gradとvmapによるメタ学習(MAML)
from torch.func import grad, vmap, functional_call
def inner_loop(params, support_x, support_y, base_model, lr=0.01, steps=5):
"""MAMLの内部ループ"""
adapted_params = {k: v.clone() for k, v in params.items()}
for _ in range(steps):
def loss_fn(params):
pred = functional_call(base_model, params, (support_x,))
return ((pred - support_y) ** 2).mean()
grads = grad(loss_fn)(adapted_params)
adapted_params = {
k: p - lr * grads[k]
for k, p in adapted_params.items()
}
return adapted_params
6. PyTorch Profiler
基本的なプロファイリング
import torch
from torch.profiler import profile, record_function, ProfilerActivity
model = TransformerBlock().cuda()
x = torch.randn(32, 128, 512, device="cuda")
# プロファイラーを実行
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True
) as prof:
with record_function("model_inference"):
for _ in range(10):
output = model(x)
# 結果を表示
print(prof.key_averages().table(
sort_by="cuda_time_total",
row_limit=20
))
# Chrome Traceをエクスポート
prof.export_chrome_trace("trace.json")
# chrome://tracingで開く
TensorBoardとの統合
from torch.profiler import profile, ProfilerActivity, tensorboard_trace_handler
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(
wait=1, # 1ステップ待機
warmup=1, # 1ウォームアップステップ
active=3, # 3ステップをプロファイル
repeat=2 # 2回繰り返す
),
on_trace_ready=tensorboard_trace_handler("./log/profiler"),
record_shapes=True,
profile_memory=True
) as prof:
for step, (data, target) in enumerate(dataloader):
train_step(model, optimizer, data, target)
prof.step() # スケジュールに従ってプロファイル
# tensorboard --logdir=./log/profiler
詳細なメモリ分析
# メモリスナップショット(PyTorch 2.1以降)
torch.cuda.memory._record_memory_history(max_entries=100000)
# コードを実行
x = torch.randn(100, 100, device="cuda")
y = x @ x.T
z = y.sum()
# スナップショットを保存
snapshot = torch.cuda.memory._snapshot()
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
print(f"アクティブな割り当て: {len(snapshot['segments'])}")
7. TorchScript
torch.jit.scriptとtorch.jit.trace
import torch
import torch.nn as nn
class ConditionalModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 5)
def forward(self, x, flag: bool = True):
if flag: # 制御フローによりトレースは不可
return torch.relu(self.linear(x))
else:
return torch.sigmoid(self.linear(x))
model = ConditionalModel()
# script: 制御フローを含む(推奨)
scripted_model = torch.jit.script(model)
print(scripted_model.code)
# trace: 単一パスのみキャプチャ
x = torch.randn(1, 10)
traced_model = torch.jit.trace(model, x)
# 注意: flag=Falseのパスはキャプチャされない
# 保存と読み込み
scripted_model.save("model_scripted.pt")
loaded = torch.jit.load("model_scripted.pt")
TorchScript最適化
# 最適化パスを適用
scripted = torch.jit.script(model)
optimized = torch.jit.optimize_for_inference(scripted)
# C++環境へのエクスポート
# C++では: torch::jit::script::Module m = torch::jit::load("model.pt");
8. 動的シェイプとtorch.export
torch.exportの使用
import torch
from torch.export import export
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 5)
def forward(self, x):
return self.linear(x)
model = MyModel()
example_inputs = (torch.randn(2, 10),)
# 動的シェイプを指定
from torch.export import Dim
batch = Dim("batch", min=1, max=100)
# モデルをエクスポート
exported = export(
model,
example_inputs,
dynamic_shapes={"x": {0: batch}}
)
print(exported)
print(exported.graph_module.code)
# ExportedProgramを実行
result = exported.module()(torch.randn(5, 10))
print(f"結果シェイプ: {result.shape}")
9. カスタムデータセットとサンプラー
IterableDataset
from torch.utils.data import IterableDataset, DataLoader
import torch
class StreamingDataset(IterableDataset):
"""大規模データセットを効率的にストリーミング"""
def __init__(self, data_paths, transform=None):
self.data_paths = data_paths
self.transform = transform
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
# シングルプロセス
paths = self.data_paths
else:
# マルチプロセス: ワーカー間でデータを分割
per_worker = len(self.data_paths) // worker_info.num_workers
start = worker_info.id * per_worker
end = start + per_worker
paths = self.data_paths[start:end]
for path in paths:
# ファイルからデータをストリーミング
data = self._load_file(path)
for sample in data:
if self.transform:
sample = self.transform(sample)
yield sample
def _load_file(self, path):
# 実際の実装ではファイルから読み込む
return [torch.randn(10) for _ in range(100)]
# DataLoaderでpersistent_workersを使用
dataset = StreamingDataset(data_paths=["file1.pt", "file2.pt"])
dataloader = DataLoader(
dataset,
batch_size=32,
num_workers=4,
persistent_workers=True, # ワーカープロセスを再利用
pin_memory=True, # GPU転送を最適化
prefetch_factor=2 # プリフェッチするバッチ数
)
カスタムサンプラー
from torch.utils.data import Sampler
import numpy as np
class BalancedClassSampler(Sampler):
"""クラス不均衡問題のための重み付きサンプリング"""
def __init__(self, dataset, num_samples_per_class=None):
self.dataset = dataset
labels = [dataset[i][1] for i in range(len(dataset))]
self.labels = torch.tensor(labels)
# クラスごとのインデックス
self.class_indices = {}
for cls in torch.unique(self.labels):
self.class_indices[cls.item()] = (
self.labels == cls
).nonzero(as_tuple=True)[0].tolist()
self.num_classes = len(self.class_indices)
self.num_samples_per_class = (
num_samples_per_class or
max(len(v) for v in self.class_indices.values())
)
def __iter__(self):
indices = []
for cls_idx in self.class_indices.values():
# 各クラスから同じ数をサンプリング(復元抽出)
sampled = np.random.choice(
cls_idx,
self.num_samples_per_class,
replace=True
).tolist()
indices.extend(sampled)
# シャッフル
np.random.shuffle(indices)
return iter(indices)
def __len__(self):
return self.num_classes * self.num_samples_per_class
10. PyTorch Lightning
完全なLightningModuleの例
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
class LightningTransformer(pl.LightningModule):
def __init__(
self,
d_model=512,
nhead=8,
num_layers=6,
num_classes=10,
learning_rate=1e-4
):
super().__init__()
self.save_hyperparameters() # HParamsを自動保存
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
batch_first=True
),
num_layers=num_layers
)
self.classifier = nn.Linear(d_model, num_classes)
self.criterion = nn.CrossEntropyLoss()
def forward(self, x):
encoded = self.encoder(x)
# シーケンスの平均プーリング
pooled = encoded.mean(dim=1)
return self.classifier(pooled)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = self.criterion(logits, y)
acc = (logits.argmax(dim=1) == y).float().mean()
# ロギング
self.log("train_loss", loss, prog_bar=True)
self.log("train_acc", acc, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = self.criterion(logits, y)
acc = (logits.argmax(dim=1) == y).float().mean()
self.log("val_loss", loss, prog_bar=True)
self.log("val_acc", acc, prog_bar=True)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.hparams.learning_rate,
weight_decay=0.01
)
# コサインアニーリングスケジューラー
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=100
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss"
}
}
class LightningDataModule(pl.LightningDataModule):
def __init__(self, batch_size=32):
super().__init__()
self.batch_size = batch_size
def setup(self, stage=None):
# ダミーデータ
x = torch.randn(1000, 32, 512)
y = torch.randint(0, 10, (1000,))
dataset = TensorDataset(x, y)
n = len(dataset)
self.train_ds = torch.utils.data.Subset(dataset, range(int(0.8*n)))
self.val_ds = torch.utils.data.Subset(dataset, range(int(0.8*n), n))
def train_dataloader(self):
return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(self.val_ds, batch_size=self.batch_size)
# 学習を実行
model = LightningTransformer()
dm = LightningDataModule()
# コールバックを設定
callbacks = [
ModelCheckpoint(
monitor="val_loss",
save_top_k=3,
mode="min",
filename="transformer-{epoch:02d}-{val_loss:.2f}"
),
EarlyStopping(monitor="val_loss", patience=10, mode="min")
]
# トレーナー
trainer = pl.Trainer(
max_epochs=100,
accelerator="auto", # GPUを自動検出
devices="auto",
callbacks=callbacks,
logger=TensorBoardLogger("tb_logs", name="transformer"),
gradient_clip_val=1.0, # 勾配クリッピング
accumulate_grad_batches=4, # 勾配蓄積
precision="16-mixed", # AMP
)
trainer.fit(model, dm)
11. モデル量子化
動的量子化
import torch
import torch.nn as nn
from torch.ao.quantization import quantize_dynamic
# 推論専用量子化 - 最も簡単なアプローチ
model = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 128)
)
model.eval()
# LinearとLSTMレイヤーをint8に量子化
quantized_model = quantize_dynamic(
model,
{nn.Linear, nn.LSTM}, # 量子化するレイヤータイプ
dtype=torch.qint8
)
# サイズ比較
import os
torch.save(model.state_dict(), "model_fp32.pt")
torch.save(quantized_model.state_dict(), "model_int8.pt")
fp32_size = os.path.getsize("model_fp32.pt")
int8_size = os.path.getsize("model_int8.pt")
print(f"FP32サイズ: {fp32_size / 1024:.1f} KB")
print(f"INT8サイズ: {int8_size / 1024:.1f} KB")
print(f"圧縮率: {fp32_size / int8_size:.1f}x")
静的量子化
import torch
from torch.ao.quantization import (
get_default_qconfig,
prepare,
convert,
)
class QuantizableModel(nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.ao.quantization.QuantStub()
self.linear1 = nn.Linear(512, 256)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(256, 10)
self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
x = self.dequant(x)
return x
model = QuantizableModel()
model.eval()
# 量子化設定
model.qconfig = get_default_qconfig("fbgemm") # x86 CPU用
# 準備(オブザーバーを挿入)
prepared_model = prepare(model)
# キャリブレーションデータで統計を収集
with torch.no_grad():
for _ in range(100):
calibration_data = torch.randn(32, 512)
prepared_model(calibration_data)
# 量子化に変換
quantized_static = convert(prepared_model)
# 推論
x = torch.randn(1, 512)
with torch.no_grad():
output = quantized_static(x)
print(f"出力シェイプ: {output.shape}")
QAT(量子化対応学習)
from torch.ao.quantization import (
get_default_qat_qconfig,
prepare_qat,
convert
)
model = QuantizableModel()
model.train()
# QAT設定
model.qconfig = get_default_qat_qconfig("fbgemm")
# QATの準備(疑似量子化ノードを挿入)
prepared_qat = prepare_qat(model)
# 学習(量子化誤差を学習に含める)
optimizer = torch.optim.SGD(prepared_qat.parameters(), lr=0.0001)
for epoch in range(10):
for x, y in dummy_dataloader():
output = prepared_qat(x)
loss = nn.functional.cross_entropy(output, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 学習後に変換
prepared_qat.eval()
quantized_qat = convert(prepared_qat)
def dummy_dataloader():
for _ in range(10):
yield torch.randn(32, 512), torch.randint(0, 10, (32,))
12. テンソル並列とパイプライン並列
DeviceMeshとDTensor API
import torch
import torch.distributed as dist
from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
RowwiseParallel
)
from torch.distributed._tensor import DeviceMesh
# 分散学習を初期化
def setup_distributed():
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
class TransformerMLP(nn.Module):
def __init__(self, d_model=1024, dim_feedforward=4096):
super().__init__()
self.fc1 = nn.Linear(d_model, dim_feedforward)
self.act = nn.GELU()
self.fc2 = nn.Linear(dim_feedforward, d_model)
def forward(self, x):
return self.fc2(self.act(self.fc1(x)))
# テンソル並列を適用
def apply_tensor_parallel(model, mesh):
"""
fc1: 列方向シャーディング(出力次元を分割)
fc2: 行方向シャーディング(入力次元を分割)
"""
parallelize_module(
model,
mesh,
{
"fc1": ColwiseParallel(),
"fc2": RowwiseParallel(),
}
)
return model
# 例(2-GPUセットアップが必要)
# device_mesh = DeviceMesh("cuda", [0, 1])
# model = TransformerMLP()
# model = apply_tensor_parallel(model, device_mesh)
FSDP(完全シャーディングデータ並列)
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
BackwardPrefetch,
ShardingStrategy
)
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
transformer_auto_wrap_policy
)
import functools
def setup_fsdp_model(model):
"""FSDPセットアップ - 複数GPUに大規模モデルを分散"""
# 混合精度設定
mp_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.bfloat16,
)
# TransformerBlockをラッピング単位として設定
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={TransformerBlock}
)
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mp_policy,
sharding_strategy=ShardingStrategy.FULL_SHARD,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
device_id=torch.cuda.current_device()
)
return model
まとめ
このガイドではPyTorchの高度なテクニックを解説しました:
- torch.compile: コード変更なしで2倍以上のパフォーマンス向上
- カスタムAutograd: 特殊な勾配計算の実装
- CUDA拡張: GPUカーネルとPyTorchの統合
- メモリ最適化: Gradient Checkpointing、AMP、8ビットオプティマイザー
- functorch/vmap: 関数型APIによるバッチ処理とメタ学習
- PyTorch Profiler: パフォーマンスボトルネックの分析
- TorchScript/export: デプロイ最適化
- PyTorch Lightning: コード構造化と学習の自動化
- 量子化: INT8でモデルサイズを削減
- 分散学習: テンソル並列、FSDP
これらのテクニックは相互補完的であり、実際のプロジェクトではいくつかを組み合わせるのが一般的です。大規模モデル学習では、AMP + Gradient Checkpointing + FSDP + torch.compileの組み合わせが特に強力です。