- Authors

- Name
- Youngju Kim
- @fjvbn20031
目次
- 自己教師あり学習とは?
- 初期のSSL手法
- 対照学習
- マスク自己エンコーダ(MAE)
- DINO - ラベルなし自己蒸留
- CLIP - 対照言語画像事前学習
- BEiT - 画像のためのBERT
- 言語モデル事前学習
- 実践的な応用
1. 自己教師あり学習とは?
1.1 学習パラダイムの比較
深層学習における3つの主要な学習パラダイムから理解していきましょう。
教師あり学習はラベル付きデータで学習します。ImageNetの1000万枚の画像にはそれぞれ人間が付けたラベルがあり、ResNetやViTといったモデルはこのデータで学習されます。問題はラベル収集が非常にコストがかかることです。医療画像1枚に専門家が数十分かけてアノテーションする必要がある場合もあります。
教師なし学習はラベルなしでデータのパターンを発見します。K-Meansクラスタリング、PCA、GANがこのカテゴリに属します。しかし従来の教師なし学習は、下流タスクに直接利用できる表現を学習するのに苦労します。
**自己教師あり学習(SSL)**は両者の長所を組み合わせます。大量のラベルなしデータを活用しながら、データ自体から監督信号を生成します。
核心的なアイデア:データ自体がラベルである。
インターネットには数十億枚の画像と数兆語のテキストが存在します。SSLはこの膨大なラベルなしデータを活用して強力な表現を学習します。
1.2 ラベル不足問題
実世界ではラベル収集が非常に困難です。
| ドメイン | ラベル収集コスト | 理由 |
|---|---|---|
| 医療画像 | 非常に高い | 専門医が必要 |
| 衛星画像 | 高い | 地理専門知識が必要 |
| 法律文書 | 高い | 法律専門家が必要 |
| 工業欠陥検出 | 高い | 欠陥サンプルが希少 |
| 音声認識 | 中程度 | 文字起こしが必要 |
SSLはこの問題を根本的に解決します。大規模なラベルなしデータで事前学習し、少量のラベルデータでファインチューニングするというアプローチです。
1.3 プレテキストタスク
SSLの鍵はプレテキストタスクです。データ自体から人工的な予測問題を作り出し、モデルに意味のある表現を学習させます。
良いプレテキストタスクの条件:
- ラベルなしで自動的に構築できること
- タスクをうまく解くためには意味のある表現が必要であること
- 難しすぎず簡単すぎないこと
例:
- 回転予測:画像を0度、90度、180度、270度回転させ、何度回転したか予測
- マスキング:画像やテキストの一部を隠して復元
- 対照学習:同じ画像の2つのビューを埋め込み空間で近づける
1.4 SSLの応用
SSLは現代AIの基盤となっています:
- GPT、BERT:テキストSSL → ChatGPT、Geminiの基礎
- CLIP:画像テキストSSL → DALL-E、Stable Diffusionの基礎
- MAE、DINO:画像SSL → 医療画像、自動運転の基礎
- wav2vec:音声SSL → 音声認識の基礎
2. 初期のSSL手法
初期のSSL研究では様々なプレテキストタスクが探索されました。
2.1 回転予測
2018年にGidarisらが提案。4つの回転(0度、90度、180度、270度)のどれが適用されたかをモデルが予測します。
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
class RotationSSL(nn.Module):
def __init__(self, backbone, num_classes=4):
super().__init__()
self.backbone = backbone
self.classifier = nn.Linear(backbone.output_dim, num_classes)
def forward(self, x):
features = self.backbone(x)
return self.classifier(features)
def create_rotation_dataset(images):
"""4方向に回転させた(画像, ラベル)ペアを作成"""
rotated_images = []
labels = []
for img in images:
for k in range(4): # 0, 90, 180, 270度
rotated = T.functional.rotate(img, k * 90)
rotated_images.append(rotated)
labels.append(k)
return rotated_images, labels
def train_rotation_ssl(model, dataloader, optimizer, epochs=10):
criterion = nn.CrossEntropyLoss()
model.train()
for epoch in range(epochs):
total_loss = 0
for images, labels in dataloader:
images, labels = images.cuda(), labels.cuda()
optimizer.zero_grad()
logits = model(images)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
制限:自然画像には明確な向きがありますが、テクスチャや空の画像には適さない場合があります。
2.2 ジグソーパズル解き
2016年にNoroozi & Favaroが提案。画像を3x3のグリッドに分割してシャッフルし、元の順序を予測します。
import itertools
import numpy as np
class JigsawSSL(nn.Module):
def __init__(self, backbone, num_permutations=100):
super().__init__()
self.backbone = backbone
self.patch_encoder = nn.Sequential(
backbone,
nn.Linear(backbone.output_dim, 512)
)
self.classifier = nn.Sequential(
nn.Linear(512 * 9, 4096),
nn.ReLU(),
nn.Linear(4096, num_permutations)
)
def forward(self, patches):
# patches: (batch, 9, C, H, W)
batch_size = patches.size(0)
patch_features = []
for i in range(9):
feat = self.patch_encoder(patches[:, i]) # (batch, 512)
patch_features.append(feat)
combined = torch.cat(patch_features, dim=1) # (batch, 512*9)
return self.classifier(combined)
def create_jigsaw_dataset(image, grid_size=3):
"""画像をgrid_size x grid_sizeのパッチに分割してシャッフル"""
h, w = image.shape[-2:]
patch_h, patch_w = h // grid_size, w // grid_size
patches = []
for i in range(grid_size):
for j in range(grid_size):
patch = image[...,
i*patch_h:(i+1)*patch_h,
j*patch_w:(j+1)*patch_w]
patches.append(patch)
permutation_idx = np.random.randint(0, len(PREDEFINED_PERMUTATIONS))
perm = PREDEFINED_PERMUTATIONS[permutation_idx]
shuffled = [patches[p] for p in perm]
return torch.stack(shuffled), permutation_idx
2.3 彩色
2016年にZhangらが提案。グレースケール画像を入力として、モデルが色を予測します。
class ColorizationSSL(nn.Module):
def __init__(self):
super().__init__()
# L チャンネル入力(明度)
self.encoder = nn.Sequential(
nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(),
nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.ReLU(),
nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.ReLU(),
)
# ab チャンネル出力(色)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), nn.ReLU(),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.ReLU(),
nn.Conv2d(64, 2, 3, padding=1), # ab チャンネル
nn.Tanh()
)
def forward(self, l_channel):
features = self.encoder(l_channel)
ab_channels = self.decoder(features)
return ab_channels
from skimage.color import rgb2lab, lab2rgb
def prepare_colorization_data(rgb_image):
lab = rgb2lab(rgb_image)
l_channel = lab[:, :, 0:1] # 明度チャンネル
ab_channels = lab[:, :, 1:] # 色チャンネル
return l_channel, ab_channels
3. 対照学習
対照学習は現代SSLの中核です。2020年頃に爆発的な発展を遂げました。
3.1 核心的なアイデア
似ているものを引き寄せ、異なるものを遠ざける。
同じ画像の2つの異なるビュー(クロップ、フリップ、カラージッターなど)は埋め込み空間で近くにあるべきであり、異なる画像は離れているべきです。
画像 x
├── 拡張ビュー v1 → z1 (埋め込み)
└── 拡張ビュー v2 → z2 (埋め込み)
目標:z1とz2を近づけ、z1とz3(異なる画像)を遠ざける
3.2 InfoNCE損失
対照学習の標準損失関数——NT-Xent(正規化温度スケーリングクロスエントロピー)とも呼ばれます。
simはコサイン類似度、τは温度パラメータです。
import torch
import torch.nn.functional as F
def info_nce_loss(z1, z2, temperature=0.5):
"""
InfoNCE / NT-Xent Loss
z1, z2: (batch_size, embedding_dim) - 同じ画像の2つのビュー
"""
batch_size = z1.size(0)
# L2正規化
z1 = F.normalize(z1, dim=1)
z2 = F.normalize(z2, dim=1)
# 結合: [z1_1, ..., z1_N, z2_1, ..., z2_N]
z = torch.cat([z1, z2], dim=0) # (2N, D)
# ペアワイズ類似度を計算
similarity = torch.mm(z, z.t()) / temperature # (2N, 2N)
# 自己類似度を除外(対角線 = -inf)
mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
similarity.masked_fill_(mask, float('-inf'))
# ポジティブペア:インデックス i と i+N
labels = torch.cat([
torch.arange(batch_size, 2 * batch_size),
torch.arange(batch_size)
]).to(z.device)
loss = F.cross_entropy(similarity, labels)
return loss
# 使用例
z1 = torch.randn(32, 128)
z2 = torch.randn(32, 128)
loss = info_nce_loss(z1, z2, temperature=0.5)
print(f"InfoNCE Loss: {loss.item():.4f}")
3.3 SimCLR
2020年にGoogleが発表。Simple Framework for Contrastive Learning。シンプルでありながら強力。
主要コンポーネント:
- データ拡張(強いオーグメンテーションが重要)
- エンコーダ(ResNet)
- プロジェクションヘッド(MLP)
- NT-Xent損失
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as T
class SimCLR(nn.Module):
def __init__(self, backbone='resnet50', projection_dim=128, temperature=0.5):
super().__init__()
self.temperature = temperature
if backbone == 'resnet50':
resnet = models.resnet50(pretrained=False)
self.encoder = nn.Sequential(*list(resnet.children())[:-1])
self.feature_dim = 2048
elif backbone == 'resnet18':
resnet = models.resnet18(pretrained=False)
self.encoder = nn.Sequential(*list(resnet.children())[:-1])
self.feature_dim = 512
# プロジェクションヘッド(ファインチューニング時は削除)
self.projection_head = nn.Sequential(
nn.Linear(self.feature_dim, self.feature_dim),
nn.ReLU(inplace=True),
nn.Linear(self.feature_dim, projection_dim)
)
def encode(self, x):
h = self.encoder(x)
h = h.squeeze(-1).squeeze(-1)
return h
def forward(self, x):
h = self.encode(x)
z = self.projection_head(h)
return z
def contrastive_loss(self, z1, z2):
return info_nce_loss(z1, z2, self.temperature)
class SimCLRDataAugmentation:
"""SimCLRの強いデータ拡張"""
def __init__(self, image_size=224, s=1.0):
color_jitter = T.ColorJitter(
brightness=0.8*s,
contrast=0.8*s,
saturation=0.8*s,
hue=0.2*s
)
self.transform = T.Compose([
T.RandomResizedCrop(image_size, scale=(0.2, 1.0)),
T.RandomHorizontalFlip(p=0.5),
T.RandomApply([color_jitter], p=0.8),
T.RandomGrayscale(p=0.2),
T.GaussianBlur(kernel_size=int(0.1 * image_size)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def __call__(self, x):
return self.transform(x), self.transform(x)
def train_simclr(model, dataloader, optimizer, epochs=200):
model.train()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=epochs
)
for epoch in range(epochs):
total_loss = 0
for (x1, x2), _ in dataloader:
x1, x2 = x1.cuda(), x2.cuda()
z1 = model(x1)
z2 = model(x2)
loss = model.contrastive_loss(z1, z2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
scheduler.step()
avg_loss = total_loss / len(dataloader)
if (epoch + 1) % 10 == 0:
print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")
model = SimCLR(backbone='resnet50', projection_dim=128).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-6)
def get_finetuning_model(pretrained_simclr, num_classes):
encoder = pretrained_simclr.encoder
classifier = nn.Linear(pretrained_simclr.feature_dim, num_classes)
return nn.Sequential(encoder, nn.Flatten(), classifier)
3.4 MoCo(モメンタムコントラスト)
2020年にFacebook AIが発表。大きいバッチサイズを必要とせず、多数のネガティブサンプルが使えます。
核心的なアイデア:モメンタム更新による安定したキーエンコーダ + ネガティブサンプル管理のためのキュー
class MoCo(nn.Module):
def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07):
super().__init__()
self.K = K # キューサイズ
self.m = m # モメンタム係数
self.T = T # 温度
# クエリエンコーダとキーエンコーダ
self.encoder_q = base_encoder(num_classes=dim)
self.encoder_k = base_encoder(num_classes=dim)
# キーエンコーダ:勾配なし、モメンタムで更新
for param_q, param_k in zip(
self.encoder_q.parameters(),
self.encoder_k.parameters()
):
param_k.data.copy_(param_q.data)
param_k.requires_grad = False
# ネガティブサンプルキュー
self.register_buffer("queue", torch.randn(dim, K))
self.queue = F.normalize(self.queue, dim=0)
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def _momentum_update_key_encoder(self):
for param_q, param_k in zip(
self.encoder_q.parameters(),
self.encoder_k.parameters()
):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
self.queue[:, ptr:ptr + batch_size] = keys.T
ptr = (ptr + batch_size) % self.K
self.queue_ptr[0] = ptr
def forward(self, im_q, im_k):
q = self.encoder_q(im_q)
q = F.normalize(q, dim=1)
with torch.no_grad():
self._momentum_update_key_encoder()
k = self.encoder_k(im_k)
k = F.normalize(k, dim=1)
# ポジティブロジット: (batch, 1)
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
# ネガティブロジット: (batch, K)
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
logits = torch.cat([l_pos, l_neg], dim=1) / self.T
labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
loss = F.cross_entropy(logits, labels)
self._dequeue_and_enqueue(k)
return loss
3.5 BYOL(Bootstrap Your Own Latent)
2020年にDeepMindが発表。ネガティブサンプルなしで機能する画期的な手法。
核心的な洞察:オンラインネットワークがターゲットネットワークの表現を予測します。ターゲットはモメンタムで更新されます。
import copy
class BYOL(nn.Module):
def __init__(self, backbone, projection_dim=256, prediction_dim=128, tau=0.996):
super().__init__()
self.tau = tau
# オンラインネットワーク
self.online_encoder = backbone
self.online_projector = nn.Sequential(
nn.Linear(backbone.output_dim, 4096),
nn.BatchNorm1d(4096),
nn.ReLU(inplace=True),
nn.Linear(4096, projection_dim)
)
self.online_predictor = nn.Sequential(
nn.Linear(projection_dim, 4096),
nn.BatchNorm1d(4096),
nn.ReLU(inplace=True),
nn.Linear(4096, prediction_dim)
)
# ターゲットネットワーク(モメンタムで更新)
self.target_encoder = copy.deepcopy(backbone)
self.target_projector = copy.deepcopy(self.online_projector)
for param in self.target_encoder.parameters():
param.requires_grad = False
for param in self.target_projector.parameters():
param.requires_grad = False
@torch.no_grad()
def update_target(self):
for online, target in zip(
self.online_encoder.parameters(),
self.target_encoder.parameters()
):
target.data = self.tau * target.data + (1 - self.tau) * online.data
def forward(self, x1, x2):
online_z1 = self.online_projector(self.online_encoder(x1))
online_z2 = self.online_projector(self.online_encoder(x2))
pred1 = self.online_predictor(online_z1)
pred2 = self.online_predictor(online_z2)
with torch.no_grad():
target_z1 = self.target_projector(self.target_encoder(x1))
target_z2 = self.target_projector(self.target_encoder(x2))
# BYOL損失(対称的なコサイン類似度)
loss1 = 2 - 2 * F.cosine_similarity(pred1, target_z2.detach()).mean()
loss2 = 2 - 2 * F.cosine_similarity(pred2, target_z1.detach()).mean()
return (loss1 + loss2) / 2
4. マスク自己エンコーダ(MAE)
MAE(Masked Autoencoders Are Scalable Vision Learners)はMetaのKaiming Heらが2021年に発表。ViTと組み合わせて非常に効果的です。
4.1 核心的なアイデア
画像パッチの75%をランダムにマスクし、マスクされたピクセルを復元する。
なぜ75%もマスクするのか?
- 画像の冗長性が高いため、低いマスク率では簡単すぎる
- 高いマスク率ではモデルがシーンの全体的な理解を学習せざるを得ない
- BERTの15%とは異なり、画像には空間的な局所性がある
4.2 完全なMAE実装
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class PatchEmbedding(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.patch_size = patch_size
num_patches = (img_size // patch_size) ** 2
self.projection = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
x = self.projection(x) # (B, embed_dim, H/P, W/P)
x = x.flatten(2).transpose(1, 2) # (B, num_patches, embed_dim)
return x
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
self.norm2 = nn.LayerNorm(dim)
mlp_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
x = x + self.mlp(self.norm2(x))
return x
class MAE(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
in_channels=3,
encoder_dim=768,
encoder_depth=12,
encoder_heads=12,
decoder_dim=512,
decoder_depth=8,
decoder_heads=16,
mask_ratio=0.75,
):
super().__init__()
self.mask_ratio = mask_ratio
self.patch_size = patch_size
num_patches = (img_size // patch_size) ** 2
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, encoder_dim)
self.cls_token = nn.Parameter(torch.zeros(1, 1, encoder_dim))
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, encoder_dim),
requires_grad=False
)
self.encoder_blocks = nn.ModuleList([
TransformerBlock(encoder_dim, encoder_heads)
for _ in range(encoder_depth)
])
self.encoder_norm = nn.LayerNorm(encoder_dim)
self.decoder_embed = nn.Linear(encoder_dim, decoder_dim, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))
self.decoder_pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, decoder_dim),
requires_grad=False
)
self.decoder_blocks = nn.ModuleList([
TransformerBlock(decoder_dim, decoder_heads)
for _ in range(decoder_depth)
])
self.decoder_norm = nn.LayerNorm(decoder_dim)
self.decoder_pred = nn.Linear(decoder_dim, patch_size**2 * in_channels)
def random_masking(self, x, mask_ratio):
N, L, D = x.shape
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device)
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(
x, dim=1,
index=ids_keep.unsqueeze(-1).repeat(1, 1, D)
)
mask = torch.ones(N, L, device=x.device)
mask[:, :len_keep] = 0
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
def forward_encoder(self, x, mask_ratio):
x = self.patch_embed(x)
x = x + self.pos_embed[:, 1:, :]
x, mask, ids_restore = self.random_masking(x, mask_ratio)
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
for block in self.encoder_blocks:
x = block(x)
x = self.encoder_norm(x)
return x, mask, ids_restore
def forward_decoder(self, x, ids_restore):
x = self.decoder_embed(x)
mask_tokens = self.mask_token.repeat(
x.shape[0],
ids_restore.shape[1] + 1 - x.shape[1],
1
)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
x_ = torch.gather(
x_, dim=1,
index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
)
x = torch.cat([x[:, :1, :], x_], dim=1)
x = x + self.decoder_pos_embed
for block in self.decoder_blocks:
x = block(x)
x = self.decoder_norm(x)
x = self.decoder_pred(x)
x = x[:, 1:, :] # CLSトークンを削除
return x
def forward(self, imgs, mask_ratio=0.75):
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
pred = self.forward_decoder(latent, ids_restore)
loss = self.forward_loss(imgs, pred, mask)
return loss, pred, mask
def forward_loss(self, imgs, pred, mask):
target = self.patchify(imgs)
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.e-6).sqrt()
loss = (pred - target) ** 2
loss = loss.mean(dim=-1)
loss = (loss * mask).sum() / mask.sum()
return loss
def patchify(self, imgs):
p = self.patch_size
h = w = imgs.shape[2] // p
x = imgs.reshape(imgs.shape[0], 3, h, p, w, p)
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(imgs.shape[0], h * w, p**2 * 3)
return x
# MAEモデルの作成と学習
mae_model = MAE(
img_size=224,
patch_size=16,
encoder_dim=768,
encoder_depth=12,
encoder_heads=12,
decoder_dim=512,
decoder_depth=8,
decoder_heads=16,
mask_ratio=0.75
).cuda()
optimizer = torch.optim.AdamW(
mae_model.parameters(),
lr=1.5e-4,
betas=(0.9, 0.95),
weight_decay=0.05
)
def train_step(model, imgs, optimizer):
model.train()
optimizer.zero_grad()
loss, pred, mask = model(imgs, mask_ratio=0.75)
loss.backward()
optimizer.step()
return loss.item()
5. DINO
DINO(Self-DIstillation with NO labels)は2021年にFacebook AI Researchが発表。ViTをSSLで学習すると驚くべき特性が現れます。
5.1 主要な発見
DINOで学習したViTのセルフアテンションマップはラベルなしでセマンティックセグメンテーションを実行します。モデルが前景と背景を完璧に分離します。
5.2 自己蒸留アーキテクチャ
学生ネットワーク 教師ネットワーク
↑ ↑
局所ビュー(多数) グローバルビュー(2枚)
- 教師:画像全体のビューを処理、豊かなコンテキスト
- 学生:局所的なクロップを処理、教師の予測に一致するよう学習
- 教師の更新:学生のEMA(勾配なし)
5.3 センタリングとシャーペニング
崩壊を防ぐ2つのテクニック:
class DINOHead(nn.Module):
def __init__(self, in_dim, out_dim=65536, bottleneck_dim=256):
super().__init__()
layers = [nn.Linear(in_dim, 2048), nn.GELU()]
layers += [nn.Linear(2048, 2048), nn.GELU()]
layers += [nn.Linear(2048, bottleneck_dim)]
self.mlp = nn.Sequential(*layers)
self.last_layer = nn.utils.weight_norm(
nn.Linear(bottleneck_dim, out_dim, bias=False)
)
self.last_layer.weight_g.data.fill_(1)
self.last_layer.weight_g.requires_grad = False
def forward(self, x):
x = self.mlp(x)
x = F.normalize(x, dim=-1)
x = self.last_layer(x)
return x
class DINO(nn.Module):
def __init__(self, backbone, head_out_dim=65536, tau_s=0.1, tau_t=0.04, m=0.996):
super().__init__()
self.tau_s = tau_s # 学生の温度
self.tau_t = tau_t # 教師の温度
self.m = m # モメンタム係数
self.student = nn.Sequential(backbone, DINOHead(backbone.output_dim, head_out_dim))
self.teacher = copy.deepcopy(self.student)
for param in self.teacher.parameters():
param.requires_grad = False
# センタリング(教師の出力の移動平均)
self.register_buffer("center", torch.zeros(1, head_out_dim))
@torch.no_grad()
def update_teacher(self):
for s_param, t_param in zip(
self.student.parameters(),
self.teacher.parameters()
):
t_param.data = self.m * t_param.data + (1 - self.m) * s_param.data
def dino_loss(self, student_output, teacher_output):
"""DINOクロスエントロピー損失"""
student_out = student_output / self.tau_s
student_out = student_out.chunk(2) # 局所ビュー
# センタリングと温度シャーペニング
teacher_out = F.softmax(
(teacher_output - self.center) / self.tau_t, dim=-1
)
teacher_out = teacher_out.detach().chunk(2) # グローバルビュー
total_loss = 0
n_loss_terms = 0
for iq, q in enumerate(teacher_out):
for v in range(len(student_out)):
if v == iq:
continue
loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
total_loss += loss.mean()
n_loss_terms += 1
total_loss /= n_loss_terms
# センターを更新
self.center = 0.9 * self.center + 0.1 * teacher_output.mean(dim=0, keepdim=True)
return total_loss
6. CLIP
CLIP(Contrastive Language-Image Pre-training)は2021年にOpenAIが発表。画像とテキストをペアで学習するマルチモーダルSSLです。
6.1 コアコンセプト
4億枚の画像とキャプションのペアで対照学習を適用:
- 同じ画像テキストペアの埋め込みを近づける
- 異なるペアの埋め込みを遠ざける
ゼロショット能力:学習後、新しいカテゴリの画像を直接分類できます(テキストプロンプトを使用)。
6.2 完全なCLIP実装
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPModel, CLIPProcessor
from PIL import Image
def zero_shot_classify(image_path, class_labels, model_name="openai/clip-vit-base-patch32"):
"""CLIPを使用したゼロショット画像分類"""
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)
model.eval()
image = Image.open(image_path)
# テキストプロンプトを作成
text_prompts = [f"a photo of a {label}" for label in class_labels]
inputs = processor(
text=text_prompts,
images=image,
return_tensors="pt",
padding=True
)
with torch.no_grad():
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
results = {}
for label, prob in zip(class_labels, probs[0]):
results[label] = prob.item()
return results
# 使用例
labels = ["cat", "dog", "car", "airplane", "bird"]
results = zero_shot_classify("example.jpg", labels)
for label, prob in sorted(results.items(), key=lambda x: -x[1]):
print(f"{label}: {prob:.4f}")
# ---- 画像テキスト類似度検索 ----
class CLIPRetrieval:
def __init__(self, model_name="openai/clip-vit-base-patch32"):
self.model = CLIPModel.from_pretrained(model_name)
self.processor = CLIPProcessor.from_pretrained(model_name)
self.model.eval()
self.image_embeddings = None
self.image_paths = []
def index_images(self, image_paths):
"""画像データベースをインデックス化"""
self.image_paths = image_paths
embeddings = []
for path in image_paths:
image = Image.open(path)
inputs = self.processor(images=image, return_tensors="pt")
with torch.no_grad():
embedding = self.model.get_image_features(**inputs)
embeddings.append(F.normalize(embedding, dim=1))
self.image_embeddings = torch.cat(embeddings, dim=0)
print(f"{len(image_paths)}枚の画像をインデックス化しました")
def search(self, query_text, top_k=5):
"""テキストクエリで画像を検索"""
inputs = self.processor(text=[query_text], return_tensors="pt", padding=True)
with torch.no_grad():
text_embedding = self.model.get_text_features(**inputs)
text_embedding = F.normalize(text_embedding, dim=1)
similarities = torch.matmul(text_embedding, self.image_embeddings.T).squeeze()
top_indices = similarities.argsort(descending=True)[:top_k]
results = [
(self.image_paths[i], similarities[i].item())
for i in top_indices
]
return results
7. BEiT
BEiT(BERT Pre-Training of Image Transformers)は2021年にMicrosoftが発表。画像を離散トークンに変換してBERT風に学習します。
7.1 離散視覚トークン(dVAE)
BEiTのコア:生のピクセルを予測する代わりに離散トークンを予測します。
ステップ1: dVAEが画像を離散トークンに変換(語彙サイズ8192)
ステップ2: ViTがマスクされた各パッチのトークンを予測
from transformers import BeitForMaskedImageModeling, BeitImageProcessor
import torch
class BEiTPretraining:
def __init__(self, model_name="microsoft/beit-base-patch16-224"):
self.model = BeitForMaskedImageModeling.from_pretrained(model_name)
self.processor = BeitImageProcessor.from_pretrained(model_name)
def create_bool_masked_pos(self, num_patches, mask_ratio=0.4):
num_masked = int(num_patches * mask_ratio)
mask = torch.zeros(num_patches, dtype=torch.bool)
masked_indices = torch.randperm(num_patches)[:num_masked]
mask[masked_indices] = True
return mask
def forward(self, images):
inputs = self.processor(images, return_tensors="pt")
pixel_values = inputs.pixel_values
num_patches = (224 // 16) ** 2 # 196
bool_masked_pos = self.create_bool_masked_pos(num_patches)
bool_masked_pos = bool_masked_pos.unsqueeze(0).expand(
pixel_values.size(0), -1
)
outputs = self.model(
pixel_values=pixel_values,
bool_masked_pos=bool_masked_pos
)
return outputs.loss
8. 言語モデル事前学習
SSLはNLPに長い歴史があります。
8.1 GPT - 次のトークン予測
左から右へと次のトークンを予測する自己回帰的アプローチ。
import torch
import torch.nn as nn
from torch.nn import functional as F
class GPTLanguageModel(nn.Module):
def __init__(self, vocab_size, block_size, n_embd, n_head, n_layer, dropout=0.1):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, n_embd)
self.position_embedding = nn.Embedding(block_size, n_embd)
self.blocks = nn.Sequential(*[
TransformerBlock(n_embd, n_head, dropout=dropout)
for _ in range(n_layer)
])
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size)
self.block_size = block_size
def forward(self, idx, targets=None):
B, T = idx.shape
tok_emb = self.token_embedding(idx)
pos_emb = self.position_embedding(torch.arange(T, device=idx.device))
x = tok_emb + pos_emb
x = self.blocks(x)
x = self.ln_f(x)
logits = self.lm_head(x)
if targets is None:
return logits, None
B, T, C = logits.shape
loss = F.cross_entropy(logits.view(B * T, C), targets.view(B * T))
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.block_size:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float('-inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
8.2 BERT - マスク言語モデリング
双方向コンテキストを学習します。15%のトークンがマスクされ、復元する必要があります。
from transformers import BertForMaskedLM, BertTokenizer
import torch
class BERTPretraining:
def __init__(self, model_name="bert-base-uncased"):
self.model = BertForMaskedLM.from_pretrained(model_name)
self.tokenizer = BertTokenizer.from_pretrained(model_name)
def mask_tokens(self, inputs, mlm_probability=0.15):
"""MLMのためにトークンをマスク"""
labels = inputs.clone()
probability_matrix = torch.full(labels.shape, mlm_probability)
# 特殊トークンはマスクしない
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
for val in labels.tolist()
]
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # マスクされていないトークンの損失を計算しない
# 80%は[MASK]に置換
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
# 10%はランダムなトークンに置換
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]
return inputs, labels
9. 実践的な応用
9.1 線形評価とファインチューニング
class LinearProbe(nn.Module):
"""SSL表現を評価するための線形分類器"""
def __init__(self, input_dim, num_classes):
super().__init__()
self.linear = nn.Linear(input_dim, num_classes)
def forward(self, x):
return self.linear(x)
def extract_features(model, dataloader, device='cuda'):
"""事前学習済みモデルから特徴を抽出"""
model.eval()
features_list = []
labels_list = []
with torch.no_grad():
for images, labels in dataloader:
images = images.to(device)
features = model.encode(images)
features_list.append(features.cpu())
labels_list.append(labels)
return torch.cat(features_list), torch.cat(labels_list)
def few_shot_evaluation(ssl_model, train_loader, val_loader,
n_shots=[1, 5, 10, 100], device='cuda'):
"""クラスごとのnサンプルで評価"""
ssl_model.eval()
all_features, all_labels = extract_features(ssl_model, train_loader, device)
val_features, val_labels = extract_features(ssl_model, val_loader, device)
results = {}
num_classes = len(all_labels.unique())
for n_shot in n_shots:
selected_indices = []
for cls in range(num_classes):
cls_indices = (all_labels == cls).nonzero(as_tuple=True)[0]
selected = cls_indices[torch.randperm(len(cls_indices))[:n_shot]]
selected_indices.append(selected)
selected_indices = torch.cat(selected_indices)
train_feats = all_features[selected_indices]
train_labs = all_labels[selected_indices]
probe = LinearProbe(all_features.shape[1], num_classes).to(device)
optimizer = torch.optim.Adam(probe.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
train_feats = train_feats.to(device)
train_labs = train_labs.to(device)
for epoch in range(100):
probe.train()
logits = probe(train_feats)
loss = criterion(logits, train_labs)
optimizer.zero_grad()
loss.backward()
optimizer.step()
probe.eval()
with torch.no_grad():
val_logits = probe(val_features.to(device))
preds = val_logits.argmax(dim=1).cpu()
acc = (preds == val_labels).float().mean().item()
results[n_shot] = acc
print(f"{n_shot}-shot精度: {acc * 100:.2f}%")
return results
def knn_evaluation(ssl_model, train_loader, val_loader, k=20, device='cuda'):
"""SSL表現の品質をk-NNで評価(パラメータフリー)"""
ssl_model.eval()
train_features, train_labels = extract_features(ssl_model, train_loader, device)
val_features, val_labels = extract_features(ssl_model, val_loader, device)
train_features = nn.functional.normalize(train_features, dim=1)
val_features = nn.functional.normalize(val_features, dim=1)
correct = 0
total = 0
batch_size = 256
for i in range(0, len(val_features), batch_size):
batch_feats = val_features[i:i+batch_size].to(device)
batch_labels = val_labels[i:i+batch_size]
sims = torch.mm(batch_feats, train_features.to(device).T)
topk_sims, topk_indices = sims.topk(k, dim=1)
topk_labels = train_labels[topk_indices.cpu()]
preds = topk_labels.mode(dim=1).values
correct += (preds == batch_labels).sum().item()
total += len(batch_labels)
accuracy = correct / total
print(f"k-NN ({k}) 精度: {accuracy * 100:.2f}%")
return accuracy
9.2 ドメイン特化SSL
# 医療画像(胸部X線の例)
class MedicalSSL(nn.Module):
"""医療画像のための自己教師あり学習"""
def __init__(self, backbone):
super().__init__()
self.backbone = backbone
# 医療画像には弱いオーグメンテーションを使用
self.augmentation = T.Compose([
T.RandomResizedCrop(224, scale=(0.8, 1.0)),
T.RandomHorizontalFlip(p=0.5),
# 医療画像では色変化を最小限に
T.RandomApply([T.ColorJitter(0.1, 0.1, 0.0, 0.0)], p=0.3),
T.ToTensor(),
T.Normalize([0.485], [0.229]) # グレースケールX線
])
def forward(self, x):
v1 = self.augmentation(x)
v2 = self.augmentation(x)
return v1, v2
# 衛星画像
class SatelliteSSL:
"""衛星画像のオーグメンテーション(地理的特性を保持)"""
def __init__(self):
self.transform = T.Compose([
T.RandomResizedCrop(256, scale=(0.4, 1.0)),
# 衛星画像:回転不変性が重要
T.RandomRotation(90),
T.RandomHorizontalFlip(),
T.RandomVerticalFlip(),
# 時刻・季節変化をシミュレート
T.ColorJitter(brightness=0.4, contrast=0.4),
T.ToTensor(),
])
9.3 HubからのSSL事前学習モデルの使用
from transformers import (
ViTModel, ViTFeatureExtractor,
CLIPModel, CLIPProcessor
)
from PIL import Image
import torch
# MAE事前学習済みViTを読み込む
def load_mae_pretrained():
model = ViTModel.from_pretrained("facebook/vit-mae-base")
feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")
return model, feature_extractor
# DINO ViT特徴抽出
def extract_dino_features(image_path):
model = ViTModel.from_pretrained("facebook/dino-vits16")
processor = ViTFeatureExtractor.from_pretrained("facebook/dino-vits16")
image = Image.open(image_path)
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# CLSトークン(グローバル画像表現)
cls_features = outputs.last_hidden_state[:, 0, :]
# パッチごとの特徴(空間情報あり)
patch_features = outputs.last_hidden_state[:, 1:, :]
return cls_features, patch_features
# CLIPで特徴抽出
def clip_feature_extraction(images, texts=None):
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
if texts:
inputs = processor(text=texts, images=images, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
return outputs.image_embeds, outputs.text_embeds
else:
inputs = processor(images=images, return_tensors="pt")
with torch.no_grad():
image_features = model.get_image_features(**inputs)
return image_features
まとめ
自己教師あり学習はAIを変革しています。重要なポイント:
| 手法 | 年 | 核心的なアイデア | 強み |
|---|---|---|---|
| SimCLR | 2020 | 強いオーグメンテーション + 対照学習 | シンプルで強力 |
| MoCo | 2020 | モメンタムキュー | バッチサイズ非依存 |
| BYOL | 2020 | ネガティブなしの対照学習 | バッチサイズ非依存 |
| MAE | 2021 | 75%マスキング + ピクセル復元 | 効率的、ViT最適 |
| DINO | 2021 | 自己蒸留 | セグメンテーション特性 |
| CLIP | 2021 | 画像テキスト対照学習 | ゼロショット分類 |
| BEiT | 2021 | 離散トークンマスキング | BERTスタイル |
| DINOv2 | 2023 | 大規模 + iBOT | 汎用特徴抽出器 |
学習ロードマップ:
- SimCLRを実装して対照学習を理解する
- MAEを学習してマスキングアプローチを把握する
- CLIPをマルチモーダル学習に応用する
- 自分のドメインデータでファインチューニングする
SSLはLLM事前学習とコンピュータビジョンの基盤となっています。膨大なラベルなしデータを効果的に活用するアプローチとして、SSLはAI進歩の中心的な推進力であり続けるでしょう。