Skip to content
Published on

PyTorch高度なテクニック完全ガイド: torch.compile、カスタムOps、メモリ最適化

Authors

PyTorch高度なテクニック完全ガイド

PyTorchはリサーチとプロダクション両方において最も人気の深層学習フレームワークの一つです。しかし、ほとんどの開発者は基本的なテンソル演算とnn.Moduleの定義で止まっています。このガイドでは、実際のプロダクション環境でパフォーマンスを最大化し、カスタム演算を実装し、メモリを効率的に管理するための高度なテクニックを解説します。

1. torch.compile(PyTorch 2.0以降)

PyTorch 2.0で導入されたtorch.compileは、モデルをコンパイルして実行速度を劇的に改善する機能です。TorchScriptやONNXエクスポートとは異なり、torch.compileは最小限のコード変更で2倍以上のスピードアップをもたらすことができます。

導入とメリット

torch.compileは3つのコアコンポーネントで構成されています:

  1. TorchDynamo: Pythonバイトコードを傍受してFXグラフを生成
  2. AOTAutograd: 自動微分グラフを事前コンパイル
  3. Inductor: TorchInductorバックエンド(Triton GPUカーネルまたはC++ CPUカーネル)を通じて最適化されたカーネルを生成
import torch
import torch.nn as nn
import time

# ベースモデルを定義
class TransformerBlock(nn.Module):
    def __init__(self, d_model=512, nhead=8, dim_feedforward=2048):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, nhead, batch_first=True)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_out)
        ff_out = self.feed_forward(x)
        x = self.norm2(x + ff_out)
        return x

device = "cuda" if torch.cuda.is_available() else "cpu"
model = TransformerBlock().to(device)

# torch.compileを適用
compiled_model = torch.compile(model)

# ウォームアップ
x = torch.randn(32, 128, 512, device=device)
for _ in range(3):
    _ = compiled_model(x)

# パフォーマンス比較
N = 100

# Eagerモード
start = time.perf_counter()
for _ in range(N):
    _ = model(x)
if device == "cuda":
    torch.cuda.synchronize()
elapsed_eager = time.perf_counter() - start

# コンパイルモード
start = time.perf_counter()
for _ in range(N):
    _ = compiled_model(x)
if device == "cuda":
    torch.cuda.synchronize()
elapsed_compiled = time.perf_counter() - start

print(f"Eagerモード: {elapsed_eager:.3f}s")
print(f"コンパイルモード: {elapsed_compiled:.3f}s")
print(f"スピードアップ: {elapsed_eager / elapsed_compiled:.2f}x")

コンパイルモード

torch.compileは3つのモードを提供します:

# デフォルトモード - 高速コンパイル、良好なパフォーマンス
model_default = torch.compile(model, mode="default")

# reduce-overheadモード - 小規模モデルに有効
model_reduce = torch.compile(model, mode="reduce-overhead")

# max-autotune - 長時間コンパイル、最高パフォーマンス
model_autotune = torch.compile(model, mode="max-autotune")

# フルグラフモード - 動的グラフなし(厳格)
model_full = torch.compile(model, fullgraph=True)

# バックエンド選択
model_eager = torch.compile(model, backend="eager")    # コンパイルなし(デバッグ用)
model_aot = torch.compile(model, backend="aot_eager")  # AOTのみ
model_inductor = torch.compile(model, backend="inductor")  # デフォルト

動的シェイプのサポート

import torch._dynamo as dynamo

# 動的シェイプを有効化
model_dynamic = torch.compile(model, dynamic=True)

# 異なるバッチサイズでも再コンパイルなしで動作
for batch_size in [8, 16, 32, 64]:
    x = torch.randn(batch_size, 128, 512, device=device)
    out = model_dynamic(x)
    print(f"Batch {batch_size}: 出力シェイプ {out.shape}")

# コンパイルキャッシュを確認
print(dynamo.explain(model)(x))

既存コードへの移行

# 学習ループにtorch.compileを適用
def train_epoch(model, optimizer, dataloader, criterion):
    model.train()
    total_loss = 0

    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(dataloader)

# モデルをコンパイルするだけ
model = MyModel().cuda()
compiled_model = torch.compile(model)  # この行を追加するだけ!

optimizer = torch.optim.Adam(compiled_model.parameters())
criterion = nn.CrossEntropyLoss()

2. カスタムAutograd

PyTorchの自動微分エンジンは強力ですが、数値安定性や効率のためにカスタム勾配計算が必要な場合があります。

torch.autograd.Functionのサブクラス化

import torch
from torch.autograd import Function

class SigmoidFunction(Function):
    """数値的に安定したシグモイド実装"""

    @staticmethod
    def forward(ctx, input):
        # sigmoid = 1 / (1 + exp(-x))
        sigmoid = torch.sigmoid(input)
        # backwardで使用するテンソルを保存
        ctx.save_for_backward(sigmoid)
        return sigmoid

    @staticmethod
    def backward(ctx, grad_output):
        # 保存したシグモイドを取得
        sigmoid, = ctx.saved_tensors
        # gradient = sigmoid * (1 - sigmoid) * grad_output
        grad_input = sigmoid * (1 - sigmoid) * grad_output
        return grad_input

# 使用方法
def custom_sigmoid(x):
    return SigmoidFunction.apply(x)

x = torch.randn(3, 4, requires_grad=True)
y = custom_sigmoid(x)
loss = y.sum()
loss.backward()
print(f"勾配: {x.grad}")

より複雑な例: カスタムBackwardを持つLeaky ReLU

class LeakyReLUFunction(Function):
    @staticmethod
    def forward(ctx, input, negative_slope=0.01):
        ctx.save_for_backward(input)
        ctx.negative_slope = negative_slope
        return input.clamp(min=0) + negative_slope * input.clamp(max=0)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        negative_slope = ctx.negative_slope
        grad_input = grad_output.clone()
        grad_input[input < 0] *= negative_slope
        # negative_slopeに対する勾配はNone
        return grad_input, None

class CustomLeakyReLU(nn.Module):
    def __init__(self, negative_slope=0.01):
        super().__init__()
        self.negative_slope = negative_slope

    def forward(self, x):
        return LeakyReLUFunction.apply(x, self.negative_slope)

数値勾配チェック

from torch.autograd import gradcheck

def test_custom_op():
    # 数値勾配チェック(float64推奨)
    input = torch.randn(3, 4, dtype=torch.float64, requires_grad=True)

    # gradcheckはヤコビアンを数値的に計算してautogradと比較
    result = gradcheck(SigmoidFunction.apply, (input,), eps=1e-6, atol=1e-4)
    print(f"勾配チェック合格: {result}")

    # 二重逆伝播テスト
    input = torch.randn(3, 4, dtype=torch.float64, requires_grad=True)
    result = gradcheck(
        SigmoidFunction.apply,
        (input,),
        eps=1e-6,
        atol=1e-4,
        check_grad_dtypes=True
    )

test_custom_op()

二重逆伝播

class SquaredFunction(Function):
    """二重逆伝播をサポートするカスタムx^2実装"""

    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x ** 2

    @staticmethod
    def backward(ctx, grad_output):
        x, = ctx.saved_tensors
        # 二重逆伝播をサポートするにはcreate_graph=Trueが必要
        return 2 * x * grad_output

# 二重逆伝播の例(MAMLなどのメタ学習で有用)
x = torch.randn(3, requires_grad=True)
y = SquaredFunction.apply(x)
grad_x = torch.autograd.grad(y.sum(), x, create_graph=True)[0]
# grad_x = 2x、再び逆伝播
grad_grad_x = torch.autograd.grad(grad_x.sum(), x)[0]
# grad_grad_x = 2 (定数)
print(f"二階微分: {grad_grad_x}")

3. カスタムCUDAオペレータ

torch.utils.cpp_extensionの概要

PyTorchはC++/CUDA拡張を書いてPythonから使用するためのツールを提供します。

# JITコンパイル(開発/プロトタイプに適しています)
from torch.utils.cpp_extension import load_inline

# C++ CPUオペレータ
cpp_source = """
#include <torch/extension.h>

torch::Tensor relu_forward(torch::Tensor input) {
    return input.clamp_min(0);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("relu_forward", &relu_forward, "ReLU forward");
}
"""

# インラインコンパイル
custom_relu_cpp = load_inline(
    name="custom_relu_cpp",
    cpp_sources=cpp_source,
    functions=["relu_forward"],
    verbose=False
)

x = torch.randn(5)
result = custom_relu_cpp.relu_forward(x)
print(result)

CUDAカーネルの例: Fused Softmax

# CUDAソース
cuda_source = """
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

__global__ void fused_softmax_kernel(
    const float* input,
    float* output,
    int rows,
    int cols
) {
    int row = blockIdx.x;
    if (row >= rows) return;

    const float* row_input = input + row * cols;
    float* row_output = output + row * cols;

    // 最大値を求める(数値安定性のため)
    float max_val = row_input[0];
    for (int i = 1; i < cols; i++) {
        max_val = fmaxf(max_val, row_input[i]);
    }

    // exp(x - max)と合計を計算
    float sum = 0.0f;
    for (int i = 0; i < cols; i++) {
        row_output[i] = expf(row_input[i] - max_val);
        sum += row_output[i];
    }

    // 正規化
    for (int i = 0; i < cols; i++) {
        row_output[i] /= sum;
    }
}

torch::Tensor fused_softmax_cuda(torch::Tensor input) {
    auto output = torch::zeros_like(input);
    int rows = input.size(0);
    int cols = input.size(1);

    fused_softmax_kernel<<<rows, 1>>>(
        input.data_ptr<float>(),
        output.data_ptr<float>(),
        rows,
        cols
    );

    return output;
}
"""

cpp_source_cuda = """
#include <torch/extension.h>
torch::Tensor fused_softmax_cuda(torch::Tensor input);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("fused_softmax", &fused_softmax_cuda, "Fused Softmax CUDA");
}
"""

# CUDAが利用可能な場合のみコンパイル
if torch.cuda.is_available():
    from torch.utils.cpp_extension import load_inline
    fused_softmax_ext = load_inline(
        name="fused_softmax",
        cpp_sources=cpp_source_cuda,
        cuda_sources=cuda_source,
        functions=["fused_softmax"],
        verbose=True
    )

    # テスト
    x = torch.randn(4, 8, device="cuda")
    result = fused_softmax_ext.fused_softmax(x)
    expected = torch.softmax(x, dim=1)
    print(f"最大差分: {(result - expected).abs().max().item():.6f}")

setup.pyでパッケージをビルド

# setup.py
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
    name="custom_ops",
    ext_modules=[
        CUDAExtension(
            name="custom_ops",
            sources=[
                "custom_ops/ops.cpp",
                "custom_ops/ops_cuda.cu",
            ],
            extra_compile_args={
                "cxx": ["-O3"],
                "nvcc": ["-O3", "--use_fast_math"],
            }
        )
    ],
    cmdclass={
        "build_ext": BuildExtension
    }
)
# ビルド: python setup.py install

4. メモリ最適化テクニック

GPUメモリのプロファイリング

import torch

def print_gpu_memory_stats():
    """GPUメモリ統計を表示するユーティリティ"""
    if torch.cuda.is_available():
        print(f"割り当て済み: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
        print(f"予約済み:  {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
        print(f"最大割り当て: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
        print(torch.cuda.memory_summary(abbreviated=True))

# メモリトラッキングを開始
torch.cuda.reset_peak_memory_stats()
print_gpu_memory_stats()

# モデルを読み込む
model = TransformerBlock().cuda()
print_gpu_memory_stats()

Gradient Checkpointing(活性化の再計算)

Gradient Checkpointingは順伝播中の中間活性化値を保存せず、逆伝播中に必要に応じて再計算します。これにより約30%多い計算コストで大幅なメモリ節約ができます。

from torch.utils.checkpoint import checkpoint, checkpoint_sequential
import torch.nn as nn

class DeepTransformer(nn.Module):
    def __init__(self, num_layers=12, d_model=512, nhead=8):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, nhead) for _ in range(num_layers)
        ])

    def forward_with_checkpointing(self, x):
        """各レイヤーにgradient checkpointingを適用"""
        for layer in self.layers:
            # checkpointは中間活性化を保存しない
            x = checkpoint(layer, x, use_reentrant=False)
        return x

    def forward_sequential_checkpointing(self, x):
        """シーケンシャルモジュールにcheckpointingを適用"""
        # 4レイヤーのセグメントにグループ化
        x = checkpoint_sequential(self.layers, segments=3, input=x)
        return x

# メモリ比較
model = DeepTransformer(num_layers=24).cuda()
x = torch.randn(16, 512, 512, device="cuda")

# 通常の順伝播
torch.cuda.reset_peak_memory_stats()
out = model.forward_with_checkpointing(x)
checkpoint_mem = torch.cuda.max_memory_allocated()

print(f"Checkpointメモリ: {checkpoint_mem / 1024**3:.2f} GB")

勾配蓄積

def train_with_gradient_accumulation(
    model, optimizer, dataloader, criterion,
    accumulation_steps=4
):
    """
    限られたGPUメモリで大きな有効バッチサイズを実装
    """
    model.train()
    optimizer.zero_grad()
    total_loss = 0

    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.cuda(), target.cuda()

        # 順伝播
        output = model(data)
        # lossをaccumulation_stepsでスケール
        loss = criterion(output, target) / accumulation_steps
        loss.backward()

        total_loss += loss.item() * accumulation_steps

        # accumulation_stepsごとに更新
        if (batch_idx + 1) % accumulation_steps == 0:
            # 勾配クリッピング(オプション)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            optimizer.zero_grad()

    return total_loss / len(dataloader)

混合精度学習(AMP)

from torch.cuda.amp import autocast, GradScaler

def train_with_amp(model, optimizer, dataloader, criterion):
    """メモリ節約+速度向上のための自動混合精度"""
    scaler = GradScaler()
    model.train()

    for data, target in dataloader:
        data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()

        # autocastコンテキスト内でfloat16演算
        with autocast():
            output = model(data)
            loss = criterion(output, target)

        # スケールされた勾配での逆伝播
        scaler.scale(loss).backward()

        # スケール解除後に勾配クリッピング
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        # オプティマイザーステップ(NaN/Infチェックあり)
        scaler.step(optimizer)
        scaler.update()

8ビットオプティマイザー

# bitsandbytesが必要: pip install bitsandbytes
try:
    import bitsandbytes as bnb

    # 通常のAdamの代わりに8ビットAdamを使用
    optimizer_8bit = bnb.optim.Adam8bit(
        model.parameters(),
        lr=1e-4,
        betas=(0.9, 0.999)
    )

    # PagedAdam(CPUにオフロード可能)
    optimizer_paged = bnb.optim.PagedAdam(
        model.parameters(),
        lr=1e-4
    )

    print("8ビットオプティマイザーが正常に読み込まれました")
except ImportError:
    print("bitsandbytesがインストールされていません。通常のAdamを使用します")
    optimizer_8bit = torch.optim.Adam(model.parameters(), lr=1e-4)

5. functorchとvmap

vmap - バッチ処理

vmapは単一サンプルを処理する関数をバッチに効率的に適用します。

import torch
from torch import vmap

# 単一サンプルを処理する関数
def single_linear(weight, bias, x):
    return weight @ x + bias

# vmapでバッチ処理
batched_linear = vmap(single_linear)

# バッチデータ
batch_size = 32
weight = torch.randn(batch_size, 10, 5)
bias = torch.randn(batch_size, 10)
x = torch.randn(batch_size, 5)

# バッチ内で自動的に処理
result = batched_linear(weight, bias, x)
print(f"結果シェイプ: {result.shape}")  # (32, 10)

grad - 関数型勾配

from torch.func import grad, vmap, functional_call

# 関数型勾配
def scalar_loss(params, x, y):
    pred = functional_call(model, params, (x,))
    return ((pred - y) ** 2).mean()

# パラメータに対する勾配
params = dict(model.named_parameters())
grad_fn = grad(scalar_loss)

x = torch.randn(1, 10)
y = torch.randn(1, 5)
grads = grad_fn(params, x, y)
print({k: v.shape for k, v in grads.items()})

vmapでアンサンブルモデル

from torch.func import stack_module_state, functional_call, vmap

def create_ensemble(model_class, num_models, *args, **kwargs):
    """vmapを使用した効率的なアンサンブル"""
    models = [model_class(*args, **kwargs) for _ in range(num_models)]

    # 全モデルのパラメータをスタック
    params, buffers = stack_module_state(models)

    # 単一モデルの順伝播
    base_model = model_class(*args, **kwargs)

    def single_forward(params, buffers, x):
        return functional_call(base_model, (params, buffers), (x,))

    # vmapで全モデルを並列実行
    ensemble_forward = vmap(single_forward, in_dims=(0, 0, None))

    return ensemble_forward, params, buffers

# 使用例
ensemble_fn, params, buffers = create_ensemble(
    nn.Linear, num_models=5, in_features=10, out_features=5
)

x = torch.randn(32, 10)
ensemble_out = ensemble_fn(params, buffers, x)
print(f"アンサンブル出力シェイプ: {ensemble_out.shape}")  # (5, 32, 5)

gradとvmapによるメタ学習(MAML)

from torch.func import grad, vmap, functional_call

def inner_loop(params, support_x, support_y, base_model, lr=0.01, steps=5):
    """MAMLの内部ループ"""
    adapted_params = {k: v.clone() for k, v in params.items()}

    for _ in range(steps):
        def loss_fn(params):
            pred = functional_call(base_model, params, (support_x,))
            return ((pred - support_y) ** 2).mean()

        grads = grad(loss_fn)(adapted_params)
        adapted_params = {
            k: p - lr * grads[k]
            for k, p in adapted_params.items()
        }

    return adapted_params

6. PyTorch Profiler

基本的なプロファイリング

import torch
from torch.profiler import profile, record_function, ProfilerActivity

model = TransformerBlock().cuda()
x = torch.randn(32, 128, 512, device="cuda")

# プロファイラーを実行
with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True,
    profile_memory=True,
    with_stack=True
) as prof:
    with record_function("model_inference"):
        for _ in range(10):
            output = model(x)

# 結果を表示
print(prof.key_averages().table(
    sort_by="cuda_time_total",
    row_limit=20
))

# Chrome Traceをエクスポート
prof.export_chrome_trace("trace.json")
# chrome://tracingで開く

TensorBoardとの統合

from torch.profiler import profile, ProfilerActivity, tensorboard_trace_handler

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    schedule=torch.profiler.schedule(
        wait=1,    # 1ステップ待機
        warmup=1,  # 1ウォームアップステップ
        active=3,  # 3ステップをプロファイル
        repeat=2   # 2回繰り返す
    ),
    on_trace_ready=tensorboard_trace_handler("./log/profiler"),
    record_shapes=True,
    profile_memory=True
) as prof:
    for step, (data, target) in enumerate(dataloader):
        train_step(model, optimizer, data, target)
        prof.step()  # スケジュールに従ってプロファイル

# tensorboard --logdir=./log/profiler

詳細なメモリ分析

# メモリスナップショット(PyTorch 2.1以降)
torch.cuda.memory._record_memory_history(max_entries=100000)

# コードを実行
x = torch.randn(100, 100, device="cuda")
y = x @ x.T
z = y.sum()

# スナップショットを保存
snapshot = torch.cuda.memory._snapshot()
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")

print(f"アクティブな割り当て: {len(snapshot['segments'])}")

7. TorchScript

torch.jit.scriptとtorch.jit.trace

import torch
import torch.nn as nn

class ConditionalModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)

    def forward(self, x, flag: bool = True):
        if flag:  # 制御フローによりトレースは不可
            return torch.relu(self.linear(x))
        else:
            return torch.sigmoid(self.linear(x))

model = ConditionalModel()

# script: 制御フローを含む(推奨)
scripted_model = torch.jit.script(model)
print(scripted_model.code)

# trace: 単一パスのみキャプチャ
x = torch.randn(1, 10)
traced_model = torch.jit.trace(model, x)
# 注意: flag=Falseのパスはキャプチャされない

# 保存と読み込み
scripted_model.save("model_scripted.pt")
loaded = torch.jit.load("model_scripted.pt")

TorchScript最適化

# 最適化パスを適用
scripted = torch.jit.script(model)
optimized = torch.jit.optimize_for_inference(scripted)

# C++環境へのエクスポート
# C++では: torch::jit::script::Module m = torch::jit::load("model.pt");

8. 動的シェイプとtorch.export

torch.exportの使用

import torch
from torch.export import export

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)

    def forward(self, x):
        return self.linear(x)

model = MyModel()
example_inputs = (torch.randn(2, 10),)

# 動的シェイプを指定
from torch.export import Dim
batch = Dim("batch", min=1, max=100)

# モデルをエクスポート
exported = export(
    model,
    example_inputs,
    dynamic_shapes={"x": {0: batch}}
)

print(exported)
print(exported.graph_module.code)

# ExportedProgramを実行
result = exported.module()(torch.randn(5, 10))
print(f"結果シェイプ: {result.shape}")

9. カスタムデータセットとサンプラー

IterableDataset

from torch.utils.data import IterableDataset, DataLoader
import torch

class StreamingDataset(IterableDataset):
    """大規模データセットを効率的にストリーミング"""

    def __init__(self, data_paths, transform=None):
        self.data_paths = data_paths
        self.transform = transform

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()

        if worker_info is None:
            # シングルプロセス
            paths = self.data_paths
        else:
            # マルチプロセス: ワーカー間でデータを分割
            per_worker = len(self.data_paths) // worker_info.num_workers
            start = worker_info.id * per_worker
            end = start + per_worker
            paths = self.data_paths[start:end]

        for path in paths:
            # ファイルからデータをストリーミング
            data = self._load_file(path)
            for sample in data:
                if self.transform:
                    sample = self.transform(sample)
                yield sample

    def _load_file(self, path):
        # 実際の実装ではファイルから読み込む
        return [torch.randn(10) for _ in range(100)]

# DataLoaderでpersistent_workersを使用
dataset = StreamingDataset(data_paths=["file1.pt", "file2.pt"])
dataloader = DataLoader(
    dataset,
    batch_size=32,
    num_workers=4,
    persistent_workers=True,  # ワーカープロセスを再利用
    pin_memory=True,           # GPU転送を最適化
    prefetch_factor=2          # プリフェッチするバッチ数
)

カスタムサンプラー

from torch.utils.data import Sampler
import numpy as np

class BalancedClassSampler(Sampler):
    """クラス不均衡問題のための重み付きサンプリング"""

    def __init__(self, dataset, num_samples_per_class=None):
        self.dataset = dataset
        labels = [dataset[i][1] for i in range(len(dataset))]
        self.labels = torch.tensor(labels)

        # クラスごとのインデックス
        self.class_indices = {}
        for cls in torch.unique(self.labels):
            self.class_indices[cls.item()] = (
                self.labels == cls
            ).nonzero(as_tuple=True)[0].tolist()

        self.num_classes = len(self.class_indices)
        self.num_samples_per_class = (
            num_samples_per_class or
            max(len(v) for v in self.class_indices.values())
        )

    def __iter__(self):
        indices = []
        for cls_idx in self.class_indices.values():
            # 各クラスから同じ数をサンプリング(復元抽出)
            sampled = np.random.choice(
                cls_idx,
                self.num_samples_per_class,
                replace=True
            ).tolist()
            indices.extend(sampled)

        # シャッフル
        np.random.shuffle(indices)
        return iter(indices)

    def __len__(self):
        return self.num_classes * self.num_samples_per_class

10. PyTorch Lightning

完全なLightningModuleの例

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

class LightningTransformer(pl.LightningModule):
    def __init__(
        self,
        d_model=512,
        nhead=8,
        num_layers=6,
        num_classes=10,
        learning_rate=1e-4
    ):
        super().__init__()
        self.save_hyperparameters()  # HParamsを自動保存

        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=d_model,
                nhead=nhead,
                batch_first=True
            ),
            num_layers=num_layers
        )
        self.classifier = nn.Linear(d_model, num_classes)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        encoded = self.encoder(x)
        # シーケンスの平均プーリング
        pooled = encoded.mean(dim=1)
        return self.classifier(pooled)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()

        # ロギング
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()

        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=0.01
        )
        # コサインアニーリングスケジューラー
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=100
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss"
            }
        }


class LightningDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32):
        super().__init__()
        self.batch_size = batch_size

    def setup(self, stage=None):
        # ダミーデータ
        x = torch.randn(1000, 32, 512)
        y = torch.randint(0, 10, (1000,))
        dataset = TensorDataset(x, y)
        n = len(dataset)
        self.train_ds = torch.utils.data.Subset(dataset, range(int(0.8*n)))
        self.val_ds = torch.utils.data.Subset(dataset, range(int(0.8*n), n))

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=self.batch_size)


# 学習を実行
model = LightningTransformer()
dm = LightningDataModule()

# コールバックを設定
callbacks = [
    ModelCheckpoint(
        monitor="val_loss",
        save_top_k=3,
        mode="min",
        filename="transformer-{epoch:02d}-{val_loss:.2f}"
    ),
    EarlyStopping(monitor="val_loss", patience=10, mode="min")
]

# トレーナー
trainer = pl.Trainer(
    max_epochs=100,
    accelerator="auto",    # GPUを自動検出
    devices="auto",
    callbacks=callbacks,
    logger=TensorBoardLogger("tb_logs", name="transformer"),
    gradient_clip_val=1.0,  # 勾配クリッピング
    accumulate_grad_batches=4,  # 勾配蓄積
    precision="16-mixed",   # AMP
)

trainer.fit(model, dm)

11. モデル量子化

動的量子化

import torch
import torch.nn as nn
from torch.ao.quantization import quantize_dynamic

# 推論専用量子化 - 最も簡単なアプローチ
model = nn.Sequential(
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Linear(256, 128)
)
model.eval()

# LinearとLSTMレイヤーをint8に量子化
quantized_model = quantize_dynamic(
    model,
    {nn.Linear, nn.LSTM},  # 量子化するレイヤータイプ
    dtype=torch.qint8
)

# サイズ比較
import os
torch.save(model.state_dict(), "model_fp32.pt")
torch.save(quantized_model.state_dict(), "model_int8.pt")
fp32_size = os.path.getsize("model_fp32.pt")
int8_size = os.path.getsize("model_int8.pt")
print(f"FP32サイズ: {fp32_size / 1024:.1f} KB")
print(f"INT8サイズ: {int8_size / 1024:.1f} KB")
print(f"圧縮率: {fp32_size / int8_size:.1f}x")

静的量子化

import torch
from torch.ao.quantization import (
    get_default_qconfig,
    prepare,
    convert,
)

class QuantizableModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = torch.ao.quantization.QuantStub()
        self.linear1 = nn.Linear(512, 256)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(256, 10)
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.dequant(x)
        return x

model = QuantizableModel()
model.eval()

# 量子化設定
model.qconfig = get_default_qconfig("fbgemm")  # x86 CPU用

# 準備(オブザーバーを挿入)
prepared_model = prepare(model)

# キャリブレーションデータで統計を収集
with torch.no_grad():
    for _ in range(100):
        calibration_data = torch.randn(32, 512)
        prepared_model(calibration_data)

# 量子化に変換
quantized_static = convert(prepared_model)

# 推論
x = torch.randn(1, 512)
with torch.no_grad():
    output = quantized_static(x)
print(f"出力シェイプ: {output.shape}")

QAT(量子化対応学習)

from torch.ao.quantization import (
    get_default_qat_qconfig,
    prepare_qat,
    convert
)

model = QuantizableModel()
model.train()

# QAT設定
model.qconfig = get_default_qat_qconfig("fbgemm")

# QATの準備(疑似量子化ノードを挿入)
prepared_qat = prepare_qat(model)

# 学習(量子化誤差を学習に含める)
optimizer = torch.optim.SGD(prepared_qat.parameters(), lr=0.0001)
for epoch in range(10):
    for x, y in dummy_dataloader():
        output = prepared_qat(x)
        loss = nn.functional.cross_entropy(output, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# 学習後に変換
prepared_qat.eval()
quantized_qat = convert(prepared_qat)


def dummy_dataloader():
    for _ in range(10):
        yield torch.randn(32, 512), torch.randint(0, 10, (32,))

12. テンソル並列とパイプライン並列

DeviceMeshとDTensor API

import torch
import torch.distributed as dist
from torch.distributed.tensor.parallel import (
    parallelize_module,
    ColwiseParallel,
    RowwiseParallel
)
from torch.distributed._tensor import DeviceMesh

# 分散学習を初期化
def setup_distributed():
    dist.init_process_group(backend="nccl")
    torch.cuda.set_device(dist.get_rank())

class TransformerMLP(nn.Module):
    def __init__(self, d_model=1024, dim_feedforward=4096):
        super().__init__()
        self.fc1 = nn.Linear(d_model, dim_feedforward)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(dim_feedforward, d_model)

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))

# テンソル並列を適用
def apply_tensor_parallel(model, mesh):
    """
    fc1: 列方向シャーディング(出力次元を分割)
    fc2: 行方向シャーディング(入力次元を分割)
    """
    parallelize_module(
        model,
        mesh,
        {
            "fc1": ColwiseParallel(),
            "fc2": RowwiseParallel(),
        }
    )
    return model

# 例(2-GPUセットアップが必要)
# device_mesh = DeviceMesh("cuda", [0, 1])
# model = TransformerMLP()
# model = apply_tensor_parallel(model, device_mesh)

FSDP(完全シャーディングデータ並列)

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    BackwardPrefetch,
    ShardingStrategy
)
from torch.distributed.fsdp.wrap import (
    size_based_auto_wrap_policy,
    transformer_auto_wrap_policy
)
import functools

def setup_fsdp_model(model):
    """FSDPセットアップ - 複数GPUに大規模モデルを分散"""

    # 混合精度設定
    mp_policy = MixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.float32,
        buffer_dtype=torch.bfloat16,
    )

    # TransformerBlockをラッピング単位として設定
    auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={TransformerBlock}
    )

    model = FSDP(
        model,
        auto_wrap_policy=auto_wrap_policy,
        mixed_precision=mp_policy,
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        device_id=torch.cuda.current_device()
    )

    return model

まとめ

このガイドではPyTorchの高度なテクニックを解説しました:

  1. torch.compile: コード変更なしで2倍以上のパフォーマンス向上
  2. カスタムAutograd: 特殊な勾配計算の実装
  3. CUDA拡張: GPUカーネルとPyTorchの統合
  4. メモリ最適化: Gradient Checkpointing、AMP、8ビットオプティマイザー
  5. functorch/vmap: 関数型APIによるバッチ処理とメタ学習
  6. PyTorch Profiler: パフォーマンスボトルネックの分析
  7. TorchScript/export: デプロイ最適化
  8. PyTorch Lightning: コード構造化と学習の自動化
  9. 量子化: INT8でモデルサイズを削減
  10. 分散学習: テンソル並列、FSDP

これらのテクニックは相互補完的であり、実際のプロジェクトではいくつかを組み合わせるのが一般的です。大規模モデル学習では、AMP + Gradient Checkpointing + FSDP + torch.compileの組み合わせが特に強力です。

参考文献