Skip to content
Published on

The Complete torchvision Guide — From Image Classification to Object Detection and Segmentation

Authors
  • Name
    Twitter
torchvision Guide

Introduction

torchvision is the official computer vision library for PyTorch. It includes image transforms, pretrained models, datasets, and Object Detection — nearly everything you need for CV.

pip install torch torchvision

Part 1: Transforms v2 — A Revolution in Image Preprocessing

Basic Transforms

import torch
from torchvision import transforms
from torchvision.transforms import v2  # v2 recommended!
from PIL import Image

# Transforms v2 (latest, recommended)
transform = v2.Compose([
    v2.RandomResizedCrop(224),         # Random crop + resize
    v2.RandomHorizontalFlip(p=0.5),    # 50% chance horizontal flip
    v2.ColorJitter(                     # Color augmentation
        brightness=0.2,
        contrast=0.2,
        saturation=0.2,
        hue=0.1
    ),
    v2.ToImage(),                       # PIL -> Tensor image
    v2.ToDtype(torch.float32, scale=True),  # [0, 255] -> [0.0, 1.0]
    v2.Normalize(                       # ImageNet normalization
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

img = Image.open("cat.jpg")
tensor = transform(img)
print(tensor.shape)  # torch.Size([3, 224, 224])

The Key Innovation of v2 — Simultaneous Bounding Box + Mask Transforms

# In v1, only the image was transformed -> BBoxes became misaligned!
# In v2, image + BBox + Mask + Label are transformed simultaneously

from torchvision import tv_tensors

# Object Detection transforms
det_transform = v2.Compose([
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomPhotometricDistort(),
    v2.RandomZoomOut(fill={tv_tensors.Image: (123, 117, 104)}),
    v2.RandomIoUCrop(),
    v2.SanitizeBoundingBoxes(),  # Remove invalid bboxes
    v2.ToDtype(torch.float32, scale=True),
])

# When transforming image + bbox together, bboxes automatically follow!
image = tv_tensors.Image(torch.randint(0, 256, (3, 500, 500), dtype=torch.uint8))
boxes = tv_tensors.BoundingBoxes(
    [[100, 100, 300, 300], [200, 200, 400, 400]],
    format="XYXY",
    canvas_size=(500, 500)
)
labels = torch.tensor([1, 2])

# Simultaneous transform!
out_img, out_boxes, out_labels = det_transform(image, boxes, labels)

Commonly Used Augmentation Recipes

# Training (strong augmentation)
train_transform = v2.Compose([
    v2.RandomResizedCrop(224, scale=(0.6, 1.0)),
    v2.RandomHorizontalFlip(),
    v2.RandAugment(num_ops=2, magnitude=9),  # AutoAugment family
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    v2.RandomErasing(p=0.25),  # CutOut effect
])

# Validation/Inference (no augmentation)
val_transform = v2.Compose([
    v2.Resize(256),
    v2.CenterCrop(224),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

Part 2: Pretrained Models (Model Zoo)

Image Classification Models

from torchvision import models
from torchvision.models import (
    ResNet50_Weights, EfficientNet_V2_S_Weights,
    ViT_B_16_Weights, ConvNeXt_Small_Weights
)

# ResNet-50 (2015, CNN fundamentals)
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
# Top-1: 80.858%, Parameters: 25.6M

# EfficientNet V2 (2021, the efficiency champion)
effnet = models.efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
# Top-1: 84.228%, Parameters: 21.5M

# Vision Transformer (2020, Transformer enters CV)
vit = models.vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1)
# Top-1: 85.304%, Parameters: 86.6M

# ConvNeXt (2022, CNN strikes back — Transformer techniques applied to CNN)
convnext = models.convnext_small(weights=ConvNeXt_Small_Weights.IMAGENET1K_V1)
# Top-1: 83.616%, Parameters: 50.2M
Model selection guide:
+-- Fast inference needed -> MobileNet V3 / EfficientNet-Lite
+-- Accuracy first -> ViT-L / Swin Transformer V2
+-- Balance (recommended for production) -> EfficientNet V2 / ConvNeXt
+-- Learning purposes -> ResNet-50 (fundamentals)

Inference

from torchvision.models import ViT_B_16_Weights

# Load model + preprocessing
weights = ViT_B_16_Weights.IMAGENET1K_V1
model = models.vit_b_16(weights=weights).eval()
preprocess = weights.transforms()

# Inference
img = Image.open("cat.jpg")
batch = preprocess(img).unsqueeze(0)  # [1, 3, 224, 224]

with torch.no_grad():
    logits = model(batch)
    probs = torch.softmax(logits, dim=1)
    top5 = torch.topk(probs, 5)

# Results
categories = weights.meta["categories"]
for prob, idx in zip(top5.values[0], top5.indices[0]):
    print(f"  {categories[idx]:30s} {prob:.2%}")
# tabby cat                       87.23%
# Egyptian cat                    8.41%
# tiger cat                       3.12%

Fine-Tuning (Transfer Learning)

import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

# 1. Load pretrained model
model = models.efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)

# 2. Replace the final classification layer
num_classes = 10  # Number of classes in your dataset
model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)

# 3. Freeze the backbone (optional)
for param in model.features.parameters():
    param.requires_grad = False  # Freeze backbone

# 4. Train only the classification layer
optimizer = AdamW(model.classifier.parameters(), lr=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=20)
criterion = nn.CrossEntropyLoss()

# 5. Training loop
model.train()
for epoch in range(20):
    for images, labels in train_loader:
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    scheduler.step()

# 6. Unfreeze (Fine-tuning phase 2)
for param in model.parameters():
    param.requires_grad = True
optimizer = AdamW(model.parameters(), lr=1e-5)  # Small LR!
# Train for an additional 10 epochs...

Part 3: Object Detection

Faster R-CNN

from torchvision.models.detection import (
    fasterrcnn_resnet50_fpn_v2,
    FasterRCNN_ResNet50_FPN_V2_Weights
)

# Pretrained model (COCO 91 classes)
weights = FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1
model = fasterrcnn_resnet50_fpn_v2(weights=weights).eval()
preprocess = weights.transforms()

# Inference
img = Image.open("street.jpg")
batch = [preprocess(img)]

with torch.no_grad():
    predictions = model(batch)[0]

# Parse results
for box, label, score in zip(
    predictions['boxes'], predictions['labels'], predictions['scores']
):
    if score > 0.7:
        category = weights.meta["categories"][label]
        x1, y1, x2, y2 = box.tolist()
        print(f"  {category}: {score:.2%} at ({x1:.0f},{y1:.0f},{x2:.0f},{y2:.0f})")

Training Detection on a Custom Dataset

from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

# 1. Load pretrained model
model = fasterrcnn_resnet50_fpn_v2(weights="COCO_V1")

# 2. Replace the classification head
num_classes = 5 + 1  # 5 classes + background
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

# 3. Custom dataset
class MyDetectionDataset(torch.utils.data.Dataset):
    def __getitem__(self, idx):
        img = ...  # Load image
        target = {
            "boxes": torch.tensor([[x1,y1,x2,y2], ...], dtype=torch.float32),
            "labels": torch.tensor([1, 3, ...], dtype=torch.int64),
        }
        return img, target

# 4. Training
model.train()
for images, targets in train_loader:
    loss_dict = model(images, targets)
    # loss_dict: {'loss_classifier', 'loss_box_reg', 'loss_objectness', 'loss_rpn_box_reg'}
    total_loss = sum(loss_dict.values())
    total_loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Part 4: Semantic Segmentation

from torchvision.models.segmentation import (
    deeplabv3_resnet101, DeepLabV3_ResNet101_Weights
)

# DeepLab V3 (COCO 21 classes)
weights = DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1
model = deeplabv3_resnet101(weights=weights).eval()
preprocess = weights.transforms()

img = Image.open("city.jpg")
batch = preprocess(img).unsqueeze(0)

with torch.no_grad():
    output = model(batch)["out"]  # [1, 21, H, W]
    pred_mask = output.argmax(dim=1)  # [1, H, W] — per-pixel class

# Visualization
import matplotlib.pyplot as plt
plt.imshow(pred_mask[0].cpu(), cmap="tab20")
plt.title("Semantic Segmentation")
plt.savefig("segmentation.png")

Part 5: Datasets

from torchvision.datasets import (
    CIFAR10, CIFAR100, ImageNet, MNIST,
    FashionMNIST, STL10, Food101, Flowers102,
    CocoDetection, VOCDetection
)

# CIFAR-10 (10 classes, 32x32)
train_set = CIFAR10(root="./data", train=True, download=True, transform=train_transform)

# ImageFolder (custom dataset)
from torchvision.datasets import ImageFolder

# Directory structure:
# data/train/
#   +-- cat/
#   |   +-- cat001.jpg
#   |   +-- cat002.jpg
#   +-- dog/
#       +-- dog001.jpg
#       +-- dog002.jpg

train_set = ImageFolder(root="data/train", transform=train_transform)
print(train_set.classes)      # ['cat', 'dog']
print(train_set.class_to_idx) # {'cat': 0, 'dog': 1}

train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=32, shuffle=True, num_workers=4, pin_memory=True
)

Part 6: Utilities

Visualization

from torchvision.utils import make_grid, draw_bounding_boxes, draw_segmentation_masks
import torchvision.transforms.functional as F

# Batch image grid
grid = make_grid(batch_images, nrow=8, padding=2, normalize=True)
plt.imshow(grid.permute(1, 2, 0))

# Bounding box visualization
from torchvision.utils import draw_bounding_boxes
img_with_boxes = draw_bounding_boxes(
    img_tensor,         # uint8, [3, H, W]
    boxes,              # [N, 4]
    labels=["cat", "dog"],
    colors=["red", "blue"],
    width=3,
    font_size=20
)

# Segmentation mask visualization
img_with_mask = draw_segmentation_masks(
    img_tensor,
    masks=pred_mask.bool(),
    alpha=0.5,
    colors=["red", "green", "blue"]
)

Feature Extraction (Intermediate Layer Output)

from torchvision.models.feature_extraction import create_feature_extractor

model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

# Extract outputs from specific layers
feature_extractor = create_feature_extractor(model, {
    'layer2': 'mid_features',    # 512 dimensions
    'layer4': 'high_features',   # 2048 dimensions
    'avgpool': 'embedding',       # 2048 dimensions (global)
})

with torch.no_grad():
    features = feature_extractor(batch)

print(features['mid_features'].shape)   # [1, 512, 28, 28]
print(features['high_features'].shape)  # [1, 2048, 7, 7]
print(features['embedding'].shape)      # [1, 2048, 1, 1]

Quiz — torchvision (click to reveal!)

Q1. What is the key advantage of Transforms v2 over v1? ||It transforms images, Bounding Boxes, and Segmentation Masks simultaneously. In v1, only images were transformed, causing BBox misalignment issues.||

Q2. What are the ImageNet normalization mean and std values? ||mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]. These are per-channel RGB statistics from the ImageNet training data.||

Q3. What is two-stage fine-tuning in Transfer Learning? ||Stage 1: Freeze backbone + train only the classification head (large LR). Stage 2: Unfreeze all layers + fine-tune with a small LR. This preserves pretrained weights while gradually adapting.||

Q4. What are the four losses in Faster R-CNN? ||loss_classifier (classification), loss_box_reg (box regression), loss_objectness (object presence), loss_rpn_box_reg (RPN box regression).||

Q5. What is the purpose of create_feature_extractor? ||It extracts intermediate layer outputs from a model. You can obtain feature maps from specific layers without running a full forward pass, useful for embedding extraction and feature visualization.||

Q6. What are the two key parameters of RandAugment? ||num_ops: number of augmentation operations to apply. magnitude: transformation strength (0~30). Unlike AutoAugment, it delivers powerful augmentation with just two parameters and no search.||

GitHub