Skip to content
Published on

torchvision完全ガイド — 画像分類からObject Detection、Segmentationまで

Authors
  • Name
    Twitter
torchvision Guide

はじめに

torchvisionはPyTorchの公式コンピュータビジョンライブラリです。画像変換、事前学習済みモデル、データセット、そしてObject Detectionまで — CVに必要なほぼすべてが含まれています。

pip install torch torchvision

Part 1: Transforms v2 — 画像前処理の革新

基本変換

import torch
from torchvision import transforms
from torchvision.transforms import v2  # v2推奨!
from PIL import Image

# Transforms v2(最新、推奨)
transform = v2.Compose([
    v2.RandomResizedCrop(224),         # ランダムクロップ+リサイズ
    v2.RandomHorizontalFlip(p=0.5),    # 50%の確率で左右反転
    v2.ColorJitter(                     # 色変換
        brightness=0.2,
        contrast=0.2,
        saturation=0.2,
        hue=0.1
    ),
    v2.ToImage(),                       # PIL → Tensor画像
    v2.ToDtype(torch.float32, scale=True),  # [0, 255] → [0.0, 1.0]
    v2.Normalize(                       # ImageNet正規化
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

img = Image.open("cat.jpg")
tensor = transform(img)
print(tensor.shape)  # torch.Size([3, 224, 224])

v2の核心 — Bounding Box + Maskの同時変換

# v1では画像のみ変換 → BBoxがずれる!
# v2では画像 + BBox + Mask + Labelを同時に変換

from torchvision import tv_tensors

# Object Detection用変換
det_transform = v2.Compose([
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomPhotometricDistort(),
    v2.RandomZoomOut(fill={tv_tensors.Image: (123, 117, 104)}),
    v2.RandomIoUCrop(),
    v2.SanitizeBoundingBoxes(),  # 無効なbboxを除去
    v2.ToDtype(torch.float32, scale=True),
])

# 画像 + bboxを一緒に変換すると、bboxも自動的に追従!
image = tv_tensors.Image(torch.randint(0, 256, (3, 500, 500), dtype=torch.uint8))
boxes = tv_tensors.BoundingBoxes(
    [[100, 100, 300, 300], [200, 200, 400, 400]],
    format="XYXY",
    canvas_size=(500, 500)
)
labels = torch.tensor([1, 2])

# 同時変換!
out_img, out_boxes, out_labels = det_transform(image, boxes, labels)

よく使うAugmentationレシピ

# 学習用(強いaugmentation)
train_transform = v2.Compose([
    v2.RandomResizedCrop(224, scale=(0.6, 1.0)),
    v2.RandomHorizontalFlip(),
    v2.RandAugment(num_ops=2, magnitude=9),  # AutoAugment系列
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    v2.RandomErasing(p=0.25),  # CutOut効果
])

# 検証/推論用(変形なし)
val_transform = v2.Compose([
    v2.Resize(256),
    v2.CenterCrop(224),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

Part 2: 事前学習済みモデル(Model Zoo)

画像分類モデル

from torchvision import models
from torchvision.models import (
    ResNet50_Weights, EfficientNet_V2_S_Weights,
    ViT_B_16_Weights, ConvNeXt_Small_Weights
)

# ResNet-50 (2015, CNNの基本)
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
# Top-1: 80.858%, パラメータ: 25.6M

# EfficientNet V2 (2021, 効率性の王者)
effnet = models.efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
# Top-1: 84.228%, パラメータ: 21.5M

# Vision Transformer (2020, TransformerのCV進出)
vit = models.vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1)
# Top-1: 85.304%, パラメータ: 86.6M

# ConvNeXt (2022, CNNの逆襲 — Transformer手法をCNNに)
convnext = models.convnext_small(weights=ConvNeXt_Small_Weights.IMAGENET1K_V1)
# Top-1: 83.616%, パラメータ: 50.2M
モデル選択ガイド:
├── 高速推論が必要 → MobileNet V3 / EfficientNet-Lite
├── 精度最優先 → ViT-L / Swin Transformer V2
├── バランス(実務推奨) → EfficientNet V2 / ConvNeXt
└── 学習/理解目的 → ResNet-50(基本)

推論(Inference)

from torchvision.models import ViT_B_16_Weights

# モデル+前処理の読み込み
weights = ViT_B_16_Weights.IMAGENET1K_V1
model = models.vit_b_16(weights=weights).eval()
preprocess = weights.transforms()

# 推論
img = Image.open("cat.jpg")
batch = preprocess(img).unsqueeze(0)  # [1, 3, 224, 224]

with torch.no_grad():
    logits = model(batch)
    probs = torch.softmax(logits, dim=1)
    top5 = torch.topk(probs, 5)

# 結果
categories = weights.meta["categories"]
for prob, idx in zip(top5.values[0], top5.indices[0]):
    print(f"  {categories[idx]:30s} {prob:.2%}")
# tabby cat                       87.23%
# Egyptian cat                    8.41%
# tiger cat                       3.12%

ファインチューニング(Transfer Learning)

import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

# 1. 事前学習済みモデルの読み込み
model = models.efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)

# 2. 最終分類レイヤーのみ置換
num_classes = 10  # データセットのクラス数
model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)

# 3. バックボーンのフリーズ(オプション)
for param in model.features.parameters():
    param.requires_grad = False  # バックボーン固定

# 4. 分類レイヤーのみ学習
optimizer = AdamW(model.classifier.parameters(), lr=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=20)
criterion = nn.CrossEntropyLoss()

# 5. 学習ループ
model.train()
for epoch in range(20):
    for images, labels in train_loader:
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    scheduler.step()

# 6. アンフリーズ(ファインチューニング第2段階)
for param in model.parameters():
    param.requires_grad = True
optimizer = AdamW(model.parameters(), lr=1e-5)  # 小さいLR!
# 追加10エポック学習...

Part 3: Object Detection

Faster R-CNN

from torchvision.models.detection import (
    fasterrcnn_resnet50_fpn_v2,
    FasterRCNN_ResNet50_FPN_V2_Weights
)

# 事前学習済みモデル(COCO 91クラス)
weights = FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1
model = fasterrcnn_resnet50_fpn_v2(weights=weights).eval()
preprocess = weights.transforms()

# 推論
img = Image.open("street.jpg")
batch = [preprocess(img)]

with torch.no_grad():
    predictions = model(batch)[0]

# 結果パース
for box, label, score in zip(
    predictions['boxes'], predictions['labels'], predictions['scores']
):
    if score > 0.7:
        category = weights.meta["categories"][label]
        x1, y1, x2, y2 = box.tolist()
        print(f"  {category}: {score:.2%} at ({x1:.0f},{y1:.0f},{x2:.0f},{y2:.0f})")

カスタムデータセットでのDetection学習

from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

# 1. 事前学習済みモデルの読み込み
model = fasterrcnn_resnet50_fpn_v2(weights="COCO_V1")

# 2. 分類ヘッドの置換
num_classes = 5 + 1  # 5クラス + 背景
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

# 3. カスタムデータセット
class MyDetectionDataset(torch.utils.data.Dataset):
    def __getitem__(self, idx):
        img = ...  # 画像読み込み
        target = {
            "boxes": torch.tensor([[x1,y1,x2,y2], ...], dtype=torch.float32),
            "labels": torch.tensor([1, 3, ...], dtype=torch.int64),
        }
        return img, target

# 4. 学習
model.train()
for images, targets in train_loader:
    loss_dict = model(images, targets)
    # loss_dict: {'loss_classifier', 'loss_box_reg', 'loss_objectness', 'loss_rpn_box_reg'}
    total_loss = sum(loss_dict.values())
    total_loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Part 4: Semantic Segmentation

from torchvision.models.segmentation import (
    deeplabv3_resnet101, DeepLabV3_ResNet101_Weights
)

# DeepLab V3 (COCO 21クラス)
weights = DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1
model = deeplabv3_resnet101(weights=weights).eval()
preprocess = weights.transforms()

img = Image.open("city.jpg")
batch = preprocess(img).unsqueeze(0)

with torch.no_grad():
    output = model(batch)["out"]  # [1, 21, H, W]
    pred_mask = output.argmax(dim=1)  # [1, H, W] — ピクセルごとのクラス

# 可視化
import matplotlib.pyplot as plt
plt.imshow(pred_mask[0].cpu(), cmap="tab20")
plt.title("Semantic Segmentation")
plt.savefig("segmentation.png")

Part 5: データセット

from torchvision.datasets import (
    CIFAR10, CIFAR100, ImageNet, MNIST,
    FashionMNIST, STL10, Food101, Flowers102,
    CocoDetection, VOCDetection
)

# CIFAR-10 (10クラス, 32x32)
train_set = CIFAR10(root="./data", train=True, download=True, transform=train_transform)

# ImageFolder(カスタムデータセット)
from torchvision.datasets import ImageFolder

# ディレクトリ構造:
# data/train/
#   ├── cat/
#   │   ├── cat001.jpg
#   │   └── cat002.jpg
#   └── dog/
#       ├── dog001.jpg
#       └── dog002.jpg

train_set = ImageFolder(root="data/train", transform=train_transform)
print(train_set.classes)      # ['cat', 'dog']
print(train_set.class_to_idx) # {'cat': 0, 'dog': 1}

train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=32, shuffle=True, num_workers=4, pin_memory=True
)

Part 6: ユーティリティ

可視化

from torchvision.utils import make_grid, draw_bounding_boxes, draw_segmentation_masks
import torchvision.transforms.functional as F

# バッチ画像グリッド
grid = make_grid(batch_images, nrow=8, padding=2, normalize=True)
plt.imshow(grid.permute(1, 2, 0))

# Bounding Box可視化
from torchvision.utils import draw_bounding_boxes
img_with_boxes = draw_bounding_boxes(
    img_tensor,         # uint8, [3, H, W]
    boxes,              # [N, 4]
    labels=["cat", "dog"],
    colors=["red", "blue"],
    width=3,
    font_size=20
)

# Segmentation Mask可視化
img_with_mask = draw_segmentation_masks(
    img_tensor,
    masks=pred_mask.bool(),
    alpha=0.5,
    colors=["red", "green", "blue"]
)

Feature Extraction(中間レイヤー出力)

from torchvision.models.feature_extraction import create_feature_extractor

model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

# 特定レイヤーの出力のみ抽出
feature_extractor = create_feature_extractor(model, {
    'layer2': 'mid_features',    # 512次元
    'layer4': 'high_features',   # 2048次元
    'avgpool': 'embedding',       # 2048次元(グローバル)
})

with torch.no_grad():
    features = feature_extractor(batch)

print(features['mid_features'].shape)   # [1, 512, 28, 28]
print(features['high_features'].shape)  # [1, 2048, 7, 7]
print(features['embedding'].shape)      # [1, 2048, 1, 1]

クイズ — torchvision(クリックして確認!)

Q1. Transforms v2がv1より優れている核心的な理由は? ||画像とBounding Box、Segmentation Maskを同時に変換できる。v1では画像のみが変換され、BBoxがずれる問題があった。||

Q2. ImageNet正規化のmeanとstdの値は? ||mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]。ImageNet学習データのRGBチャンネルごとの統計値。||

Q3. Transfer Learningにおける2段階ファインチューニングとは? ||第1段階: バックボーンをフリーズ+分類ヘッドのみ学習(大きいLR)。第2段階: 全体をアンフリーズ+小さいLRで全体ファインチューニング。事前学習済み重みを保存しながら段階的に適応。||

Q4. Faster R-CNNの4つのlossとは? ||loss_classifier(分類)、loss_box_reg(ボックス回帰)、loss_objectness(オブジェクト存在有無)、loss_rpn_box_reg(RPNボックス回帰)。||

Q5. create_feature_extractorの用途は? ||モデルの中間レイヤー出力を抽出する。完全なforwardなしで特定レイヤーのfeature mapのみ取得でき、埋め込み抽出やfeature可視化などに活用。||

Q6. RandAugmentの核心パラメータ2つは? ||num_ops: 適用するaugmentation操作の数。magnitude: 変換の強度(0〜30)。AutoAugmentとは異なり、探索なしで2つのパラメータだけで強力なaugmentationを実現。||

関連シリーズ&おすすめ記事

GitHub