Split View: torchvision 완전 가이드 — 이미지 분류부터 Object Detection, Segmentation까지
torchvision 완전 가이드 — 이미지 분류부터 Object Detection, Segmentation까지
- 들어가며
- Part 1: Transforms v2 — 이미지 전처리의 혁신
- Part 2: 사전학습 모델 (Model Zoo)
- Part 3: Object Detection
- Part 4: Semantic Segmentation
- Part 5: 데이터셋
- Part 6: 유틸리티
- 📖 관련 시리즈 & 추천 포스팅

들어가며
torchvision은 PyTorch의 공식 컴퓨터 비전 라이브러리입니다. 이미지 변환, 사전학습 모델, 데이터셋, 그리고 Object Detection까지 — CV에 필요한 거의 모든 것이 들어있습니다.
pip install torch torchvision
Part 1: Transforms v2 — 이미지 전처리의 혁신
기본 변환
import torch
from torchvision import transforms
from torchvision.transforms import v2 # v2 권장!
from PIL import Image
# ✅ Transforms v2 (최신, 권장)
transform = v2.Compose([
v2.RandomResizedCrop(224), # 랜덤 크롭 + 리사이즈
v2.RandomHorizontalFlip(p=0.5), # 50% 확률로 좌우 반전
v2.ColorJitter( # 색상 변형
brightness=0.2,
contrast=0.2,
saturation=0.2,
hue=0.1
),
v2.ToImage(), # PIL → Tensor 이미지
v2.ToDtype(torch.float32, scale=True), # [0, 255] → [0.0, 1.0]
v2.Normalize( # ImageNet 정규화
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
img = Image.open("cat.jpg")
tensor = transform(img)
print(tensor.shape) # torch.Size([3, 224, 224])
v2의 핵심 — Bounding Box + Mask 동시 변환
# v1에서는 이미지만 변환 → BBox가 어긋남!
# v2에서는 이미지 + BBox + Mask + Label을 동시에 변환
from torchvision import tv_tensors
# Object Detection용 변환
det_transform = v2.Compose([
v2.RandomHorizontalFlip(p=0.5),
v2.RandomPhotometricDistort(),
v2.RandomZoomOut(fill={tv_tensors.Image: (123, 117, 104)}),
v2.RandomIoUCrop(),
v2.SanitizeBoundingBoxes(), # 유효하지 않은 bbox 제거
v2.ToDtype(torch.float32, scale=True),
])
# 이미지 + bbox를 함께 변환하면 bbox도 자동으로 따라감!
image = tv_tensors.Image(torch.randint(0, 256, (3, 500, 500), dtype=torch.uint8))
boxes = tv_tensors.BoundingBoxes(
[[100, 100, 300, 300], [200, 200, 400, 400]],
format="XYXY",
canvas_size=(500, 500)
)
labels = torch.tensor([1, 2])
# 동시 변환!
out_img, out_boxes, out_labels = det_transform(image, boxes, labels)
자주 쓰는 Augmentation 레시피
# 학습용 (강한 augmentation)
train_transform = v2.Compose([
v2.RandomResizedCrop(224, scale=(0.6, 1.0)),
v2.RandomHorizontalFlip(),
v2.RandAugment(num_ops=2, magnitude=9), # AutoAugment 계열
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
v2.RandomErasing(p=0.25), # CutOut 효과
])
# 검증/추론용 (변형 없음)
val_transform = v2.Compose([
v2.Resize(256),
v2.CenterCrop(224),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
Part 2: 사전학습 모델 (Model Zoo)
이미지 분류 모델
from torchvision import models
from torchvision.models import (
ResNet50_Weights, EfficientNet_V2_S_Weights,
ViT_B_16_Weights, ConvNeXt_Small_Weights
)
# ResNet-50 (2015, CNN의 기본기)
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
# Top-1: 80.858%, 파라미터: 25.6M
# EfficientNet V2 (2021, 효율성의 왕)
effnet = models.efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
# Top-1: 84.228%, 파라미터: 21.5M
# Vision Transformer (2020, Transformer의 CV 진출)
vit = models.vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1)
# Top-1: 85.304%, 파라미터: 86.6M
# ConvNeXt (2022, CNN의 역습 — Transformer 기법을 CNN에)
convnext = models.convnext_small(weights=ConvNeXt_Small_Weights.IMAGENET1K_V1)
# Top-1: 83.616%, 파라미터: 50.2M
모델 선택 가이드:
├── 빠른 추론 필요 → MobileNet V3 / EfficientNet-Lite
├── 정확도 최우선 → ViT-L / Swin Transformer V2
├── 균형 (실무 추천) → EfficientNet V2 / ConvNeXt
└── 학습/이해 목적 → ResNet-50 (기본기)
추론 (Inference)
from torchvision.models import ViT_B_16_Weights
# 모델 + 전처리 로드
weights = ViT_B_16_Weights.IMAGENET1K_V1
model = models.vit_b_16(weights=weights).eval()
preprocess = weights.transforms()
# 추론
img = Image.open("cat.jpg")
batch = preprocess(img).unsqueeze(0) # [1, 3, 224, 224]
with torch.no_grad():
logits = model(batch)
probs = torch.softmax(logits, dim=1)
top5 = torch.topk(probs, 5)
# 결과
categories = weights.meta["categories"]
for prob, idx in zip(top5.values[0], top5.indices[0]):
print(f" {categories[idx]:30s} {prob:.2%}")
# tabby cat 87.23%
# Egyptian cat 8.41%
# tiger cat 3.12%
파인튜닝 (Transfer Learning)
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
# 1. 사전학습 모델 로드
model = models.efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
# 2. 마지막 분류 레이어만 교체
num_classes = 10 # 내 데이터셋 클래스 수
model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
# 3. 백본 프리징 (선택)
for param in model.features.parameters():
param.requires_grad = False # 백본 고정
# 4. 분류 레이어만 학습
optimizer = AdamW(model.classifier.parameters(), lr=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=20)
criterion = nn.CrossEntropyLoss()
# 5. 학습 루프
model.train()
for epoch in range(20):
for images, labels in train_loader:
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
scheduler.step()
# 6. 언프리징 (Fine-tuning 2단계)
for param in model.parameters():
param.requires_grad = True
optimizer = AdamW(model.parameters(), lr=1e-5) # 작은 LR!
# 추가 10 에포크 학습...
Part 3: Object Detection
Faster R-CNN
from torchvision.models.detection import (
fasterrcnn_resnet50_fpn_v2,
FasterRCNN_ResNet50_FPN_V2_Weights
)
# 사전학습 모델 (COCO 91 classes)
weights = FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1
model = fasterrcnn_resnet50_fpn_v2(weights=weights).eval()
preprocess = weights.transforms()
# 추론
img = Image.open("street.jpg")
batch = [preprocess(img)]
with torch.no_grad():
predictions = model(batch)[0]
# 결과 파싱
for box, label, score in zip(
predictions['boxes'], predictions['labels'], predictions['scores']
):
if score > 0.7:
category = weights.meta["categories"][label]
x1, y1, x2, y2 = box.tolist()
print(f" {category}: {score:.2%} at ({x1:.0f},{y1:.0f},{x2:.0f},{y2:.0f})")
커스텀 데이터셋으로 Detection 학습
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
# 1. 사전학습 모델 로드
model = fasterrcnn_resnet50_fpn_v2(weights="COCO_V1")
# 2. 분류 헤드 교체
num_classes = 5 + 1 # 5 클래스 + 배경
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
# 3. 커스텀 데이터셋
class MyDetectionDataset(torch.utils.data.Dataset):
def __getitem__(self, idx):
img = ... # 이미지 로드
target = {
"boxes": torch.tensor([[x1,y1,x2,y2], ...], dtype=torch.float32),
"labels": torch.tensor([1, 3, ...], dtype=torch.int64),
}
return img, target
# 4. 학습
model.train()
for images, targets in train_loader:
loss_dict = model(images, targets)
# loss_dict: {'loss_classifier', 'loss_box_reg', 'loss_objectness', 'loss_rpn_box_reg'}
total_loss = sum(loss_dict.values())
total_loss.backward()
optimizer.step()
optimizer.zero_grad()
Part 4: Semantic Segmentation
from torchvision.models.segmentation import (
deeplabv3_resnet101, DeepLabV3_ResNet101_Weights
)
# DeepLab V3 (COCO 21 classes)
weights = DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1
model = deeplabv3_resnet101(weights=weights).eval()
preprocess = weights.transforms()
img = Image.open("city.jpg")
batch = preprocess(img).unsqueeze(0)
with torch.no_grad():
output = model(batch)["out"] # [1, 21, H, W]
pred_mask = output.argmax(dim=1) # [1, H, W] — 픽셀별 클래스
# 시각화
import matplotlib.pyplot as plt
plt.imshow(pred_mask[0].cpu(), cmap="tab20")
plt.title("Semantic Segmentation")
plt.savefig("segmentation.png")
Part 5: 데이터셋
from torchvision.datasets import (
CIFAR10, CIFAR100, ImageNet, MNIST,
FashionMNIST, STL10, Food101, Flowers102,
CocoDetection, VOCDetection
)
# CIFAR-10 (10 classes, 32x32)
train_set = CIFAR10(root="./data", train=True, download=True, transform=train_transform)
# ImageFolder (커스텀 데이터셋)
from torchvision.datasets import ImageFolder
# 디렉토리 구조:
# data/train/
# ├── cat/
# │ ├── cat001.jpg
# │ └── cat002.jpg
# └── dog/
# ├── dog001.jpg
# └── dog002.jpg
train_set = ImageFolder(root="data/train", transform=train_transform)
print(train_set.classes) # ['cat', 'dog']
print(train_set.class_to_idx) # {'cat': 0, 'dog': 1}
train_loader = torch.utils.data.DataLoader(
train_set, batch_size=32, shuffle=True, num_workers=4, pin_memory=True
)
Part 6: 유틸리티
시각화
from torchvision.utils import make_grid, draw_bounding_boxes, draw_segmentation_masks
import torchvision.transforms.functional as F
# 배치 이미지 그리드
grid = make_grid(batch_images, nrow=8, padding=2, normalize=True)
plt.imshow(grid.permute(1, 2, 0))
# Bounding Box 시각화
from torchvision.utils import draw_bounding_boxes
img_with_boxes = draw_bounding_boxes(
img_tensor, # uint8, [3, H, W]
boxes, # [N, 4]
labels=["cat", "dog"],
colors=["red", "blue"],
width=3,
font_size=20
)
# Segmentation Mask 시각화
img_with_mask = draw_segmentation_masks(
img_tensor,
masks=pred_mask.bool(),
alpha=0.5,
colors=["red", "green", "blue"]
)
Feature Extraction (중간 레이어 출력)
from torchvision.models.feature_extraction import create_feature_extractor
model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
# 원하는 레이어의 출력만 추출
feature_extractor = create_feature_extractor(model, {
'layer2': 'mid_features', # 512차원
'layer4': 'high_features', # 2048차원
'avgpool': 'embedding', # 2048차원 (글로벌)
})
with torch.no_grad():
features = feature_extractor(batch)
print(features['mid_features'].shape) # [1, 512, 28, 28]
print(features['high_features'].shape) # [1, 2048, 7, 7]
print(features['embedding'].shape) # [1, 2048, 1, 1]
📝 퀴즈 — torchvision (클릭해서 확인!)
Q1. Transforms v2가 v1보다 좋은 핵심 이유는? ||이미지와 Bounding Box, Segmentation Mask를 동시에 변환. v1에서는 이미지만 변환되어 BBox가 어긋나는 문제가 있었음||
Q2. ImageNet 정규화 값 mean과 std는? ||mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]. ImageNet 학습 데이터의 RGB 채널별 통계값||
Q3. Transfer Learning에서 2단계 파인튜닝이란? ||1단계: 백본 프리징 + 분류 헤드만 학습 (큰 LR). 2단계: 전체 언프리징 + 작은 LR로 전체 파인튜닝. 사전학습 가중치를 보존하면서 점진적으로 적응||
Q4. Faster R-CNN의 loss 4가지는? ||loss_classifier (분류), loss_box_reg (박스 회귀), loss_objectness (객체 존재 여부), loss_rpn_box_reg (RPN 박스 회귀)||
Q5. create_feature_extractor의 용도는? ||모델의 중간 레이어 출력을 추출. 전체 forward 없이 특정 레이어의 feature map만 얻을 수 있어 임베딩 추출, feature 시각화 등에 활용||
Q6. RandAugment의 핵심 파라미터 2개는? ||num_ops: 적용할 augmentation 연산 수. magnitude: 변환 강도 (0~30). AutoAugment와 달리 검색 없이 2개 파라미터만으로 강력한 augmentation||
📖 관련 시리즈 & 추천 포스팅
- torchaudio 완전 가이드 — 오디오 AI (자매편)
- AI를 위한 수학 완전 가이드 — CNN/ViT 이해에 필요한 수학
- 나만의 GPT 만들기 — ViT의 원형인 Transformer
GitHub
The Complete torchvision Guide — From Image Classification to Object Detection and Segmentation
- Introduction
- Part 1: Transforms v2 — A Revolution in Image Preprocessing
- Part 2: Pretrained Models (Model Zoo)
- Part 3: Object Detection
- Part 4: Semantic Segmentation
- Part 5: Datasets
- Part 6: Utilities
- Related Series and Recommended Posts
- Quiz

Introduction
torchvision is the official computer vision library for PyTorch. It includes image transforms, pretrained models, datasets, and Object Detection — nearly everything you need for CV.
pip install torch torchvision
Part 1: Transforms v2 — A Revolution in Image Preprocessing
Basic Transforms
import torch
from torchvision import transforms
from torchvision.transforms import v2 # v2 recommended!
from PIL import Image
# Transforms v2 (latest, recommended)
transform = v2.Compose([
v2.RandomResizedCrop(224), # Random crop + resize
v2.RandomHorizontalFlip(p=0.5), # 50% chance horizontal flip
v2.ColorJitter( # Color augmentation
brightness=0.2,
contrast=0.2,
saturation=0.2,
hue=0.1
),
v2.ToImage(), # PIL -> Tensor image
v2.ToDtype(torch.float32, scale=True), # [0, 255] -> [0.0, 1.0]
v2.Normalize( # ImageNet normalization
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
img = Image.open("cat.jpg")
tensor = transform(img)
print(tensor.shape) # torch.Size([3, 224, 224])
The Key Innovation of v2 — Simultaneous Bounding Box + Mask Transforms
# In v1, only the image was transformed -> BBoxes became misaligned!
# In v2, image + BBox + Mask + Label are transformed simultaneously
from torchvision import tv_tensors
# Object Detection transforms
det_transform = v2.Compose([
v2.RandomHorizontalFlip(p=0.5),
v2.RandomPhotometricDistort(),
v2.RandomZoomOut(fill={tv_tensors.Image: (123, 117, 104)}),
v2.RandomIoUCrop(),
v2.SanitizeBoundingBoxes(), # Remove invalid bboxes
v2.ToDtype(torch.float32, scale=True),
])
# When transforming image + bbox together, bboxes automatically follow!
image = tv_tensors.Image(torch.randint(0, 256, (3, 500, 500), dtype=torch.uint8))
boxes = tv_tensors.BoundingBoxes(
[[100, 100, 300, 300], [200, 200, 400, 400]],
format="XYXY",
canvas_size=(500, 500)
)
labels = torch.tensor([1, 2])
# Simultaneous transform!
out_img, out_boxes, out_labels = det_transform(image, boxes, labels)
Commonly Used Augmentation Recipes
# Training (strong augmentation)
train_transform = v2.Compose([
v2.RandomResizedCrop(224, scale=(0.6, 1.0)),
v2.RandomHorizontalFlip(),
v2.RandAugment(num_ops=2, magnitude=9), # AutoAugment family
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
v2.RandomErasing(p=0.25), # CutOut effect
])
# Validation/Inference (no augmentation)
val_transform = v2.Compose([
v2.Resize(256),
v2.CenterCrop(224),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
Part 2: Pretrained Models (Model Zoo)
Image Classification Models
from torchvision import models
from torchvision.models import (
ResNet50_Weights, EfficientNet_V2_S_Weights,
ViT_B_16_Weights, ConvNeXt_Small_Weights
)
# ResNet-50 (2015, CNN fundamentals)
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
# Top-1: 80.858%, Parameters: 25.6M
# EfficientNet V2 (2021, the efficiency champion)
effnet = models.efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
# Top-1: 84.228%, Parameters: 21.5M
# Vision Transformer (2020, Transformer enters CV)
vit = models.vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1)
# Top-1: 85.304%, Parameters: 86.6M
# ConvNeXt (2022, CNN strikes back — Transformer techniques applied to CNN)
convnext = models.convnext_small(weights=ConvNeXt_Small_Weights.IMAGENET1K_V1)
# Top-1: 83.616%, Parameters: 50.2M
Model selection guide:
+-- Fast inference needed -> MobileNet V3 / EfficientNet-Lite
+-- Accuracy first -> ViT-L / Swin Transformer V2
+-- Balance (recommended for production) -> EfficientNet V2 / ConvNeXt
+-- Learning purposes -> ResNet-50 (fundamentals)
Inference
from torchvision.models import ViT_B_16_Weights
# Load model + preprocessing
weights = ViT_B_16_Weights.IMAGENET1K_V1
model = models.vit_b_16(weights=weights).eval()
preprocess = weights.transforms()
# Inference
img = Image.open("cat.jpg")
batch = preprocess(img).unsqueeze(0) # [1, 3, 224, 224]
with torch.no_grad():
logits = model(batch)
probs = torch.softmax(logits, dim=1)
top5 = torch.topk(probs, 5)
# Results
categories = weights.meta["categories"]
for prob, idx in zip(top5.values[0], top5.indices[0]):
print(f" {categories[idx]:30s} {prob:.2%}")
# tabby cat 87.23%
# Egyptian cat 8.41%
# tiger cat 3.12%
Fine-Tuning (Transfer Learning)
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
# 1. Load pretrained model
model = models.efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
# 2. Replace the final classification layer
num_classes = 10 # Number of classes in your dataset
model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
# 3. Freeze the backbone (optional)
for param in model.features.parameters():
param.requires_grad = False # Freeze backbone
# 4. Train only the classification layer
optimizer = AdamW(model.classifier.parameters(), lr=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=20)
criterion = nn.CrossEntropyLoss()
# 5. Training loop
model.train()
for epoch in range(20):
for images, labels in train_loader:
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
scheduler.step()
# 6. Unfreeze (Fine-tuning phase 2)
for param in model.parameters():
param.requires_grad = True
optimizer = AdamW(model.parameters(), lr=1e-5) # Small LR!
# Train for an additional 10 epochs...
Part 3: Object Detection
Faster R-CNN
from torchvision.models.detection import (
fasterrcnn_resnet50_fpn_v2,
FasterRCNN_ResNet50_FPN_V2_Weights
)
# Pretrained model (COCO 91 classes)
weights = FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1
model = fasterrcnn_resnet50_fpn_v2(weights=weights).eval()
preprocess = weights.transforms()
# Inference
img = Image.open("street.jpg")
batch = [preprocess(img)]
with torch.no_grad():
predictions = model(batch)[0]
# Parse results
for box, label, score in zip(
predictions['boxes'], predictions['labels'], predictions['scores']
):
if score > 0.7:
category = weights.meta["categories"][label]
x1, y1, x2, y2 = box.tolist()
print(f" {category}: {score:.2%} at ({x1:.0f},{y1:.0f},{x2:.0f},{y2:.0f})")
Training Detection on a Custom Dataset
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
# 1. Load pretrained model
model = fasterrcnn_resnet50_fpn_v2(weights="COCO_V1")
# 2. Replace the classification head
num_classes = 5 + 1 # 5 classes + background
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
# 3. Custom dataset
class MyDetectionDataset(torch.utils.data.Dataset):
def __getitem__(self, idx):
img = ... # Load image
target = {
"boxes": torch.tensor([[x1,y1,x2,y2], ...], dtype=torch.float32),
"labels": torch.tensor([1, 3, ...], dtype=torch.int64),
}
return img, target
# 4. Training
model.train()
for images, targets in train_loader:
loss_dict = model(images, targets)
# loss_dict: {'loss_classifier', 'loss_box_reg', 'loss_objectness', 'loss_rpn_box_reg'}
total_loss = sum(loss_dict.values())
total_loss.backward()
optimizer.step()
optimizer.zero_grad()
Part 4: Semantic Segmentation
from torchvision.models.segmentation import (
deeplabv3_resnet101, DeepLabV3_ResNet101_Weights
)
# DeepLab V3 (COCO 21 classes)
weights = DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1
model = deeplabv3_resnet101(weights=weights).eval()
preprocess = weights.transforms()
img = Image.open("city.jpg")
batch = preprocess(img).unsqueeze(0)
with torch.no_grad():
output = model(batch)["out"] # [1, 21, H, W]
pred_mask = output.argmax(dim=1) # [1, H, W] — per-pixel class
# Visualization
import matplotlib.pyplot as plt
plt.imshow(pred_mask[0].cpu(), cmap="tab20")
plt.title("Semantic Segmentation")
plt.savefig("segmentation.png")
Part 5: Datasets
from torchvision.datasets import (
CIFAR10, CIFAR100, ImageNet, MNIST,
FashionMNIST, STL10, Food101, Flowers102,
CocoDetection, VOCDetection
)
# CIFAR-10 (10 classes, 32x32)
train_set = CIFAR10(root="./data", train=True, download=True, transform=train_transform)
# ImageFolder (custom dataset)
from torchvision.datasets import ImageFolder
# Directory structure:
# data/train/
# +-- cat/
# | +-- cat001.jpg
# | +-- cat002.jpg
# +-- dog/
# +-- dog001.jpg
# +-- dog002.jpg
train_set = ImageFolder(root="data/train", transform=train_transform)
print(train_set.classes) # ['cat', 'dog']
print(train_set.class_to_idx) # {'cat': 0, 'dog': 1}
train_loader = torch.utils.data.DataLoader(
train_set, batch_size=32, shuffle=True, num_workers=4, pin_memory=True
)
Part 6: Utilities
Visualization
from torchvision.utils import make_grid, draw_bounding_boxes, draw_segmentation_masks
import torchvision.transforms.functional as F
# Batch image grid
grid = make_grid(batch_images, nrow=8, padding=2, normalize=True)
plt.imshow(grid.permute(1, 2, 0))
# Bounding box visualization
from torchvision.utils import draw_bounding_boxes
img_with_boxes = draw_bounding_boxes(
img_tensor, # uint8, [3, H, W]
boxes, # [N, 4]
labels=["cat", "dog"],
colors=["red", "blue"],
width=3,
font_size=20
)
# Segmentation mask visualization
img_with_mask = draw_segmentation_masks(
img_tensor,
masks=pred_mask.bool(),
alpha=0.5,
colors=["red", "green", "blue"]
)
Feature Extraction (Intermediate Layer Output)
from torchvision.models.feature_extraction import create_feature_extractor
model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
# Extract outputs from specific layers
feature_extractor = create_feature_extractor(model, {
'layer2': 'mid_features', # 512 dimensions
'layer4': 'high_features', # 2048 dimensions
'avgpool': 'embedding', # 2048 dimensions (global)
})
with torch.no_grad():
features = feature_extractor(batch)
print(features['mid_features'].shape) # [1, 512, 28, 28]
print(features['high_features'].shape) # [1, 2048, 7, 7]
print(features['embedding'].shape) # [1, 2048, 1, 1]
Quiz — torchvision (click to reveal!)
Q1. What is the key advantage of Transforms v2 over v1? ||It transforms images, Bounding Boxes, and Segmentation Masks simultaneously. In v1, only images were transformed, causing BBox misalignment issues.||
Q2. What are the ImageNet normalization mean and std values? ||mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]. These are per-channel RGB statistics from the ImageNet training data.||
Q3. What is two-stage fine-tuning in Transfer Learning? ||Stage 1: Freeze backbone + train only the classification head (large LR). Stage 2: Unfreeze all layers + fine-tune with a small LR. This preserves pretrained weights while gradually adapting.||
Q4. What are the four losses in Faster R-CNN? ||loss_classifier (classification), loss_box_reg (box regression), loss_objectness (object presence), loss_rpn_box_reg (RPN box regression).||
Q5. What is the purpose of create_feature_extractor? ||It extracts intermediate layer outputs from a model. You can obtain feature maps from specific layers without running a full forward pass, useful for embedding extraction and feature visualization.||
Q6. What are the two key parameters of RandAugment? ||num_ops: number of augmentation operations to apply. magnitude: transformation strength (0~30). Unlike AutoAugment, it delivers powerful augmentation with just two parameters and no search.||
Related Series and Recommended Posts
- The Complete torchaudio Guide — Audio AI (companion post)
- The Complete Math for AI Guide — Mathematics needed for CNN/ViT
- Build Your Own GPT — The Transformer, the origin of ViT
GitHub
Quiz
Q1: What is the main topic covered in "The Complete torchvision Guide — From Image
Classification to Object Detection and Segmentation"?
From torchvision transforms v2 and pretrained models (ResNet to ViT) to datasets, Object Detection (Faster R-CNN, YOLO), Segmentation, and hands-on fine-tuning — master computer vision in practice with PyTorch.
Q2: What is Part 1: Transforms v2 — A Revolution in Image Preprocessing?
Basic Transforms The Key Innovation of v2 — Simultaneous Bounding Box + Mask Transforms Commonly
Used Augmentation Recipes
Q3: Explain the core concept of Part 2: Pretrained Models (Model Zoo).
Image Classification Models Inference Fine-Tuning (Transfer Learning)
Q4: What are the key aspects of Part 3: Object Detection?
Faster R-CNN Training Detection on a Custom Dataset
Q5: How does Part 6: Utilities work?
Visualization Feature Extraction (Intermediate Layer Output) Q1. What is the key advantage of
Transforms v2 over v1? Q2. What are the ImageNet normalization mean and std values? Q3. What is
two-stage fine-tuning in Transfer Learning? Q4. What are the four losses in Faster R-CNN?