Skip to content

필사 모드: 自己教師あり学習完全ガイド: SimCLR・MAE・DINO・CLIPまで

日本語
0%
정확도 0%
💡 왼쪽 원문을 읽으면서 오른쪽에 따라 써보세요. Tab 키로 힌트를 받을 수 있습니다.
원문 렌더가 준비되기 전까지 텍스트 가이드로 표시합니다.

目次

1. [自己教師あり学習とは?](#1-自己教師あり学習とは)

2. [初期のSSL手法](#2-初期のssl手法)

3. [対照学習](#3-対照学習)

4. [マスク自己エンコーダ(MAE)](#4-マスク自己エンコーダmae)

5. [DINO - ラベルなし自己蒸留](#5-dino)

6. [CLIP - 対照言語画像事前学習](#6-clip)

7. [BEiT - 画像のためのBERT](#7-beit)

8. [言語モデル事前学習](#8-言語モデル事前学習)

9. [実践的な応用](#9-実践的な応用)

1. 自己教師あり学習とは?

1.1 学習パラダイムの比較

深層学習における3つの主要な学習パラダイムから理解していきましょう。

**教師あり学習**はラベル付きデータで学習します。ImageNetの1000万枚の画像にはそれぞれ人間が付けたラベルがあり、ResNetやViTといったモデルはこのデータで学習されます。問題はラベル収集が非常にコストがかかることです。医療画像1枚に専門家が数十分かけてアノテーションする必要がある場合もあります。

**教師なし学習**はラベルなしでデータのパターンを発見します。K-Meansクラスタリング、PCA、GANがこのカテゴリに属します。しかし従来の教師なし学習は、下流タスクに直接利用できる表現を学習するのに苦労します。

**自己教師あり学習(SSL)**は両者の長所を組み合わせます。大量のラベルなしデータを活用しながら、データ自体から監督信号を生成します。

核心的なアイデア:**データ自体がラベルである。**

インターネットには数十億枚の画像と数兆語のテキストが存在します。SSLはこの膨大なラベルなしデータを活用して強力な表現を学習します。

1.2 ラベル不足問題

実世界ではラベル収集が非常に困難です。

| ドメイン | ラベル収集コスト | 理由 |

| ------------ | ---------------- | ------------------ |

| 医療画像 | 非常に高い | 専門医が必要 |

| 衛星画像 | 高い | 地理専門知識が必要 |

| 法律文書 | 高い | 法律専門家が必要 |

| 工業欠陥検出 | 高い | 欠陥サンプルが希少 |

| 音声認識 | 中程度 | 文字起こしが必要 |

SSLはこの問題を根本的に解決します。大規模なラベルなしデータで事前学習し、少量のラベルデータでファインチューニングするというアプローチです。

1.3 プレテキストタスク

SSLの鍵は**プレテキストタスク**です。データ自体から人工的な予測問題を作り出し、モデルに意味のある表現を学習させます。

良いプレテキストタスクの条件:

1. ラベルなしで自動的に構築できること

2. タスクをうまく解くためには意味のある表現が必要であること

3. 難しすぎず簡単すぎないこと

例:

- **回転予測**:画像を0度、90度、180度、270度回転させ、何度回転したか予測

- **マスキング**:画像やテキストの一部を隠して復元

- **対照学習**:同じ画像の2つのビューを埋め込み空間で近づける

1.4 SSLの応用

SSLは現代AIの基盤となっています:

- **GPT、BERT**:テキストSSL → ChatGPT、Geminiの基礎

- **CLIP**:画像テキストSSL → DALL-E、Stable Diffusionの基礎

- **MAE、DINO**:画像SSL → 医療画像、自動運転の基礎

- **wav2vec**:音声SSL → 音声認識の基礎

2. 初期のSSL手法

初期のSSL研究では様々なプレテキストタスクが探索されました。

2.1 回転予測

2018年にGidarisらが提案。4つの回転(0度、90度、180度、270度)のどれが適用されたかをモデルが予測します。

from PIL import Image

class RotationSSL(nn.Module):

def __init__(self, backbone, num_classes=4):

super().__init__()

self.backbone = backbone

self.classifier = nn.Linear(backbone.output_dim, num_classes)

def forward(self, x):

features = self.backbone(x)

return self.classifier(features)

def create_rotation_dataset(images):

"""4方向に回転させた(画像, ラベル)ペアを作成"""

rotated_images = []

labels = []

for img in images:

for k in range(4): # 0, 90, 180, 270度

rotated = T.functional.rotate(img, k * 90)

rotated_images.append(rotated)

labels.append(k)

return rotated_images, labels

def train_rotation_ssl(model, dataloader, optimizer, epochs=10):

criterion = nn.CrossEntropyLoss()

model.train()

for epoch in range(epochs):

total_loss = 0

for images, labels in dataloader:

images, labels = images.cuda(), labels.cuda()

optimizer.zero_grad()

logits = model(images)

loss = criterion(logits, labels)

loss.backward()

optimizer.step()

total_loss += loss.item()

avg_loss = total_loss / len(dataloader)

print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

制限:自然画像には明確な向きがありますが、テクスチャや空の画像には適さない場合があります。

2.2 ジグソーパズル解き

2016年にNoroozi & Favaroが提案。画像を3x3のグリッドに分割してシャッフルし、元の順序を予測します。

class JigsawSSL(nn.Module):

def __init__(self, backbone, num_permutations=100):

super().__init__()

self.backbone = backbone

self.patch_encoder = nn.Sequential(

backbone,

nn.Linear(backbone.output_dim, 512)

)

self.classifier = nn.Sequential(

nn.Linear(512 * 9, 4096),

nn.ReLU(),

nn.Linear(4096, num_permutations)

)

def forward(self, patches):

patches: (batch, 9, C, H, W)

batch_size = patches.size(0)

patch_features = []

for i in range(9):

feat = self.patch_encoder(patches[:, i]) # (batch, 512)

patch_features.append(feat)

combined = torch.cat(patch_features, dim=1) # (batch, 512*9)

return self.classifier(combined)

def create_jigsaw_dataset(image, grid_size=3):

"""画像をgrid_size x grid_sizeのパッチに分割してシャッフル"""

h, w = image.shape[-2:]

patch_h, patch_w = h // grid_size, w // grid_size

patches = []

for i in range(grid_size):

for j in range(grid_size):

patch = image[...,

i*patch_h:(i+1)*patch_h,

j*patch_w:(j+1)*patch_w]

patches.append(patch)

permutation_idx = np.random.randint(0, len(PREDEFINED_PERMUTATIONS))

perm = PREDEFINED_PERMUTATIONS[permutation_idx]

shuffled = [patches[p] for p in perm]

return torch.stack(shuffled), permutation_idx

2.3 彩色

2016年にZhangらが提案。グレースケール画像を入力として、モデルが色を予測します。

class ColorizationSSL(nn.Module):

def __init__(self):

super().__init__()

L チャンネル入力(明度)

self.encoder = nn.Sequential(

nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(),

nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.ReLU(),

nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.ReLU(),

)

ab チャンネル出力(色)

self.decoder = nn.Sequential(

nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), nn.ReLU(),

nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.ReLU(),

nn.Conv2d(64, 2, 3, padding=1), # ab チャンネル

nn.Tanh()

)

def forward(self, l_channel):

features = self.encoder(l_channel)

ab_channels = self.decoder(features)

return ab_channels

from skimage.color import rgb2lab, lab2rgb

def prepare_colorization_data(rgb_image):

lab = rgb2lab(rgb_image)

l_channel = lab[:, :, 0:1] # 明度チャンネル

ab_channels = lab[:, :, 1:] # 色チャンネル

return l_channel, ab_channels

3. 対照学習

対照学習は現代SSLの中核です。2020年頃に爆発的な発展を遂げました。

3.1 核心的なアイデア

**似ているものを引き寄せ、異なるものを遠ざける。**

同じ画像の2つの異なるビュー(クロップ、フリップ、カラージッターなど)は埋め込み空間で近くにあるべきであり、異なる画像は離れているべきです。

画像 x

├── 拡張ビュー v1 → z1 (埋め込み)

└── 拡張ビュー v2 → z2 (埋め込み)

目標:z1とz2を近づけ、z1とz3(異なる画像)を遠ざける

3.2 InfoNCE損失

対照学習の標準損失関数——NT-Xent(正規化温度スケーリングクロスエントロピー)とも呼ばれます。

$$\mathcal{L} = -\frac{1}{2N} \sum_{i=1}^{N} \left[ \log \frac{e^{sim(z_i, z_{j(i)})/\tau}}{\sum_{k \neq i} e^{sim(z_i, z_k)/\tau}} \right]$$

simはコサイン類似度、τは温度パラメータです。

def info_nce_loss(z1, z2, temperature=0.5):

"""

InfoNCE / NT-Xent Loss

z1, z2: (batch_size, embedding_dim) - 同じ画像の2つのビュー

"""

batch_size = z1.size(0)

L2正規化

z1 = F.normalize(z1, dim=1)

z2 = F.normalize(z2, dim=1)

結合: [z1_1, ..., z1_N, z2_1, ..., z2_N]

z = torch.cat([z1, z2], dim=0) # (2N, D)

ペアワイズ類似度を計算

similarity = torch.mm(z, z.t()) / temperature # (2N, 2N)

自己類似度を除外(対角線 = -inf)

mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)

similarity.masked_fill_(mask, float('-inf'))

ポジティブペア:インデックス i と i+N

labels = torch.cat([

torch.arange(batch_size, 2 * batch_size),

torch.arange(batch_size)

]).to(z.device)

loss = F.cross_entropy(similarity, labels)

return loss

使用例

z1 = torch.randn(32, 128)

z2 = torch.randn(32, 128)

loss = info_nce_loss(z1, z2, temperature=0.5)

print(f"InfoNCE Loss: {loss.item():.4f}")

3.3 SimCLR

2020年にGoogleが発表。Simple Framework for Contrastive Learning。シンプルでありながら強力。

**主要コンポーネント**:

1. データ拡張(強いオーグメンテーションが重要)

2. エンコーダ(ResNet)

3. プロジェクションヘッド(MLP)

4. NT-Xent損失

class SimCLR(nn.Module):

def __init__(self, backbone='resnet50', projection_dim=128, temperature=0.5):

super().__init__()

self.temperature = temperature

if backbone == 'resnet50':

resnet = models.resnet50(pretrained=False)

self.encoder = nn.Sequential(*list(resnet.children())[:-1])

self.feature_dim = 2048

elif backbone == 'resnet18':

resnet = models.resnet18(pretrained=False)

self.encoder = nn.Sequential(*list(resnet.children())[:-1])

self.feature_dim = 512

プロジェクションヘッド(ファインチューニング時は削除)

self.projection_head = nn.Sequential(

nn.Linear(self.feature_dim, self.feature_dim),

nn.ReLU(inplace=True),

nn.Linear(self.feature_dim, projection_dim)

)

def encode(self, x):

h = self.encoder(x)

h = h.squeeze(-1).squeeze(-1)

return h

def forward(self, x):

h = self.encode(x)

z = self.projection_head(h)

return z

def contrastive_loss(self, z1, z2):

return info_nce_loss(z1, z2, self.temperature)

class SimCLRDataAugmentation:

"""SimCLRの強いデータ拡張"""

def __init__(self, image_size=224, s=1.0):

color_jitter = T.ColorJitter(

brightness=0.8*s,

contrast=0.8*s,

saturation=0.8*s,

hue=0.2*s

)

self.transform = T.Compose([

T.RandomResizedCrop(image_size, scale=(0.2, 1.0)),

T.RandomHorizontalFlip(p=0.5),

T.RandomApply([color_jitter], p=0.8),

T.RandomGrayscale(p=0.2),

T.GaussianBlur(kernel_size=int(0.1 * image_size)),

T.ToTensor(),

T.Normalize(mean=[0.485, 0.456, 0.406],

std=[0.229, 0.224, 0.225])

])

def __call__(self, x):

return self.transform(x), self.transform(x)

def train_simclr(model, dataloader, optimizer, epochs=200):

model.train()

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(

optimizer, T_max=epochs

)

for epoch in range(epochs):

total_loss = 0

for (x1, x2), _ in dataloader:

x1, x2 = x1.cuda(), x2.cuda()

z1 = model(x1)

z2 = model(x2)

loss = model.contrastive_loss(z1, z2)

optimizer.zero_grad()

loss.backward()

optimizer.step()

total_loss += loss.item()

scheduler.step()

avg_loss = total_loss / len(dataloader)

if (epoch + 1) % 10 == 0:

print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")

model = SimCLR(backbone='resnet50', projection_dim=128).cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-6)

def get_finetuning_model(pretrained_simclr, num_classes):

encoder = pretrained_simclr.encoder

classifier = nn.Linear(pretrained_simclr.feature_dim, num_classes)

return nn.Sequential(encoder, nn.Flatten(), classifier)

3.4 MoCo(モメンタムコントラスト)

2020年にFacebook AIが発表。大きいバッチサイズを必要とせず、多数のネガティブサンプルが使えます。

核心的なアイデア:**モメンタム更新による安定したキーエンコーダ + ネガティブサンプル管理のためのキュー**

class MoCo(nn.Module):

def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07):

super().__init__()

self.K = K # キューサイズ

self.m = m # モメンタム係数

self.T = T # 温度

クエリエンコーダとキーエンコーダ

self.encoder_q = base_encoder(num_classes=dim)

self.encoder_k = base_encoder(num_classes=dim)

キーエンコーダ:勾配なし、モメンタムで更新

for param_q, param_k in zip(

self.encoder_q.parameters(),

self.encoder_k.parameters()

):

param_k.data.copy_(param_q.data)

param_k.requires_grad = False

ネガティブサンプルキュー

self.register_buffer("queue", torch.randn(dim, K))

self.queue = F.normalize(self.queue, dim=0)

self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

@torch.no_grad()

def _momentum_update_key_encoder(self):

for param_q, param_k in zip(

self.encoder_q.parameters(),

self.encoder_k.parameters()

):

param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

@torch.no_grad()

def _dequeue_and_enqueue(self, keys):

batch_size = keys.shape[0]

ptr = int(self.queue_ptr)

self.queue[:, ptr:ptr + batch_size] = keys.T

ptr = (ptr + batch_size) % self.K

self.queue_ptr[0] = ptr

def forward(self, im_q, im_k):

q = self.encoder_q(im_q)

q = F.normalize(q, dim=1)

with torch.no_grad():

self._momentum_update_key_encoder()

k = self.encoder_k(im_k)

k = F.normalize(k, dim=1)

ポジティブロジット: (batch, 1)

l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)

ネガティブロジット: (batch, K)

l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

logits = torch.cat([l_pos, l_neg], dim=1) / self.T

labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

loss = F.cross_entropy(logits, labels)

self._dequeue_and_enqueue(k)

return loss

3.5 BYOL(Bootstrap Your Own Latent)

2020年にDeepMindが発表。**ネガティブサンプルなし**で機能する画期的な手法。

核心的な洞察:オンラインネットワークがターゲットネットワークの表現を予測します。ターゲットはモメンタムで更新されます。

class BYOL(nn.Module):

def __init__(self, backbone, projection_dim=256, prediction_dim=128, tau=0.996):

super().__init__()

self.tau = tau

オンラインネットワーク

self.online_encoder = backbone

self.online_projector = nn.Sequential(

nn.Linear(backbone.output_dim, 4096),

nn.BatchNorm1d(4096),

nn.ReLU(inplace=True),

nn.Linear(4096, projection_dim)

)

self.online_predictor = nn.Sequential(

nn.Linear(projection_dim, 4096),

nn.BatchNorm1d(4096),

nn.ReLU(inplace=True),

nn.Linear(4096, prediction_dim)

)

ターゲットネットワーク(モメンタムで更新)

self.target_encoder = copy.deepcopy(backbone)

self.target_projector = copy.deepcopy(self.online_projector)

for param in self.target_encoder.parameters():

param.requires_grad = False

for param in self.target_projector.parameters():

param.requires_grad = False

@torch.no_grad()

def update_target(self):

for online, target in zip(

self.online_encoder.parameters(),

self.target_encoder.parameters()

):

target.data = self.tau * target.data + (1 - self.tau) * online.data

def forward(self, x1, x2):

online_z1 = self.online_projector(self.online_encoder(x1))

online_z2 = self.online_projector(self.online_encoder(x2))

pred1 = self.online_predictor(online_z1)

pred2 = self.online_predictor(online_z2)

with torch.no_grad():

target_z1 = self.target_projector(self.target_encoder(x1))

target_z2 = self.target_projector(self.target_encoder(x2))

BYOL損失(対称的なコサイン類似度)

loss1 = 2 - 2 * F.cosine_similarity(pred1, target_z2.detach()).mean()

loss2 = 2 - 2 * F.cosine_similarity(pred2, target_z1.detach()).mean()

return (loss1 + loss2) / 2

4. マスク自己エンコーダ(MAE)

MAE(Masked Autoencoders Are Scalable Vision Learners)はMetaのKaiming Heらが2021年に発表。ViTと組み合わせて非常に効果的です。

4.1 核心的なアイデア

画像パッチの75%をランダムにマスクし、マスクされたピクセルを復元する。

**なぜ75%もマスクするのか?**

- 画像の冗長性が高いため、低いマスク率では簡単すぎる

- 高いマスク率ではモデルがシーンの全体的な理解を学習せざるを得ない

- BERTの15%とは異なり、画像には空間的な局所性がある

4.2 完全なMAE実装

class PatchEmbedding(nn.Module):

def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):

super().__init__()

self.patch_size = patch_size

num_patches = (img_size // patch_size) ** 2

self.projection = nn.Conv2d(

in_channels, embed_dim,

kernel_size=patch_size, stride=patch_size

)

def forward(self, x):

x = self.projection(x) # (B, embed_dim, H/P, W/P)

x = x.flatten(2).transpose(1, 2) # (B, num_patches, embed_dim)

return x

class TransformerBlock(nn.Module):

def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.0):

super().__init__()

self.norm1 = nn.LayerNorm(dim)

self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)

self.norm2 = nn.LayerNorm(dim)

mlp_dim = int(dim * mlp_ratio)

self.mlp = nn.Sequential(

nn.Linear(dim, mlp_dim),

nn.GELU(),

nn.Dropout(dropout),

nn.Linear(mlp_dim, dim),

nn.Dropout(dropout)

)

def forward(self, x):

x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]

x = x + self.mlp(self.norm2(x))

return x

class MAE(nn.Module):

def __init__(

self,

img_size=224,

patch_size=16,

in_channels=3,

encoder_dim=768,

encoder_depth=12,

encoder_heads=12,

decoder_dim=512,

decoder_depth=8,

decoder_heads=16,

mask_ratio=0.75,

):

super().__init__()

self.mask_ratio = mask_ratio

self.patch_size = patch_size

num_patches = (img_size // patch_size) ** 2

self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, encoder_dim)

self.cls_token = nn.Parameter(torch.zeros(1, 1, encoder_dim))

self.pos_embed = nn.Parameter(

torch.zeros(1, num_patches + 1, encoder_dim),

requires_grad=False

)

self.encoder_blocks = nn.ModuleList([

TransformerBlock(encoder_dim, encoder_heads)

for _ in range(encoder_depth)

])

self.encoder_norm = nn.LayerNorm(encoder_dim)

self.decoder_embed = nn.Linear(encoder_dim, decoder_dim, bias=True)

self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))

self.decoder_pos_embed = nn.Parameter(

torch.zeros(1, num_patches + 1, decoder_dim),

requires_grad=False

)

self.decoder_blocks = nn.ModuleList([

TransformerBlock(decoder_dim, decoder_heads)

for _ in range(decoder_depth)

])

self.decoder_norm = nn.LayerNorm(decoder_dim)

self.decoder_pred = nn.Linear(decoder_dim, patch_size**2 * in_channels)

def random_masking(self, x, mask_ratio):

N, L, D = x.shape

len_keep = int(L * (1 - mask_ratio))

noise = torch.rand(N, L, device=x.device)

ids_shuffle = torch.argsort(noise, dim=1)

ids_restore = torch.argsort(ids_shuffle, dim=1)

ids_keep = ids_shuffle[:, :len_keep]

x_masked = torch.gather(

x, dim=1,

index=ids_keep.unsqueeze(-1).repeat(1, 1, D)

)

mask = torch.ones(N, L, device=x.device)

mask[:, :len_keep] = 0

mask = torch.gather(mask, dim=1, index=ids_restore)

return x_masked, mask, ids_restore

def forward_encoder(self, x, mask_ratio):

x = self.patch_embed(x)

x = x + self.pos_embed[:, 1:, :]

x, mask, ids_restore = self.random_masking(x, mask_ratio)

cls_token = self.cls_token + self.pos_embed[:, :1, :]

cls_tokens = cls_token.expand(x.shape[0], -1, -1)

x = torch.cat((cls_tokens, x), dim=1)

for block in self.encoder_blocks:

x = block(x)

x = self.encoder_norm(x)

return x, mask, ids_restore

def forward_decoder(self, x, ids_restore):

x = self.decoder_embed(x)

mask_tokens = self.mask_token.repeat(

x.shape[0],

ids_restore.shape[1] + 1 - x.shape[1],

1

)

x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)

x_ = torch.gather(

x_, dim=1,

index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])

)

x = torch.cat([x[:, :1, :], x_], dim=1)

x = x + self.decoder_pos_embed

for block in self.decoder_blocks:

x = block(x)

x = self.decoder_norm(x)

x = self.decoder_pred(x)

x = x[:, 1:, :] # CLSトークンを削除

return x

def forward(self, imgs, mask_ratio=0.75):

latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)

pred = self.forward_decoder(latent, ids_restore)

loss = self.forward_loss(imgs, pred, mask)

return loss, pred, mask

def forward_loss(self, imgs, pred, mask):

target = self.patchify(imgs)

mean = target.mean(dim=-1, keepdim=True)

var = target.var(dim=-1, keepdim=True)

target = (target - mean) / (var + 1.e-6).sqrt()

loss = (pred - target) ** 2

loss = loss.mean(dim=-1)

loss = (loss * mask).sum() / mask.sum()

return loss

def patchify(self, imgs):

p = self.patch_size

h = w = imgs.shape[2] // p

x = imgs.reshape(imgs.shape[0], 3, h, p, w, p)

x = torch.einsum('nchpwq->nhwpqc', x)

x = x.reshape(imgs.shape[0], h * w, p**2 * 3)

return x

MAEモデルの作成と学習

mae_model = MAE(

img_size=224,

patch_size=16,

encoder_dim=768,

encoder_depth=12,

encoder_heads=12,

decoder_dim=512,

decoder_depth=8,

decoder_heads=16,

mask_ratio=0.75

).cuda()

optimizer = torch.optim.AdamW(

mae_model.parameters(),

lr=1.5e-4,

betas=(0.9, 0.95),

weight_decay=0.05

)

def train_step(model, imgs, optimizer):

model.train()

optimizer.zero_grad()

loss, pred, mask = model(imgs, mask_ratio=0.75)

loss.backward()

optimizer.step()

return loss.item()

5. DINO

DINO(Self-DIstillation with NO labels)は2021年にFacebook AI Researchが発表。ViTをSSLで学習すると驚くべき特性が現れます。

5.1 主要な発見

DINOで学習したViTのセルフアテンションマップは**ラベルなしでセマンティックセグメンテーション**を実行します。モデルが前景と背景を完璧に分離します。

5.2 自己蒸留アーキテクチャ

学生ネットワーク 教師ネットワーク

↑ ↑

局所ビュー(多数) グローバルビュー(2枚)

- **教師**:画像全体のビューを処理、豊かなコンテキスト

- **学生**:局所的なクロップを処理、教師の予測に一致するよう学習

- **教師の更新**:学生のEMA(勾配なし)

5.3 センタリングとシャーペニング

崩壊を防ぐ2つのテクニック:

class DINOHead(nn.Module):

def __init__(self, in_dim, out_dim=65536, bottleneck_dim=256):

super().__init__()

layers = [nn.Linear(in_dim, 2048), nn.GELU()]

layers += [nn.Linear(2048, 2048), nn.GELU()]

layers += [nn.Linear(2048, bottleneck_dim)]

self.mlp = nn.Sequential(*layers)

self.last_layer = nn.utils.weight_norm(

nn.Linear(bottleneck_dim, out_dim, bias=False)

)

self.last_layer.weight_g.data.fill_(1)

self.last_layer.weight_g.requires_grad = False

def forward(self, x):

x = self.mlp(x)

x = F.normalize(x, dim=-1)

x = self.last_layer(x)

return x

class DINO(nn.Module):

def __init__(self, backbone, head_out_dim=65536, tau_s=0.1, tau_t=0.04, m=0.996):

super().__init__()

self.tau_s = tau_s # 学生の温度

self.tau_t = tau_t # 教師の温度

self.m = m # モメンタム係数

self.student = nn.Sequential(backbone, DINOHead(backbone.output_dim, head_out_dim))

self.teacher = copy.deepcopy(self.student)

for param in self.teacher.parameters():

param.requires_grad = False

センタリング(教師の出力の移動平均)

self.register_buffer("center", torch.zeros(1, head_out_dim))

@torch.no_grad()

def update_teacher(self):

for s_param, t_param in zip(

self.student.parameters(),

self.teacher.parameters()

):

t_param.data = self.m * t_param.data + (1 - self.m) * s_param.data

def dino_loss(self, student_output, teacher_output):

"""DINOクロスエントロピー損失"""

student_out = student_output / self.tau_s

student_out = student_out.chunk(2) # 局所ビュー

センタリングと温度シャーペニング

teacher_out = F.softmax(

(teacher_output - self.center) / self.tau_t, dim=-1

)

teacher_out = teacher_out.detach().chunk(2) # グローバルビュー

total_loss = 0

n_loss_terms = 0

for iq, q in enumerate(teacher_out):

for v in range(len(student_out)):

if v == iq:

continue

loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)

total_loss += loss.mean()

n_loss_terms += 1

total_loss /= n_loss_terms

センターを更新

self.center = 0.9 * self.center + 0.1 * teacher_output.mean(dim=0, keepdim=True)

return total_loss

6. CLIP

CLIP(Contrastive Language-Image Pre-training)は2021年にOpenAIが発表。画像とテキストをペアで学習するマルチモーダルSSLです。

6.1 コアコンセプト

4億枚の画像とキャプションのペアで対照学習を適用:

- 同じ画像テキストペアの埋め込みを近づける

- 異なるペアの埋め込みを遠ざける

**ゼロショット能力**:学習後、新しいカテゴリの画像を直接分類できます(テキストプロンプトを使用)。

6.2 完全なCLIP実装

from transformers import CLIPModel, CLIPProcessor

from PIL import Image

def zero_shot_classify(image_path, class_labels, model_name="openai/clip-vit-base-patch32"):

"""CLIPを使用したゼロショット画像分類"""

model = CLIPModel.from_pretrained(model_name)

processor = CLIPProcessor.from_pretrained(model_name)

model.eval()

image = Image.open(image_path)

テキストプロンプトを作成

text_prompts = [f"a photo of a {label}" for label in class_labels]

inputs = processor(

text=text_prompts,

images=image,

return_tensors="pt",

padding=True

)

with torch.no_grad():

outputs = model(**inputs)

logits_per_image = outputs.logits_per_image

probs = logits_per_image.softmax(dim=1)

results = {}

for label, prob in zip(class_labels, probs[0]):

results[label] = prob.item()

return results

使用例

labels = ["cat", "dog", "car", "airplane", "bird"]

results = zero_shot_classify("example.jpg", labels)

for label, prob in sorted(results.items(), key=lambda x: -x[1]):

print(f"{label}: {prob:.4f}")

---- 画像テキスト類似度検索 ----

class CLIPRetrieval:

def __init__(self, model_name="openai/clip-vit-base-patch32"):

self.model = CLIPModel.from_pretrained(model_name)

self.processor = CLIPProcessor.from_pretrained(model_name)

self.model.eval()

self.image_embeddings = None

self.image_paths = []

def index_images(self, image_paths):

"""画像データベースをインデックス化"""

self.image_paths = image_paths

embeddings = []

for path in image_paths:

image = Image.open(path)

inputs = self.processor(images=image, return_tensors="pt")

with torch.no_grad():

embedding = self.model.get_image_features(**inputs)

embeddings.append(F.normalize(embedding, dim=1))

self.image_embeddings = torch.cat(embeddings, dim=0)

print(f"{len(image_paths)}枚の画像をインデックス化しました")

def search(self, query_text, top_k=5):

"""テキストクエリで画像を検索"""

inputs = self.processor(text=[query_text], return_tensors="pt", padding=True)

with torch.no_grad():

text_embedding = self.model.get_text_features(**inputs)

text_embedding = F.normalize(text_embedding, dim=1)

similarities = torch.matmul(text_embedding, self.image_embeddings.T).squeeze()

top_indices = similarities.argsort(descending=True)[:top_k]

results = [

(self.image_paths[i], similarities[i].item())

for i in top_indices

]

return results

7. BEiT

BEiT(BERT Pre-Training of Image Transformers)は2021年にMicrosoftが発表。画像を離散トークンに変換してBERT風に学習します。

7.1 離散視覚トークン(dVAE)

BEiTのコア:生のピクセルを予測する代わりに**離散トークン**を予測します。

ステップ1: dVAEが画像を離散トークンに変換(語彙サイズ8192)

ステップ2: ViTがマスクされた各パッチのトークンを予測

from transformers import BeitForMaskedImageModeling, BeitImageProcessor

class BEiTPretraining:

def __init__(self, model_name="microsoft/beit-base-patch16-224"):

self.model = BeitForMaskedImageModeling.from_pretrained(model_name)

self.processor = BeitImageProcessor.from_pretrained(model_name)

def create_bool_masked_pos(self, num_patches, mask_ratio=0.4):

num_masked = int(num_patches * mask_ratio)

mask = torch.zeros(num_patches, dtype=torch.bool)

masked_indices = torch.randperm(num_patches)[:num_masked]

mask[masked_indices] = True

return mask

def forward(self, images):

inputs = self.processor(images, return_tensors="pt")

pixel_values = inputs.pixel_values

num_patches = (224 // 16) ** 2 # 196

bool_masked_pos = self.create_bool_masked_pos(num_patches)

bool_masked_pos = bool_masked_pos.unsqueeze(0).expand(

pixel_values.size(0), -1

)

outputs = self.model(

pixel_values=pixel_values,

bool_masked_pos=bool_masked_pos

)

return outputs.loss

8. 言語モデル事前学習

SSLはNLPに長い歴史があります。

8.1 GPT - 次のトークン予測

左から右へと次のトークンを予測する自己回帰的アプローチ。

from torch.nn import functional as F

class GPTLanguageModel(nn.Module):

def __init__(self, vocab_size, block_size, n_embd, n_head, n_layer, dropout=0.1):

super().__init__()

self.token_embedding = nn.Embedding(vocab_size, n_embd)

self.position_embedding = nn.Embedding(block_size, n_embd)

self.blocks = nn.Sequential(*[

TransformerBlock(n_embd, n_head, dropout=dropout)

for _ in range(n_layer)

])

self.ln_f = nn.LayerNorm(n_embd)

self.lm_head = nn.Linear(n_embd, vocab_size)

self.block_size = block_size

def forward(self, idx, targets=None):

B, T = idx.shape

tok_emb = self.token_embedding(idx)

pos_emb = self.position_embedding(torch.arange(T, device=idx.device))

x = tok_emb + pos_emb

x = self.blocks(x)

x = self.ln_f(x)

logits = self.lm_head(x)

if targets is None:

return logits, None

B, T, C = logits.shape

loss = F.cross_entropy(logits.view(B * T, C), targets.view(B * T))

return logits, loss

@torch.no_grad()

def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):

for _ in range(max_new_tokens):

idx_cond = idx[:, -self.block_size:]

logits, _ = self(idx_cond)

logits = logits[:, -1, :] / temperature

if top_k is not None:

v, _ = torch.topk(logits, min(top_k, logits.size(-1)))

logits[logits < v[:, [-1]]] = float('-inf')

probs = F.softmax(logits, dim=-1)

idx_next = torch.multinomial(probs, num_samples=1)

idx = torch.cat((idx, idx_next), dim=1)

return idx

8.2 BERT - マスク言語モデリング

双方向コンテキストを学習します。15%のトークンがマスクされ、復元する必要があります。

from transformers import BertForMaskedLM, BertTokenizer

class BERTPretraining:

def __init__(self, model_name="bert-base-uncased"):

self.model = BertForMaskedLM.from_pretrained(model_name)

self.tokenizer = BertTokenizer.from_pretrained(model_name)

def mask_tokens(self, inputs, mlm_probability=0.15):

"""MLMのためにトークンをマスク"""

labels = inputs.clone()

probability_matrix = torch.full(labels.shape, mlm_probability)

特殊トークンはマスクしない

special_tokens_mask = [

self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)

for val in labels.tolist()

]

special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)

probability_matrix.masked_fill_(special_tokens_mask, value=0.0)

masked_indices = torch.bernoulli(probability_matrix).bool()

labels[~masked_indices] = -100 # マスクされていないトークンの損失を計算しない

80%は[MASK]に置換

indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices

inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

10%はランダムなトークンに置換

indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced

random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)

inputs[indices_random] = random_words[indices_random]

return inputs, labels

9. 実践的な応用

9.1 線形評価とファインチューニング

class LinearProbe(nn.Module):

"""SSL表現を評価するための線形分類器"""

def __init__(self, input_dim, num_classes):

super().__init__()

self.linear = nn.Linear(input_dim, num_classes)

def forward(self, x):

return self.linear(x)

def extract_features(model, dataloader, device='cuda'):

"""事前学習済みモデルから特徴を抽出"""

model.eval()

features_list = []

labels_list = []

with torch.no_grad():

for images, labels in dataloader:

images = images.to(device)

features = model.encode(images)

features_list.append(features.cpu())

labels_list.append(labels)

return torch.cat(features_list), torch.cat(labels_list)

def few_shot_evaluation(ssl_model, train_loader, val_loader,

n_shots=[1, 5, 10, 100], device='cuda'):

"""クラスごとのnサンプルで評価"""

ssl_model.eval()

all_features, all_labels = extract_features(ssl_model, train_loader, device)

val_features, val_labels = extract_features(ssl_model, val_loader, device)

results = {}

num_classes = len(all_labels.unique())

for n_shot in n_shots:

selected_indices = []

for cls in range(num_classes):

cls_indices = (all_labels == cls).nonzero(as_tuple=True)[0]

selected = cls_indices[torch.randperm(len(cls_indices))[:n_shot]]

selected_indices.append(selected)

selected_indices = torch.cat(selected_indices)

train_feats = all_features[selected_indices]

train_labs = all_labels[selected_indices]

probe = LinearProbe(all_features.shape[1], num_classes).to(device)

optimizer = torch.optim.Adam(probe.parameters(), lr=1e-3)

criterion = nn.CrossEntropyLoss()

train_feats = train_feats.to(device)

train_labs = train_labs.to(device)

for epoch in range(100):

probe.train()

logits = probe(train_feats)

loss = criterion(logits, train_labs)

optimizer.zero_grad()

loss.backward()

optimizer.step()

probe.eval()

with torch.no_grad():

val_logits = probe(val_features.to(device))

preds = val_logits.argmax(dim=1).cpu()

acc = (preds == val_labels).float().mean().item()

results[n_shot] = acc

print(f"{n_shot}-shot精度: {acc * 100:.2f}%")

return results

def knn_evaluation(ssl_model, train_loader, val_loader, k=20, device='cuda'):

"""SSL表現の品質をk-NNで評価(パラメータフリー)"""

ssl_model.eval()

train_features, train_labels = extract_features(ssl_model, train_loader, device)

val_features, val_labels = extract_features(ssl_model, val_loader, device)

train_features = nn.functional.normalize(train_features, dim=1)

val_features = nn.functional.normalize(val_features, dim=1)

correct = 0

total = 0

batch_size = 256

for i in range(0, len(val_features), batch_size):

batch_feats = val_features[i:i+batch_size].to(device)

batch_labels = val_labels[i:i+batch_size]

sims = torch.mm(batch_feats, train_features.to(device).T)

topk_sims, topk_indices = sims.topk(k, dim=1)

topk_labels = train_labels[topk_indices.cpu()]

preds = topk_labels.mode(dim=1).values

correct += (preds == batch_labels).sum().item()

total += len(batch_labels)

accuracy = correct / total

print(f"k-NN ({k}) 精度: {accuracy * 100:.2f}%")

return accuracy

9.2 ドメイン特化SSL

医療画像(胸部X線の例)

class MedicalSSL(nn.Module):

"""医療画像のための自己教師あり学習"""

def __init__(self, backbone):

super().__init__()

self.backbone = backbone

医療画像には弱いオーグメンテーションを使用

self.augmentation = T.Compose([

T.RandomResizedCrop(224, scale=(0.8, 1.0)),

T.RandomHorizontalFlip(p=0.5),

医療画像では色変化を最小限に

T.RandomApply([T.ColorJitter(0.1, 0.1, 0.0, 0.0)], p=0.3),

T.ToTensor(),

T.Normalize([0.485], [0.229]) # グレースケールX線

])

def forward(self, x):

v1 = self.augmentation(x)

v2 = self.augmentation(x)

return v1, v2

衛星画像

class SatelliteSSL:

"""衛星画像のオーグメンテーション(地理的特性を保持)"""

def __init__(self):

self.transform = T.Compose([

T.RandomResizedCrop(256, scale=(0.4, 1.0)),

衛星画像:回転不変性が重要

T.RandomRotation(90),

T.RandomHorizontalFlip(),

T.RandomVerticalFlip(),

時刻・季節変化をシミュレート

T.ColorJitter(brightness=0.4, contrast=0.4),

T.ToTensor(),

])

9.3 HubからのSSL事前学習モデルの使用

from transformers import (

ViTModel, ViTFeatureExtractor,

CLIPModel, CLIPProcessor

)

from PIL import Image

MAE事前学習済みViTを読み込む

def load_mae_pretrained():

model = ViTModel.from_pretrained("facebook/vit-mae-base")

feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")

return model, feature_extractor

DINO ViT特徴抽出

def extract_dino_features(image_path):

model = ViTModel.from_pretrained("facebook/dino-vits16")

processor = ViTFeatureExtractor.from_pretrained("facebook/dino-vits16")

image = Image.open(image_path)

inputs = processor(images=image, return_tensors="pt")

with torch.no_grad():

outputs = model(**inputs)

CLSトークン(グローバル画像表現)

cls_features = outputs.last_hidden_state[:, 0, :]

パッチごとの特徴(空間情報あり)

patch_features = outputs.last_hidden_state[:, 1:, :]

return cls_features, patch_features

CLIPで特徴抽出

def clip_feature_extraction(images, texts=None):

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

if texts:

inputs = processor(text=texts, images=images, return_tensors="pt", padding=True)

with torch.no_grad():

outputs = model(**inputs)

return outputs.image_embeds, outputs.text_embeds

else:

inputs = processor(images=images, return_tensors="pt")

with torch.no_grad():

image_features = model.get_image_features(**inputs)

return image_features

まとめ

自己教師あり学習はAIを変革しています。重要なポイント:

| 手法 | 年 | 核心的なアイデア | 強み |

| ------ | ---- | ----------------------------------- | ---------------------- |

| SimCLR | 2020 | 強いオーグメンテーション + 対照学習 | シンプルで強力 |

| MoCo | 2020 | モメンタムキュー | バッチサイズ非依存 |

| BYOL | 2020 | ネガティブなしの対照学習 | バッチサイズ非依存 |

| MAE | 2021 | 75%マスキング + ピクセル復元 | 効率的、ViT最適 |

| DINO | 2021 | 自己蒸留 | セグメンテーション特性 |

| CLIP | 2021 | 画像テキスト対照学習 | ゼロショット分類 |

| BEiT | 2021 | 離散トークンマスキング | BERTスタイル |

| DINOv2 | 2023 | 大規模 + iBOT | 汎用特徴抽出器 |

**学習ロードマップ**:

1. SimCLRを実装して対照学習を理解する

2. MAEを学習してマスキングアプローチを把握する

3. CLIPをマルチモーダル学習に応用する

4. 自分のドメインデータでファインチューニングする

SSLはLLM事前学習とコンピュータビジョンの基盤となっています。膨大なラベルなしデータを効果的に活用するアプローチとして、SSLはAI進歩の中心的な推進力であり続けるでしょう。

参考文献

- [SimCLR論文](https://arxiv.org/abs/2002.05709) - Chen et al., 2020

- [MAE論文](https://arxiv.org/abs/2111.06377) - He et al., 2021

- [DINO論文](https://arxiv.org/abs/2104.14294) - Caron et al., 2021

- [CLIP論文](https://arxiv.org/abs/2103.00020) - Radford et al., 2021

- [BYOL論文](https://arxiv.org/abs/2006.07733) - Grill et al., 2020

현재 단락 (1/888)

1. [自己教師あり学習とは?](#1-自己教師あり学習とは)

작성 글자: 0원문 글자: 29,350작성 단락: 0/888