Skip to content
Published on

自己教師あり学習完全ガイド: SimCLR・MAE・DINO・CLIPまで

Authors

目次

  1. 自己教師あり学習とは?
  2. 初期のSSL手法
  3. 対照学習
  4. マスク自己エンコーダ(MAE)
  5. DINO - ラベルなし自己蒸留
  6. CLIP - 対照言語画像事前学習
  7. BEiT - 画像のためのBERT
  8. 言語モデル事前学習
  9. 実践的な応用

1. 自己教師あり学習とは?

1.1 学習パラダイムの比較

深層学習における3つの主要な学習パラダイムから理解していきましょう。

教師あり学習はラベル付きデータで学習します。ImageNetの1000万枚の画像にはそれぞれ人間が付けたラベルがあり、ResNetやViTといったモデルはこのデータで学習されます。問題はラベル収集が非常にコストがかかることです。医療画像1枚に専門家が数十分かけてアノテーションする必要がある場合もあります。

教師なし学習はラベルなしでデータのパターンを発見します。K-Meansクラスタリング、PCA、GANがこのカテゴリに属します。しかし従来の教師なし学習は、下流タスクに直接利用できる表現を学習するのに苦労します。

**自己教師あり学習(SSL)**は両者の長所を組み合わせます。大量のラベルなしデータを活用しながら、データ自体から監督信号を生成します。

核心的なアイデア:データ自体がラベルである。

インターネットには数十億枚の画像と数兆語のテキストが存在します。SSLはこの膨大なラベルなしデータを活用して強力な表現を学習します。

1.2 ラベル不足問題

実世界ではラベル収集が非常に困難です。

ドメインラベル収集コスト理由
医療画像非常に高い専門医が必要
衛星画像高い地理専門知識が必要
法律文書高い法律専門家が必要
工業欠陥検出高い欠陥サンプルが希少
音声認識中程度文字起こしが必要

SSLはこの問題を根本的に解決します。大規模なラベルなしデータで事前学習し、少量のラベルデータでファインチューニングするというアプローチです。

1.3 プレテキストタスク

SSLの鍵はプレテキストタスクです。データ自体から人工的な予測問題を作り出し、モデルに意味のある表現を学習させます。

良いプレテキストタスクの条件:

  1. ラベルなしで自動的に構築できること
  2. タスクをうまく解くためには意味のある表現が必要であること
  3. 難しすぎず簡単すぎないこと

例:

  • 回転予測:画像を0度、90度、180度、270度回転させ、何度回転したか予測
  • マスキング:画像やテキストの一部を隠して復元
  • 対照学習:同じ画像の2つのビューを埋め込み空間で近づける

1.4 SSLの応用

SSLは現代AIの基盤となっています:

  • GPT、BERT:テキストSSL → ChatGPT、Geminiの基礎
  • CLIP:画像テキストSSL → DALL-E、Stable Diffusionの基礎
  • MAE、DINO:画像SSL → 医療画像、自動運転の基礎
  • wav2vec:音声SSL → 音声認識の基礎

2. 初期のSSL手法

初期のSSL研究では様々なプレテキストタスクが探索されました。

2.1 回転予測

2018年にGidarisらが提案。4つの回転(0度、90度、180度、270度)のどれが適用されたかをモデルが予測します。

import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image

class RotationSSL(nn.Module):
    def __init__(self, backbone, num_classes=4):
        super().__init__()
        self.backbone = backbone
        self.classifier = nn.Linear(backbone.output_dim, num_classes)

    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)

def create_rotation_dataset(images):
    """4方向に回転させた(画像, ラベル)ペアを作成"""
    rotated_images = []
    labels = []

    for img in images:
        for k in range(4):  # 0, 90, 180, 270度
            rotated = T.functional.rotate(img, k * 90)
            rotated_images.append(rotated)
            labels.append(k)

    return rotated_images, labels

def train_rotation_ssl(model, dataloader, optimizer, epochs=10):
    criterion = nn.CrossEntropyLoss()
    model.train()

    for epoch in range(epochs):
        total_loss = 0
        for images, labels in dataloader:
            images, labels = images.cuda(), labels.cuda()
            optimizer.zero_grad()
            logits = model(images)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

制限:自然画像には明確な向きがありますが、テクスチャや空の画像には適さない場合があります。

2.2 ジグソーパズル解き

2016年にNoroozi & Favaroが提案。画像を3x3のグリッドに分割してシャッフルし、元の順序を予測します。

import itertools
import numpy as np

class JigsawSSL(nn.Module):
    def __init__(self, backbone, num_permutations=100):
        super().__init__()
        self.backbone = backbone
        self.patch_encoder = nn.Sequential(
            backbone,
            nn.Linear(backbone.output_dim, 512)
        )
        self.classifier = nn.Sequential(
            nn.Linear(512 * 9, 4096),
            nn.ReLU(),
            nn.Linear(4096, num_permutations)
        )

    def forward(self, patches):
        # patches: (batch, 9, C, H, W)
        batch_size = patches.size(0)
        patch_features = []

        for i in range(9):
            feat = self.patch_encoder(patches[:, i])  # (batch, 512)
            patch_features.append(feat)

        combined = torch.cat(patch_features, dim=1)  # (batch, 512*9)
        return self.classifier(combined)

def create_jigsaw_dataset(image, grid_size=3):
    """画像をgrid_size x grid_sizeのパッチに分割してシャッフル"""
    h, w = image.shape[-2:]
    patch_h, patch_w = h // grid_size, w // grid_size

    patches = []
    for i in range(grid_size):
        for j in range(grid_size):
            patch = image[...,
                         i*patch_h:(i+1)*patch_h,
                         j*patch_w:(j+1)*patch_w]
            patches.append(patch)

    permutation_idx = np.random.randint(0, len(PREDEFINED_PERMUTATIONS))
    perm = PREDEFINED_PERMUTATIONS[permutation_idx]
    shuffled = [patches[p] for p in perm]

    return torch.stack(shuffled), permutation_idx

2.3 彩色

2016年にZhangらが提案。グレースケール画像を入力として、モデルが色を予測します。

class ColorizationSSL(nn.Module):
    def __init__(self):
        super().__init__()
        # L チャンネル入力(明度)
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.ReLU(),
        )
        # ab チャンネル出力(色)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(64, 2, 3, padding=1),  # ab チャンネル
            nn.Tanh()
        )

    def forward(self, l_channel):
        features = self.encoder(l_channel)
        ab_channels = self.decoder(features)
        return ab_channels

from skimage.color import rgb2lab, lab2rgb

def prepare_colorization_data(rgb_image):
    lab = rgb2lab(rgb_image)
    l_channel = lab[:, :, 0:1]   # 明度チャンネル
    ab_channels = lab[:, :, 1:]  # 色チャンネル
    return l_channel, ab_channels

3. 対照学習

対照学習は現代SSLの中核です。2020年頃に爆発的な発展を遂げました。

3.1 核心的なアイデア

似ているものを引き寄せ、異なるものを遠ざける。

同じ画像の2つの異なるビュー(クロップ、フリップ、カラージッターなど)は埋め込み空間で近くにあるべきであり、異なる画像は離れているべきです。

画像 x
├── 拡張ビュー v1 → z1 (埋め込み)
└── 拡張ビュー v2 → z2 (埋め込み)

目標:z1とz2を近づけ、z1とz3(異なる画像)を遠ざける

3.2 InfoNCE損失

対照学習の標準損失関数——NT-Xent(正規化温度スケーリングクロスエントロピー)とも呼ばれます。

L=12Ni=1N[logesim(zi,zj(i))/τkiesim(zi,zk)/τ]\mathcal{L} = -\frac{1}{2N} \sum_{i=1}^{N} \left[ \log \frac{e^{sim(z_i, z_{j(i)})/\tau}}{\sum_{k \neq i} e^{sim(z_i, z_k)/\tau}} \right]

simはコサイン類似度、τは温度パラメータです。

import torch
import torch.nn.functional as F

def info_nce_loss(z1, z2, temperature=0.5):
    """
    InfoNCE / NT-Xent Loss
    z1, z2: (batch_size, embedding_dim) - 同じ画像の2つのビュー
    """
    batch_size = z1.size(0)

    # L2正規化
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)

    # 結合: [z1_1, ..., z1_N, z2_1, ..., z2_N]
    z = torch.cat([z1, z2], dim=0)  # (2N, D)

    # ペアワイズ類似度を計算
    similarity = torch.mm(z, z.t()) / temperature  # (2N, 2N)

    # 自己類似度を除外(対角線 = -inf)
    mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
    similarity.masked_fill_(mask, float('-inf'))

    # ポジティブペア:インデックス i と i+N
    labels = torch.cat([
        torch.arange(batch_size, 2 * batch_size),
        torch.arange(batch_size)
    ]).to(z.device)

    loss = F.cross_entropy(similarity, labels)
    return loss

# 使用例
z1 = torch.randn(32, 128)
z2 = torch.randn(32, 128)
loss = info_nce_loss(z1, z2, temperature=0.5)
print(f"InfoNCE Loss: {loss.item():.4f}")

3.3 SimCLR

2020年にGoogleが発表。Simple Framework for Contrastive Learning。シンプルでありながら強力。

主要コンポーネント

  1. データ拡張(強いオーグメンテーションが重要)
  2. エンコーダ(ResNet)
  3. プロジェクションヘッド(MLP)
  4. NT-Xent損失
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as T

class SimCLR(nn.Module):
    def __init__(self, backbone='resnet50', projection_dim=128, temperature=0.5):
        super().__init__()
        self.temperature = temperature

        if backbone == 'resnet50':
            resnet = models.resnet50(pretrained=False)
            self.encoder = nn.Sequential(*list(resnet.children())[:-1])
            self.feature_dim = 2048
        elif backbone == 'resnet18':
            resnet = models.resnet18(pretrained=False)
            self.encoder = nn.Sequential(*list(resnet.children())[:-1])
            self.feature_dim = 512

        # プロジェクションヘッド(ファインチューニング時は削除)
        self.projection_head = nn.Sequential(
            nn.Linear(self.feature_dim, self.feature_dim),
            nn.ReLU(inplace=True),
            nn.Linear(self.feature_dim, projection_dim)
        )

    def encode(self, x):
        h = self.encoder(x)
        h = h.squeeze(-1).squeeze(-1)
        return h

    def forward(self, x):
        h = self.encode(x)
        z = self.projection_head(h)
        return z

    def contrastive_loss(self, z1, z2):
        return info_nce_loss(z1, z2, self.temperature)


class SimCLRDataAugmentation:
    """SimCLRの強いデータ拡張"""
    def __init__(self, image_size=224, s=1.0):
        color_jitter = T.ColorJitter(
            brightness=0.8*s,
            contrast=0.8*s,
            saturation=0.8*s,
            hue=0.2*s
        )
        self.transform = T.Compose([
            T.RandomResizedCrop(image_size, scale=(0.2, 1.0)),
            T.RandomHorizontalFlip(p=0.5),
            T.RandomApply([color_jitter], p=0.8),
            T.RandomGrayscale(p=0.2),
            T.GaussianBlur(kernel_size=int(0.1 * image_size)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
        ])

    def __call__(self, x):
        return self.transform(x), self.transform(x)


def train_simclr(model, dataloader, optimizer, epochs=200):
    model.train()
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=epochs
    )

    for epoch in range(epochs):
        total_loss = 0
        for (x1, x2), _ in dataloader:
            x1, x2 = x1.cuda(), x2.cuda()

            z1 = model(x1)
            z2 = model(x2)

            loss = model.contrastive_loss(z1, z2)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        scheduler.step()
        avg_loss = total_loss / len(dataloader)

        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")


model = SimCLR(backbone='resnet50', projection_dim=128).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-6)

def get_finetuning_model(pretrained_simclr, num_classes):
    encoder = pretrained_simclr.encoder
    classifier = nn.Linear(pretrained_simclr.feature_dim, num_classes)
    return nn.Sequential(encoder, nn.Flatten(), classifier)

3.4 MoCo(モメンタムコントラスト)

2020年にFacebook AIが発表。大きいバッチサイズを必要とせず、多数のネガティブサンプルが使えます。

核心的なアイデア:モメンタム更新による安定したキーエンコーダ + ネガティブサンプル管理のためのキュー

class MoCo(nn.Module):
    def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07):
        super().__init__()
        self.K = K   # キューサイズ
        self.m = m   # モメンタム係数
        self.T = T   # 温度

        # クエリエンコーダとキーエンコーダ
        self.encoder_q = base_encoder(num_classes=dim)
        self.encoder_k = base_encoder(num_classes=dim)

        # キーエンコーダ:勾配なし、モメンタムで更新
        for param_q, param_k in zip(
            self.encoder_q.parameters(),
            self.encoder_k.parameters()
        ):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False

        # ネガティブサンプルキュー
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = F.normalize(self.queue, dim=0)
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        for param_q, param_k in zip(
            self.encoder_q.parameters(),
            self.encoder_k.parameters()
        ):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        batch_size = keys.shape[0]
        ptr = int(self.queue_ptr)
        self.queue[:, ptr:ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K
        self.queue_ptr[0] = ptr

    def forward(self, im_q, im_k):
        q = self.encoder_q(im_q)
        q = F.normalize(q, dim=1)

        with torch.no_grad():
            self._momentum_update_key_encoder()
            k = self.encoder_k(im_k)
            k = F.normalize(k, dim=1)

        # ポジティブロジット: (batch, 1)
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        # ネガティブロジット: (batch, K)
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

        logits = torch.cat([l_pos, l_neg], dim=1) / self.T
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

        loss = F.cross_entropy(logits, labels)

        self._dequeue_and_enqueue(k)
        return loss

3.5 BYOL(Bootstrap Your Own Latent)

2020年にDeepMindが発表。ネガティブサンプルなしで機能する画期的な手法。

核心的な洞察:オンラインネットワークがターゲットネットワークの表現を予測します。ターゲットはモメンタムで更新されます。

import copy

class BYOL(nn.Module):
    def __init__(self, backbone, projection_dim=256, prediction_dim=128, tau=0.996):
        super().__init__()
        self.tau = tau

        # オンラインネットワーク
        self.online_encoder = backbone
        self.online_projector = nn.Sequential(
            nn.Linear(backbone.output_dim, 4096),
            nn.BatchNorm1d(4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, projection_dim)
        )
        self.online_predictor = nn.Sequential(
            nn.Linear(projection_dim, 4096),
            nn.BatchNorm1d(4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, prediction_dim)
        )

        # ターゲットネットワーク(モメンタムで更新)
        self.target_encoder = copy.deepcopy(backbone)
        self.target_projector = copy.deepcopy(self.online_projector)

        for param in self.target_encoder.parameters():
            param.requires_grad = False
        for param in self.target_projector.parameters():
            param.requires_grad = False

    @torch.no_grad()
    def update_target(self):
        for online, target in zip(
            self.online_encoder.parameters(),
            self.target_encoder.parameters()
        ):
            target.data = self.tau * target.data + (1 - self.tau) * online.data

    def forward(self, x1, x2):
        online_z1 = self.online_projector(self.online_encoder(x1))
        online_z2 = self.online_projector(self.online_encoder(x2))
        pred1 = self.online_predictor(online_z1)
        pred2 = self.online_predictor(online_z2)

        with torch.no_grad():
            target_z1 = self.target_projector(self.target_encoder(x1))
            target_z2 = self.target_projector(self.target_encoder(x2))

        # BYOL損失(対称的なコサイン類似度)
        loss1 = 2 - 2 * F.cosine_similarity(pred1, target_z2.detach()).mean()
        loss2 = 2 - 2 * F.cosine_similarity(pred2, target_z1.detach()).mean()

        return (loss1 + loss2) / 2

4. マスク自己エンコーダ(MAE)

MAE(Masked Autoencoders Are Scalable Vision Learners)はMetaのKaiming Heらが2021年に発表。ViTと組み合わせて非常に効果的です。

4.1 核心的なアイデア

画像パッチの75%をランダムにマスクし、マスクされたピクセルを復元する。

なぜ75%もマスクするのか?

  • 画像の冗長性が高いため、低いマスク率では簡単すぎる
  • 高いマスク率ではモデルがシーンの全体的な理解を学習せざるを得ない
  • BERTの15%とは異なり、画像には空間的な局所性がある

4.2 完全なMAE実装

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.patch_size = patch_size
        num_patches = (img_size // patch_size) ** 2
        self.projection = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        x = self.projection(x)  # (B, embed_dim, H/P, W/P)
        x = x.flatten(2).transpose(1, 2)  # (B, num_patches, embed_dim)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        mlp_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x

class MAE(nn.Module):
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_channels=3,
        encoder_dim=768,
        encoder_depth=12,
        encoder_heads=12,
        decoder_dim=512,
        decoder_depth=8,
        decoder_heads=16,
        mask_ratio=0.75,
    ):
        super().__init__()
        self.mask_ratio = mask_ratio
        self.patch_size = patch_size
        num_patches = (img_size // patch_size) ** 2

        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, encoder_dim)

        self.cls_token = nn.Parameter(torch.zeros(1, 1, encoder_dim))
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, encoder_dim),
            requires_grad=False
        )

        self.encoder_blocks = nn.ModuleList([
            TransformerBlock(encoder_dim, encoder_heads)
            for _ in range(encoder_depth)
        ])
        self.encoder_norm = nn.LayerNorm(encoder_dim)

        self.decoder_embed = nn.Linear(encoder_dim, decoder_dim, bias=True)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))

        self.decoder_pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, decoder_dim),
            requires_grad=False
        )

        self.decoder_blocks = nn.ModuleList([
            TransformerBlock(decoder_dim, decoder_heads)
            for _ in range(decoder_depth)
        ])
        self.decoder_norm = nn.LayerNorm(decoder_dim)
        self.decoder_pred = nn.Linear(decoder_dim, patch_size**2 * in_channels)

    def random_masking(self, x, mask_ratio):
        N, L, D = x.shape
        len_keep = int(L * (1 - mask_ratio))

        noise = torch.rand(N, L, device=x.device)
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(
            x, dim=1,
            index=ids_keep.unsqueeze(-1).repeat(1, 1, D)
        )

        mask = torch.ones(N, L, device=x.device)
        mask[:, :len_keep] = 0
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

    def forward_encoder(self, x, mask_ratio):
        x = self.patch_embed(x)
        x = x + self.pos_embed[:, 1:, :]
        x, mask, ids_restore = self.random_masking(x, mask_ratio)

        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        for block in self.encoder_blocks:
            x = block(x)
        x = self.encoder_norm(x)

        return x, mask, ids_restore

    def forward_decoder(self, x, ids_restore):
        x = self.decoder_embed(x)

        mask_tokens = self.mask_token.repeat(
            x.shape[0],
            ids_restore.shape[1] + 1 - x.shape[1],
            1
        )
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
        x_ = torch.gather(
            x_, dim=1,
            index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
        )
        x = torch.cat([x[:, :1, :], x_], dim=1)

        x = x + self.decoder_pos_embed

        for block in self.decoder_blocks:
            x = block(x)
        x = self.decoder_norm(x)
        x = self.decoder_pred(x)
        x = x[:, 1:, :]  # CLSトークンを削除
        return x

    def forward(self, imgs, mask_ratio=0.75):
        latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
        pred = self.forward_decoder(latent, ids_restore)
        loss = self.forward_loss(imgs, pred, mask)
        return loss, pred, mask

    def forward_loss(self, imgs, pred, mask):
        target = self.patchify(imgs)

        mean = target.mean(dim=-1, keepdim=True)
        var = target.var(dim=-1, keepdim=True)
        target = (target - mean) / (var + 1.e-6).sqrt()

        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)
        loss = (loss * mask).sum() / mask.sum()
        return loss

    def patchify(self, imgs):
        p = self.patch_size
        h = w = imgs.shape[2] // p
        x = imgs.reshape(imgs.shape[0], 3, h, p, w, p)
        x = torch.einsum('nchpwq->nhwpqc', x)
        x = x.reshape(imgs.shape[0], h * w, p**2 * 3)
        return x


# MAEモデルの作成と学習
mae_model = MAE(
    img_size=224,
    patch_size=16,
    encoder_dim=768,
    encoder_depth=12,
    encoder_heads=12,
    decoder_dim=512,
    decoder_depth=8,
    decoder_heads=16,
    mask_ratio=0.75
).cuda()

optimizer = torch.optim.AdamW(
    mae_model.parameters(),
    lr=1.5e-4,
    betas=(0.9, 0.95),
    weight_decay=0.05
)

def train_step(model, imgs, optimizer):
    model.train()
    optimizer.zero_grad()
    loss, pred, mask = model(imgs, mask_ratio=0.75)
    loss.backward()
    optimizer.step()
    return loss.item()

5. DINO

DINO(Self-DIstillation with NO labels)は2021年にFacebook AI Researchが発表。ViTをSSLで学習すると驚くべき特性が現れます。

5.1 主要な発見

DINOで学習したViTのセルフアテンションマップはラベルなしでセマンティックセグメンテーションを実行します。モデルが前景と背景を完璧に分離します。

5.2 自己蒸留アーキテクチャ

学生ネットワーク               教師ネットワーク
      ↑                               ↑
 局所ビュー(多数)             グローバルビュー(2枚)
  • 教師:画像全体のビューを処理、豊かなコンテキスト
  • 学生:局所的なクロップを処理、教師の予測に一致するよう学習
  • 教師の更新:学生のEMA(勾配なし)

5.3 センタリングとシャーペニング

崩壊を防ぐ2つのテクニック:

class DINOHead(nn.Module):
    def __init__(self, in_dim, out_dim=65536, bottleneck_dim=256):
        super().__init__()
        layers = [nn.Linear(in_dim, 2048), nn.GELU()]
        layers += [nn.Linear(2048, 2048), nn.GELU()]
        layers += [nn.Linear(2048, bottleneck_dim)]
        self.mlp = nn.Sequential(*layers)

        self.last_layer = nn.utils.weight_norm(
            nn.Linear(bottleneck_dim, out_dim, bias=False)
        )
        self.last_layer.weight_g.data.fill_(1)
        self.last_layer.weight_g.requires_grad = False

    def forward(self, x):
        x = self.mlp(x)
        x = F.normalize(x, dim=-1)
        x = self.last_layer(x)
        return x


class DINO(nn.Module):
    def __init__(self, backbone, head_out_dim=65536, tau_s=0.1, tau_t=0.04, m=0.996):
        super().__init__()
        self.tau_s = tau_s  # 学生の温度
        self.tau_t = tau_t  # 教師の温度
        self.m = m          # モメンタム係数

        self.student = nn.Sequential(backbone, DINOHead(backbone.output_dim, head_out_dim))
        self.teacher = copy.deepcopy(self.student)

        for param in self.teacher.parameters():
            param.requires_grad = False

        # センタリング(教師の出力の移動平均)
        self.register_buffer("center", torch.zeros(1, head_out_dim))

    @torch.no_grad()
    def update_teacher(self):
        for s_param, t_param in zip(
            self.student.parameters(),
            self.teacher.parameters()
        ):
            t_param.data = self.m * t_param.data + (1 - self.m) * s_param.data

    def dino_loss(self, student_output, teacher_output):
        """DINOクロスエントロピー損失"""
        student_out = student_output / self.tau_s
        student_out = student_out.chunk(2)  # 局所ビュー

        # センタリングと温度シャーペニング
        teacher_out = F.softmax(
            (teacher_output - self.center) / self.tau_t, dim=-1
        )
        teacher_out = teacher_out.detach().chunk(2)  # グローバルビュー

        total_loss = 0
        n_loss_terms = 0

        for iq, q in enumerate(teacher_out):
            for v in range(len(student_out)):
                if v == iq:
                    continue
                loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
                total_loss += loss.mean()
                n_loss_terms += 1

        total_loss /= n_loss_terms

        # センターを更新
        self.center = 0.9 * self.center + 0.1 * teacher_output.mean(dim=0, keepdim=True)

        return total_loss

6. CLIP

CLIP(Contrastive Language-Image Pre-training)は2021年にOpenAIが発表。画像とテキストをペアで学習するマルチモーダルSSLです。

6.1 コアコンセプト

4億枚の画像とキャプションのペアで対照学習を適用:

  • 同じ画像テキストペアの埋め込みを近づける
  • 異なるペアの埋め込みを遠ざける

ゼロショット能力:学習後、新しいカテゴリの画像を直接分類できます(テキストプロンプトを使用)。

6.2 完全なCLIP実装

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPModel, CLIPProcessor
from PIL import Image


def zero_shot_classify(image_path, class_labels, model_name="openai/clip-vit-base-patch32"):
    """CLIPを使用したゼロショット画像分類"""
    model = CLIPModel.from_pretrained(model_name)
    processor = CLIPProcessor.from_pretrained(model_name)
    model.eval()

    image = Image.open(image_path)

    # テキストプロンプトを作成
    text_prompts = [f"a photo of a {label}" for label in class_labels]

    inputs = processor(
        text=text_prompts,
        images=image,
        return_tensors="pt",
        padding=True
    )

    with torch.no_grad():
        outputs = model(**inputs)
        logits_per_image = outputs.logits_per_image
        probs = logits_per_image.softmax(dim=1)

    results = {}
    for label, prob in zip(class_labels, probs[0]):
        results[label] = prob.item()

    return results

# 使用例
labels = ["cat", "dog", "car", "airplane", "bird"]
results = zero_shot_classify("example.jpg", labels)
for label, prob in sorted(results.items(), key=lambda x: -x[1]):
    print(f"{label}: {prob:.4f}")


# ---- 画像テキスト類似度検索 ----
class CLIPRetrieval:
    def __init__(self, model_name="openai/clip-vit-base-patch32"):
        self.model = CLIPModel.from_pretrained(model_name)
        self.processor = CLIPProcessor.from_pretrained(model_name)
        self.model.eval()

        self.image_embeddings = None
        self.image_paths = []

    def index_images(self, image_paths):
        """画像データベースをインデックス化"""
        self.image_paths = image_paths
        embeddings = []

        for path in image_paths:
            image = Image.open(path)
            inputs = self.processor(images=image, return_tensors="pt")
            with torch.no_grad():
                embedding = self.model.get_image_features(**inputs)
            embeddings.append(F.normalize(embedding, dim=1))

        self.image_embeddings = torch.cat(embeddings, dim=0)
        print(f"{len(image_paths)}枚の画像をインデックス化しました")

    def search(self, query_text, top_k=5):
        """テキストクエリで画像を検索"""
        inputs = self.processor(text=[query_text], return_tensors="pt", padding=True)
        with torch.no_grad():
            text_embedding = self.model.get_text_features(**inputs)
        text_embedding = F.normalize(text_embedding, dim=1)

        similarities = torch.matmul(text_embedding, self.image_embeddings.T).squeeze()
        top_indices = similarities.argsort(descending=True)[:top_k]

        results = [
            (self.image_paths[i], similarities[i].item())
            for i in top_indices
        ]
        return results

7. BEiT

BEiT(BERT Pre-Training of Image Transformers)は2021年にMicrosoftが発表。画像を離散トークンに変換してBERT風に学習します。

7.1 離散視覚トークン(dVAE)

BEiTのコア:生のピクセルを予測する代わりに離散トークンを予測します。

ステップ1: dVAEが画像を離散トークンに変換(語彙サイズ8192ステップ2: ViTがマスクされた各パッチのトークンを予測
from transformers import BeitForMaskedImageModeling, BeitImageProcessor
import torch

class BEiTPretraining:
    def __init__(self, model_name="microsoft/beit-base-patch16-224"):
        self.model = BeitForMaskedImageModeling.from_pretrained(model_name)
        self.processor = BeitImageProcessor.from_pretrained(model_name)

    def create_bool_masked_pos(self, num_patches, mask_ratio=0.4):
        num_masked = int(num_patches * mask_ratio)
        mask = torch.zeros(num_patches, dtype=torch.bool)
        masked_indices = torch.randperm(num_patches)[:num_masked]
        mask[masked_indices] = True
        return mask

    def forward(self, images):
        inputs = self.processor(images, return_tensors="pt")
        pixel_values = inputs.pixel_values

        num_patches = (224 // 16) ** 2  # 196
        bool_masked_pos = self.create_bool_masked_pos(num_patches)
        bool_masked_pos = bool_masked_pos.unsqueeze(0).expand(
            pixel_values.size(0), -1
        )

        outputs = self.model(
            pixel_values=pixel_values,
            bool_masked_pos=bool_masked_pos
        )

        return outputs.loss

8. 言語モデル事前学習

SSLはNLPに長い歴史があります。

8.1 GPT - 次のトークン予測

左から右へと次のトークンを予測する自己回帰的アプローチ。

import torch
import torch.nn as nn
from torch.nn import functional as F

class GPTLanguageModel(nn.Module):
    def __init__(self, vocab_size, block_size, n_embd, n_head, n_layer, dropout=0.1):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, n_embd)
        self.position_embedding = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[
            TransformerBlock(n_embd, n_head, dropout=dropout)
            for _ in range(n_layer)
        ])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        self.block_size = block_size

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding(idx)
        pos_emb = self.position_embedding(torch.arange(T, device=idx.device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        if targets is None:
            return logits, None

        B, T, C = logits.shape
        loss = F.cross_entropy(logits.view(B * T, C), targets.view(B * T))
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature

            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = float('-inf')

            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

8.2 BERT - マスク言語モデリング

双方向コンテキストを学習します。15%のトークンがマスクされ、復元する必要があります。

from transformers import BertForMaskedLM, BertTokenizer
import torch

class BERTPretraining:
    def __init__(self, model_name="bert-base-uncased"):
        self.model = BertForMaskedLM.from_pretrained(model_name)
        self.tokenizer = BertTokenizer.from_pretrained(model_name)

    def mask_tokens(self, inputs, mlm_probability=0.15):
        """MLMのためにトークンをマスク"""
        labels = inputs.clone()
        probability_matrix = torch.full(labels.shape, mlm_probability)

        # 特殊トークンはマスクしない
        special_tokens_mask = [
            self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
            for val in labels.tolist()
        ]
        special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)

        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # マスクされていないトークンの損失を計算しない

        # 80%は[MASK]に置換
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

        # 10%はランダムなトークンに置換
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        return inputs, labels

9. 実践的な応用

9.1 線形評価とファインチューニング

class LinearProbe(nn.Module):
    """SSL表現を評価するための線形分類器"""
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.linear = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.linear(x)


def extract_features(model, dataloader, device='cuda'):
    """事前学習済みモデルから特徴を抽出"""
    model.eval()
    features_list = []
    labels_list = []

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            features = model.encode(images)
            features_list.append(features.cpu())
            labels_list.append(labels)

    return torch.cat(features_list), torch.cat(labels_list)


def few_shot_evaluation(ssl_model, train_loader, val_loader,
                        n_shots=[1, 5, 10, 100], device='cuda'):
    """クラスごとのnサンプルで評価"""
    ssl_model.eval()

    all_features, all_labels = extract_features(ssl_model, train_loader, device)
    val_features, val_labels = extract_features(ssl_model, val_loader, device)

    results = {}
    num_classes = len(all_labels.unique())

    for n_shot in n_shots:
        selected_indices = []
        for cls in range(num_classes):
            cls_indices = (all_labels == cls).nonzero(as_tuple=True)[0]
            selected = cls_indices[torch.randperm(len(cls_indices))[:n_shot]]
            selected_indices.append(selected)

        selected_indices = torch.cat(selected_indices)
        train_feats = all_features[selected_indices]
        train_labs = all_labels[selected_indices]

        probe = LinearProbe(all_features.shape[1], num_classes).to(device)
        optimizer = torch.optim.Adam(probe.parameters(), lr=1e-3)
        criterion = nn.CrossEntropyLoss()

        train_feats = train_feats.to(device)
        train_labs = train_labs.to(device)

        for epoch in range(100):
            probe.train()
            logits = probe(train_feats)
            loss = criterion(logits, train_labs)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        probe.eval()
        with torch.no_grad():
            val_logits = probe(val_features.to(device))
            preds = val_logits.argmax(dim=1).cpu()
            acc = (preds == val_labels).float().mean().item()

        results[n_shot] = acc
        print(f"{n_shot}-shot精度: {acc * 100:.2f}%")

    return results


def knn_evaluation(ssl_model, train_loader, val_loader, k=20, device='cuda'):
    """SSL表現の品質をk-NNで評価(パラメータフリー)"""
    ssl_model.eval()

    train_features, train_labels = extract_features(ssl_model, train_loader, device)
    val_features, val_labels = extract_features(ssl_model, val_loader, device)

    train_features = nn.functional.normalize(train_features, dim=1)
    val_features = nn.functional.normalize(val_features, dim=1)

    correct = 0
    total = 0
    batch_size = 256

    for i in range(0, len(val_features), batch_size):
        batch_feats = val_features[i:i+batch_size].to(device)
        batch_labels = val_labels[i:i+batch_size]

        sims = torch.mm(batch_feats, train_features.to(device).T)
        topk_sims, topk_indices = sims.topk(k, dim=1)

        topk_labels = train_labels[topk_indices.cpu()]
        preds = topk_labels.mode(dim=1).values

        correct += (preds == batch_labels).sum().item()
        total += len(batch_labels)

    accuracy = correct / total
    print(f"k-NN ({k}) 精度: {accuracy * 100:.2f}%")
    return accuracy

9.2 ドメイン特化SSL

# 医療画像(胸部X線の例)
class MedicalSSL(nn.Module):
    """医療画像のための自己教師あり学習"""
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone

        # 医療画像には弱いオーグメンテーションを使用
        self.augmentation = T.Compose([
            T.RandomResizedCrop(224, scale=(0.8, 1.0)),
            T.RandomHorizontalFlip(p=0.5),
            # 医療画像では色変化を最小限に
            T.RandomApply([T.ColorJitter(0.1, 0.1, 0.0, 0.0)], p=0.3),
            T.ToTensor(),
            T.Normalize([0.485], [0.229])  # グレースケールX線
        ])

    def forward(self, x):
        v1 = self.augmentation(x)
        v2 = self.augmentation(x)
        return v1, v2


# 衛星画像
class SatelliteSSL:
    """衛星画像のオーグメンテーション(地理的特性を保持)"""
    def __init__(self):
        self.transform = T.Compose([
            T.RandomResizedCrop(256, scale=(0.4, 1.0)),
            # 衛星画像:回転不変性が重要
            T.RandomRotation(90),
            T.RandomHorizontalFlip(),
            T.RandomVerticalFlip(),
            # 時刻・季節変化をシミュレート
            T.ColorJitter(brightness=0.4, contrast=0.4),
            T.ToTensor(),
        ])

9.3 HubからのSSL事前学習モデルの使用

from transformers import (
    ViTModel, ViTFeatureExtractor,
    CLIPModel, CLIPProcessor
)
from PIL import Image
import torch

# MAE事前学習済みViTを読み込む
def load_mae_pretrained():
    model = ViTModel.from_pretrained("facebook/vit-mae-base")
    feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")
    return model, feature_extractor

# DINO ViT特徴抽出
def extract_dino_features(image_path):
    model = ViTModel.from_pretrained("facebook/dino-vits16")
    processor = ViTFeatureExtractor.from_pretrained("facebook/dino-vits16")

    image = Image.open(image_path)
    inputs = processor(images=image, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)

    # CLSトークン(グローバル画像表現)
    cls_features = outputs.last_hidden_state[:, 0, :]
    # パッチごとの特徴(空間情報あり)
    patch_features = outputs.last_hidden_state[:, 1:, :]

    return cls_features, patch_features

# CLIPで特徴抽出
def clip_feature_extraction(images, texts=None):
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    if texts:
        inputs = processor(text=texts, images=images, return_tensors="pt", padding=True)
        with torch.no_grad():
            outputs = model(**inputs)
        return outputs.image_embeds, outputs.text_embeds
    else:
        inputs = processor(images=images, return_tensors="pt")
        with torch.no_grad():
            image_features = model.get_image_features(**inputs)
        return image_features

まとめ

自己教師あり学習はAIを変革しています。重要なポイント:

手法核心的なアイデア強み
SimCLR2020強いオーグメンテーション + 対照学習シンプルで強力
MoCo2020モメンタムキューバッチサイズ非依存
BYOL2020ネガティブなしの対照学習バッチサイズ非依存
MAE202175%マスキング + ピクセル復元効率的、ViT最適
DINO2021自己蒸留セグメンテーション特性
CLIP2021画像テキスト対照学習ゼロショット分類
BEiT2021離散トークンマスキングBERTスタイル
DINOv22023大規模 + iBOT汎用特徴抽出器

学習ロードマップ

  1. SimCLRを実装して対照学習を理解する
  2. MAEを学習してマスキングアプローチを把握する
  3. CLIPをマルチモーダル学習に応用する
  4. 自分のドメインデータでファインチューニングする

SSLはLLM事前学習とコンピュータビジョンの基盤となっています。膨大なラベルなしデータを効果的に活用するアプローチとして、SSLはAI進歩の中心的な推進力であり続けるでしょう。


参考文献