Skip to content

Split View: 자기지도 학습(Self-Supervised Learning) 완전 정복: SimCLR, MAE, DINO, CLIP

|

자기지도 학습(Self-Supervised Learning) 완전 정복: SimCLR, MAE, DINO, CLIP

목차

  1. 자기지도 학습이란?
  2. 초기 SSL 방법들
  3. 대조 학습 (Contrastive Learning)
  4. Masked Autoencoder (MAE)
  5. DINO - 라벨 없는 자기 증류
  6. CLIP - 텍스트-이미지 대조 학습
  7. BEiT - BERT의 이미지 버전
  8. 언어 모델의 사전학습
  9. 실전 활용

1. 자기지도 학습이란?

1.1 학습 패러다임의 비교

딥러닝의 세 가지 주요 학습 패러다임을 이해하는 것부터 시작하겠습니다.

**지도 학습 (Supervised Learning)**은 레이블이 붙은 데이터로 학습합니다. ImageNet 1000만 장의 이미지에 인간이 하나하나 레이블을 달았고, 이 데이터로 ResNet, ViT 같은 모델을 학습합니다. 문제는 레이블 수집 비용이 엄청나다는 것입니다. 의료 이미지 한 장에 전문 의사가 레이블을 다는 데 수십 분이 걸릴 수 있습니다.

**비지도 학습 (Unsupervised Learning)**은 레이블 없이 데이터의 구조를 파악합니다. K-Means 클러스터링, PCA, GAN이 여기에 해당합니다. 하지만 전통적인 비지도 학습은 다운스트림 태스크에 바로 쓸 수 있는 표현(representation)을 학습하기 어려웠습니다.

**자기지도 학습 (Self-Supervised Learning, SSL)**은 두 세계의 장점을 결합합니다. 레이블 없는 대규모 데이터를 활용하면서도, 데이터 자체에서 지도 신호를 만들어냅니다.

핵심 아이디어: 데이터 자체가 레이블이다.

인터넷에는 수십억 장의 이미지, 수조 개의 단어가 존재합니다. 자기지도 학습은 이 방대한 비레이블 데이터를 활용해 강력한 표현을 학습합니다.

1.2 레이블 부족 문제

현실 세계에서는 레이블 수집이 매우 어렵습니다.

도메인레이블 수집 비용이유
의료 이미지매우 높음전문 의사 필요
위성 이미지높음전문 지리 지식 필요
법률 문서높음법률 전문가 필요
산업 결함 감지높음희귀한 불량 샘플
음성 인식중간전사(transcription) 필요

SSL은 이 문제를 근본적으로 해결합니다. 레이블 없는 대규모 데이터로 먼저 사전학습(pretraining)하고, 소량의 레이블 데이터로 파인튜닝(fine-tuning)하는 방식입니다.

1.3 프리텍스트 태스크 (Pretext Task)

SSL의 핵심은 프리텍스트 태스크입니다. 데이터 자체에서 인위적인 예측 문제를 만들어 모델이 의미 있는 표현을 학습하도록 합니다.

좋은 프리텍스트 태스크의 조건:

  1. 레이블 없이 자동으로 만들 수 있어야 함
  2. 태스크를 잘 풀려면 의미 있는 표현이 필요해야 함
  3. 너무 쉽거나 너무 어렵지 않아야 함

예시:

  • 이미지 회전 예측: 이미지를 0°, 90°, 180°, 270° 회전 후, 몇 도 회전했는지 예측
  • 마스킹: 이미지/텍스트 일부를 가리고 복원
  • 대조 학습: 같은 이미지의 두 뷰가 유사하도록 학습

1.4 SSL의 응용 분야

SSL은 현대 AI의 근간이 되었습니다:

  • GPT, BERT: 텍스트 SSL → ChatGPT, Gemini의 기반
  • CLIP: 이미지-텍스트 SSL → DALL-E, Stable Diffusion의 기반
  • MAE, DINO: 이미지 SSL → 의료 영상, 자율주행의 기반
  • wav2vec: 오디오 SSL → 음성 인식의 기반

2. 초기 SSL 방법들

초기 SSL 연구들은 다양한 프리텍스트 태스크를 탐구했습니다.

2.1 회전 예측 (Rotation Prediction)

2018년 Gidaris et al.이 제안한 방법입니다. 이미지를 4가지 각도(0°, 90°, 180°, 270°)로 회전시키고, 어떤 회전이 적용되었는지 분류하는 문제입니다.

import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image

class RotationSSL(nn.Module):
    def __init__(self, backbone, num_classes=4):
        super().__init__()
        self.backbone = backbone
        self.classifier = nn.Linear(backbone.output_dim, num_classes)

    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)

def create_rotation_dataset(images):
    """이미지를 4가지 방향으로 회전하여 (image, label) 쌍 생성"""
    rotated_images = []
    labels = []

    for img in images:
        for k in range(4):  # 0, 90, 180, 270도
            # PIL Image를 k*90도 회전
            rotated = T.functional.rotate(img, k * 90)
            rotated_images.append(rotated)
            labels.append(k)

    return rotated_images, labels

# 학습 루프
def train_rotation_ssl(model, dataloader, optimizer, epochs=10):
    criterion = nn.CrossEntropyLoss()
    model.train()

    for epoch in range(epochs):
        total_loss = 0
        for images, labels in dataloader:
            images, labels = images.cuda(), labels.cuda()

            optimizer.zero_grad()
            logits = model(images)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

이 방법의 한계: 자연 이미지에는 명확한 방향성이 있지만, 하늘 사진이나 텍스처 이미지에는 효과적이지 않습니다.

2.2 Jigsaw Puzzle 풀기

2016년 Noroozi & Favaro가 제안했습니다. 이미지를 3x3 그리드로 자른 후 무작위로 섞고, 원래 순서를 예측합니다.

import itertools
import numpy as np

class JigsawSSL(nn.Module):
    def __init__(self, backbone, num_permutations=100):
        super().__init__()
        self.backbone = backbone
        # 각 패치를 독립적으로 처리
        self.patch_encoder = nn.Sequential(
            backbone,
            nn.Linear(backbone.output_dim, 512)
        )
        # 9개 패치를 모두 처리 후 순열 분류
        self.classifier = nn.Sequential(
            nn.Linear(512 * 9, 4096),
            nn.ReLU(),
            nn.Linear(4096, num_permutations)
        )

    def forward(self, patches):
        # patches: (batch, 9, C, H, W)
        batch_size = patches.size(0)
        patch_features = []

        for i in range(9):
            feat = self.patch_encoder(patches[:, i])  # (batch, 512)
            patch_features.append(feat)

        # 모든 패치 특징 연결
        combined = torch.cat(patch_features, dim=1)  # (batch, 512*9)
        return self.classifier(combined)

def create_jigsaw_dataset(image, grid_size=3):
    """이미지를 grid_size x grid_size 패치로 분할하고 섞기"""
    h, w = image.shape[-2:]
    patch_h, patch_w = h // grid_size, w // grid_size

    patches = []
    for i in range(grid_size):
        for j in range(grid_size):
            patch = image[...,
                         i*patch_h:(i+1)*patch_h,
                         j*patch_w:(j+1)*patch_w]
            patches.append(patch)

    # 미리 정의된 순열 집합에서 선택
    permutation_idx = np.random.randint(0, len(PREDEFINED_PERMUTATIONS))
    perm = PREDEFINED_PERMUTATIONS[permutation_idx]
    shuffled = [patches[p] for p in perm]

    return torch.stack(shuffled), permutation_idx

2.3 컬러화 (Colorization)

2016년 Zhang et al.이 제안한 방법입니다. 흑백 이미지를 입력받아 색상을 예측합니다.

class ColorizationSSL(nn.Module):
    def __init__(self):
        super().__init__()
        # L 채널 입력 (밝기)
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.ReLU(),
        )
        # ab 채널 출력 (색상)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(64, 2, 3, padding=1),  # ab 채널
            nn.Tanh()
        )

    def forward(self, l_channel):
        features = self.encoder(l_channel)
        ab_channels = self.decoder(features)
        return ab_channels

# LAB 색 공간에서 작업
from skimage.color import rgb2lab, lab2rgb

def prepare_colorization_data(rgb_image):
    lab = rgb2lab(rgb_image)
    l_channel = lab[:, :, 0:1]   # 밝기 채널
    ab_channels = lab[:, :, 1:]  # 색상 채널
    return l_channel, ab_channels

2.4 다음 프레임 예측

비디오에서 시간적 연속성을 이용합니다. 이전 프레임들로 다음 프레임을 예측합니다.

class VideoSSL(nn.Module):
    def __init__(self, frame_encoder, temporal_model):
        super().__init__()
        self.frame_encoder = frame_encoder
        self.temporal_model = temporal_model  # LSTM 또는 Transformer
        self.decoder = nn.ConvTranspose2d(...)

    def forward(self, frames):
        # frames: (batch, T, C, H, W)
        T = frames.size(1)

        # 각 프레임 인코딩
        frame_features = []
        for t in range(T - 1):
            feat = self.frame_encoder(frames[:, t])
            frame_features.append(feat)

        # 시간 모델링
        features = torch.stack(frame_features, dim=1)
        next_feat = self.temporal_model(features)

        # 다음 프레임 디코딩
        next_frame_pred = self.decoder(next_feat[:, -1])
        return next_frame_pred

3. 대조 학습

대조 학습(Contrastive Learning)은 현대 SSL의 핵심입니다. 2020년 전후로 폭발적으로 발전했습니다.

3.1 핵심 아이디어

같은 것은 가깝게, 다른 것은 멀게.

같은 이미지의 두 가지 다른 뷰(crop, flip, color jitter 등)는 임베딩 공간에서 가까워야 하고, 다른 이미지들은 멀어야 합니다.

이미지 x
├── 증강 뷰 v1 → z1 (임베딩)
└── 증강 뷰 v2 → z2 (임베딩)

목표: z1과 z2는 가깝게, z1과 다른 이미지의 z3은 멀게

3.2 InfoNCE Loss

대조 학습의 표준 손실 함수입니다. NT-Xent(Normalized Temperature-scaled Cross Entropy)라고도 합니다.

수식 (배치 내 N개 이미지, 각 이미지당 2개 뷰):

L=12Ni=1N[logesim(zi,zj(i))/τkiesim(zi,zk)/τ]\mathcal{L} = -\frac{1}{2N} \sum_{i=1}^{N} \left[ \log \frac{e^{sim(z_i, z_{j(i)})/\tau}}{\sum_{k \neq i} e^{sim(z_i, z_k)/\tau}} \right]

여기서 sim은 코사인 유사도, τ는 온도 파라미터입니다.

import torch
import torch.nn.functional as F

def info_nce_loss(z1, z2, temperature=0.5):
    """
    InfoNCE / NT-Xent Loss 구현
    z1, z2: (batch_size, embedding_dim) - 같은 이미지의 두 뷰
    """
    batch_size = z1.size(0)

    # L2 정규화
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)

    # 두 뷰를 합쳐서 2N x D 행렬 생성
    # [z1_1, z1_2, ..., z1_N, z2_1, z2_2, ..., z2_N]
    z = torch.cat([z1, z2], dim=0)  # (2N, D)

    # 모든 쌍의 유사도 계산
    similarity = torch.mm(z, z.t()) / temperature  # (2N, 2N)

    # 자기 자신과의 유사도는 제외 (대각선 -inf)
    mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
    similarity.masked_fill_(mask, float('-inf'))

    # 포지티브 쌍: i번째와 i+N번째, i+N번째와 i번째
    labels = torch.cat([
        torch.arange(batch_size, 2 * batch_size),
        torch.arange(batch_size)
    ]).to(z.device)

    loss = F.cross_entropy(similarity, labels)
    return loss

# 사용 예시
z1 = torch.randn(32, 128)  # 배치 32, 임베딩 128차원
z2 = torch.randn(32, 128)
loss = info_nce_loss(z1, z2, temperature=0.5)
print(f"InfoNCE Loss: {loss.item():.4f}")

3.3 SimCLR

2020년 Google이 발표한 Simple Framework for Contrastive Learning입니다. 단순하지만 강력합니다.

핵심 구성요소:

  1. 데이터 증강 (강한 증강이 핵심)
  2. 인코더 (ResNet)
  3. Projection Head (MLP)
  4. NT-Xent Loss
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as T

class SimCLR(nn.Module):
    def __init__(self, backbone='resnet50', projection_dim=128, temperature=0.5):
        super().__init__()
        self.temperature = temperature

        # 백본 인코더
        if backbone == 'resnet50':
            resnet = models.resnet50(pretrained=False)
            self.encoder = nn.Sequential(*list(resnet.children())[:-1])  # FC 제거
            self.feature_dim = 2048
        elif backbone == 'resnet18':
            resnet = models.resnet18(pretrained=False)
            self.encoder = nn.Sequential(*list(resnet.children())[:-1])
            self.feature_dim = 512

        # Projection Head (중요: 파인튜닝 시 제거)
        self.projection_head = nn.Sequential(
            nn.Linear(self.feature_dim, self.feature_dim),
            nn.ReLU(inplace=True),
            nn.Linear(self.feature_dim, projection_dim)
        )

    def encode(self, x):
        h = self.encoder(x)
        h = h.squeeze(-1).squeeze(-1)  # (batch, feature_dim)
        return h

    def forward(self, x):
        h = self.encode(x)
        z = self.projection_head(h)
        return z

    def contrastive_loss(self, z1, z2):
        return info_nce_loss(z1, z2, self.temperature)


class SimCLRDataAugmentation:
    """SimCLR의 강한 데이터 증강"""
    def __init__(self, image_size=224, s=1.0):
        color_jitter = T.ColorJitter(
            brightness=0.8*s,
            contrast=0.8*s,
            saturation=0.8*s,
            hue=0.2*s
        )
        self.transform = T.Compose([
            T.RandomResizedCrop(image_size, scale=(0.2, 1.0)),
            T.RandomHorizontalFlip(p=0.5),
            T.RandomApply([color_jitter], p=0.8),
            T.RandomGrayscale(p=0.2),
            T.GaussianBlur(kernel_size=int(0.1 * image_size)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
        ])

    def __call__(self, x):
        return self.transform(x), self.transform(x)


# SimCLR 학습
def train_simclr(model, dataloader, optimizer, epochs=200):
    model.train()
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=epochs
    )

    for epoch in range(epochs):
        total_loss = 0
        for (x1, x2), _ in dataloader:
            x1, x2 = x1.cuda(), x2.cuda()

            z1 = model(x1)
            z2 = model(x2)

            loss = model.contrastive_loss(z1, z2)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        scheduler.step()
        avg_loss = total_loss / len(dataloader)

        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")


# 전체 파이프라인
model = SimCLR(backbone='resnet50', projection_dim=128).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-6)
augmentation = SimCLRDataAugmentation(image_size=224)

# 파인튜닝 시 projection head 제거
def get_finetuning_model(pretrained_simclr, num_classes):
    # 인코더만 사용
    encoder = pretrained_simclr.encoder
    classifier = nn.Linear(pretrained_simclr.feature_dim, num_classes)
    return nn.Sequential(encoder, nn.Flatten(), classifier)

3.4 MoCo (Momentum Contrast)

Facebook AI가 2020년 발표했습니다. 큰 배치 없이도 많은 네거티브 샘플을 사용할 수 있는 방법입니다.

핵심 아이디어: 모멘텀 업데이트로 안정적인 키 인코더 유지 + 큐(Queue)로 네거티브 샘플 관리

class MoCo(nn.Module):
    def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07):
        super().__init__()
        self.K = K  # 큐 크기
        self.m = m  # 모멘텀 계수
        self.T = T  # 온도

        # 쿼리 인코더와 키 인코더
        self.encoder_q = base_encoder(num_classes=dim)
        self.encoder_k = base_encoder(num_classes=dim)

        # 키 인코더는 그래디언트 없이 모멘텀 업데이트
        for param_q, param_k in zip(
            self.encoder_q.parameters(),
            self.encoder_k.parameters()
        ):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False

        # 네거티브 샘플 큐
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = F.normalize(self.queue, dim=0)
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """키 인코더를 모멘텀으로 업데이트"""
        for param_q, param_k in zip(
            self.encoder_q.parameters(),
            self.encoder_k.parameters()
        ):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        """큐에 새 키 추가, 오래된 키 제거"""
        batch_size = keys.shape[0]
        ptr = int(self.queue_ptr)

        self.queue[:, ptr:ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K
        self.queue_ptr[0] = ptr

    def forward(self, im_q, im_k):
        # 쿼리 임베딩
        q = self.encoder_q(im_q)
        q = F.normalize(q, dim=1)

        # 키 임베딩 (그래디언트 없음)
        with torch.no_grad():
            self._momentum_update_key_encoder()
            k = self.encoder_k(im_k)
            k = F.normalize(k, dim=1)

        # 포지티브 로짓: (batch, 1)
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        # 네거티브 로짓: (batch, K)
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

        logits = torch.cat([l_pos, l_neg], dim=1) / self.T
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

        loss = F.cross_entropy(logits, labels)

        self._dequeue_and_enqueue(k)
        return loss

3.5 BYOL (Bootstrap Your Own Latent)

2020년 DeepMind가 발표한 획기적인 방법입니다. 네거티브 샘플 없이 작동합니다.

핵심: 온라인 네트워크가 타깃 네트워크의 표현을 예측. 타깃은 모멘텀 업데이트.

class BYOL(nn.Module):
    def __init__(self, backbone, projection_dim=256, prediction_dim=128, tau=0.996):
        super().__init__()
        self.tau = tau  # 모멘텀 계수

        # 온라인 네트워크
        self.online_encoder = backbone
        self.online_projector = nn.Sequential(
            nn.Linear(2048, 4096), nn.BatchNorm1d(4096), nn.ReLU(),
            nn.Linear(4096, projection_dim)
        )
        self.online_predictor = nn.Sequential(
            nn.Linear(projection_dim, 4096), nn.BatchNorm1d(4096), nn.ReLU(),
            nn.Linear(4096, prediction_dim)
        )

        # 타깃 네트워크 (그래디언트 없음)
        self.target_encoder = copy.deepcopy(backbone)
        self.target_projector = copy.deepcopy(self.online_projector)

        for param in self.target_encoder.parameters():
            param.requires_grad = False
        for param in self.target_projector.parameters():
            param.requires_grad = False

    @torch.no_grad()
    def update_target(self):
        """EMA 업데이트"""
        for online, target in zip(
            self.online_encoder.parameters(),
            self.target_encoder.parameters()
        ):
            target.data = self.tau * target.data + (1 - self.tau) * online.data

        for online, target in zip(
            self.online_projector.parameters(),
            self.target_projector.parameters()
        ):
            target.data = self.tau * target.data + (1 - self.tau) * online.data

    def forward(self, x1, x2):
        # 온라인 예측
        online_feat1 = self.online_encoder(x1).squeeze()
        online_proj1 = self.online_projector(online_feat1)
        online_pred1 = self.online_predictor(online_proj1)

        online_feat2 = self.online_encoder(x2).squeeze()
        online_proj2 = self.online_projector(online_feat2)
        online_pred2 = self.online_predictor(online_proj2)

        # 타깃 (stop gradient)
        with torch.no_grad():
            target_feat1 = self.target_encoder(x1).squeeze()
            target_proj1 = self.target_projector(target_feat1)

            target_feat2 = self.target_encoder(x2).squeeze()
            target_proj2 = self.target_projector(target_feat2)

        # BYOL Loss: cosine similarity 최대화
        loss1 = byol_loss(online_pred1, target_proj2.detach())
        loss2 = byol_loss(online_pred2, target_proj1.detach())

        return (loss1 + loss2).mean()

def byol_loss(p, z):
    """Negative cosine similarity"""
    p = F.normalize(p, dim=-1)
    z = F.normalize(z, dim=-1)
    return -(p * z).sum(dim=-1)

3.6 SimSiam

2021년 Facebook AI가 발표. BYOL보다 더 단순합니다. 모멘텀 없이도 collapse 없이 학습됩니다.

class SimSiam(nn.Module):
    def __init__(self, backbone, dim=2048, pred_dim=512):
        super().__init__()
        self.encoder = backbone

        # Projector
        self.projector = nn.Sequential(
            nn.Linear(dim, dim, bias=False),
            nn.BatchNorm1d(dim),
            nn.ReLU(inplace=True),
            nn.Linear(dim, dim, bias=False),
            nn.BatchNorm1d(dim),
            nn.ReLU(inplace=True),
            nn.Linear(dim, dim, bias=False),
            nn.BatchNorm1d(dim, affine=False)  # 마지막 BN: affine=False
        )

        # Predictor
        self.predictor = nn.Sequential(
            nn.Linear(dim, pred_dim, bias=False),
            nn.BatchNorm1d(pred_dim),
            nn.ReLU(inplace=True),
            nn.Linear(pred_dim, dim)
        )

    def forward(self, x1, x2):
        z1 = self.projector(self.encoder(x1).squeeze())
        z2 = self.projector(self.encoder(x2).squeeze())

        p1 = self.predictor(z1)
        p2 = self.predictor(z2)

        # Stop gradient - 비대칭 구조의 핵심
        loss = simsiam_loss(p1, z2.detach()) / 2 + \
               simsiam_loss(p2, z1.detach()) / 2
        return loss

def simsiam_loss(p, z):
    """Negative cosine similarity"""
    z = z.detach()  # Stop gradient
    p = F.normalize(p, dim=1)
    z = F.normalize(z, dim=1)
    return -(p * z).sum(dim=1).mean()

4. Masked Autoencoder (MAE)

2021년 Kaiming He et al.이 발표한 MAE는 자연어 처리의 BERT를 이미지에 적용한 것입니다.

4.1 핵심 아이디어

이미지의 75%를 마스킹하고 나머지 25%로 전체 이미지를 복원.

왜 75%라는 높은 마스킹 비율인가?

  • 텍스트는 각 토큰이 높은 의미 밀도를 가져 15% 마스킹으로 충분
  • 이미지는 인접 픽셀 간 높은 중복성 → 75%는 마스킹해야 진정한 이해 필요

4.2 비대칭 인코더-디코더

MAE의 혁신: 인코더는 가시 패치만 처리, 디코더는 전체를 복원.

입력 이미지 (196개 패치)
75% 마스킹
49개 가시 패치 → 인코더 (ViT-Large) → 인코딩
196패치 (마스크 토큰 포함)디코더 (얕은 ViT) → 픽셀 복원

인코더는 원래 크기의 25%만 처리 → 3-4배 빠른 학습

4.3 완전한 MAE 구현

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

class PatchEmbedding(nn.Module):
    """이미지를 패치 임베딩으로 변환"""
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2

        self.projection = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        # x: (batch, C, H, W)
        x = self.projection(x)  # (batch, embed_dim, H/p, W/p)
        x = x.flatten(2)        # (batch, embed_dim, num_patches)
        x = x.transpose(1, 2)  # (batch, num_patches, embed_dim)
        return x


class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., dropout=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        mlp_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        # Self-attention
        norm_x = self.norm1(x)
        attn_out, _ = self.attn(norm_x, norm_x, norm_x)
        x = x + attn_out
        # MLP
        x = x + self.mlp(self.norm2(x))
        return x


class MAE(nn.Module):
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_channels=3,
        # 인코더 설정
        encoder_dim=768,
        encoder_depth=12,
        encoder_heads=12,
        # 디코더 설정 (인코더보다 훨씬 작음)
        decoder_dim=512,
        decoder_depth=8,
        decoder_heads=16,
        # 마스킹 비율
        mask_ratio=0.75,
    ):
        super().__init__()
        self.mask_ratio = mask_ratio
        self.patch_size = patch_size
        num_patches = (img_size // patch_size) ** 2

        # 패치 임베딩
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, encoder_dim)

        # [CLS] 토큰과 위치 임베딩 (인코더)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, encoder_dim))
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, encoder_dim),
            requires_grad=False
        )

        # 인코더 (ViT)
        self.encoder_blocks = nn.ModuleList([
            TransformerBlock(encoder_dim, encoder_heads)
            for _ in range(encoder_depth)
        ])
        self.encoder_norm = nn.LayerNorm(encoder_dim)

        # 인코더-디코더 연결
        self.decoder_embed = nn.Linear(encoder_dim, decoder_dim, bias=True)

        # 마스크 토큰
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))

        # 디코더 위치 임베딩
        self.decoder_pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, decoder_dim),
            requires_grad=False
        )

        # 디코더 (얕은 ViT)
        self.decoder_blocks = nn.ModuleList([
            TransformerBlock(decoder_dim, decoder_heads)
            for _ in range(decoder_depth)
        ])
        self.decoder_norm = nn.LayerNorm(decoder_dim)

        # 픽셀 예측 헤드
        self.decoder_pred = nn.Linear(decoder_dim, patch_size**2 * in_channels)

        self._init_weights()

    def _init_weights(self):
        """Sinusoidal 위치 임베딩 초기화"""
        # 위치 임베딩 초기화 (sinusoidal)
        pos_embed = self._get_2d_sincos_pos_embed(
            self.pos_embed.shape[-1],
            int(self.patch_embed.num_patches**0.5)
        )
        self.pos_embed.data.copy_(
            torch.from_numpy(pos_embed).float().unsqueeze(0)
        )
        # 나머지 초기화
        nn.init.normal_(self.cls_token, std=.02)
        nn.init.normal_(self.mask_token, std=.02)

    def _get_2d_sincos_pos_embed(self, embed_dim, grid_size):
        """2D sinusoidal 위치 임베딩"""
        import numpy as np
        grid_h = np.arange(grid_size, dtype=np.float32)
        grid_w = np.arange(grid_size, dtype=np.float32)
        grid = np.meshgrid(grid_w, grid_h)
        grid = np.stack(grid, axis=0)
        grid = grid.reshape([2, 1, grid_size, grid_size])

        emb_h = self._get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
        emb_w = self._get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
        emb = np.concatenate([emb_h, emb_w], axis=1)
        # CLS 토큰을 위한 제로 임베딩 추가
        emb = np.concatenate([np.zeros([1, embed_dim]), emb], axis=0)
        return emb

    def _get_1d_sincos_pos_embed_from_grid(self, embed_dim, pos):
        import numpy as np
        omega = np.arange(embed_dim // 2, dtype=np.float32)
        omega /= embed_dim / 2.
        omega = 1. / 10000**omega
        pos = pos.reshape(-1)
        out = np.einsum('m,d->md', pos, omega)
        emb_sin = np.sin(out)
        emb_cos = np.cos(out)
        emb = np.concatenate([emb_sin, emb_cos], axis=1)
        return emb

    def random_masking(self, x, mask_ratio):
        """랜덤 마스킹: 일부 패치만 남김"""
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))

        # 랜덤 노이즈로 셔플 인덱스 생성
        noise = torch.rand(N, L, device=x.device)
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # 가시 패치만 선택
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(
            x, dim=1,
            index=ids_keep.unsqueeze(-1).repeat(1, 1, D)
        )

        # 마스크 생성 (1: 마스킹됨, 0: 가시)
        mask = torch.ones(N, L, device=x.device)
        mask[:, :len_keep] = 0
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

    def forward_encoder(self, x, mask_ratio):
        """인코더: 가시 패치만 처리"""
        # 패치 임베딩
        x = self.patch_embed(x)

        # 위치 임베딩 추가 (CLS 제외)
        x = x + self.pos_embed[:, 1:, :]

        # 랜덤 마스킹
        x, mask, ids_restore = self.random_masking(x, mask_ratio)

        # CLS 토큰 추가
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # 인코더 트랜스포머
        for block in self.encoder_blocks:
            x = block(x)
        x = self.encoder_norm(x)

        return x, mask, ids_restore

    def forward_decoder(self, x, ids_restore):
        """디코더: 마스크 토큰 포함 전체 복원"""
        # 인코더 차원 → 디코더 차원
        x = self.decoder_embed(x)

        # 마스크 토큰 추가
        mask_tokens = self.mask_token.repeat(
            x.shape[0],
            ids_restore.shape[1] + 1 - x.shape[1],
            1
        )
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # CLS 제외
        x_ = torch.gather(
            x_, dim=1,
            index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
        )
        x = torch.cat([x[:, :1, :], x_], dim=1)  # CLS 다시 추가

        # 디코더 위치 임베딩
        x = x + self.decoder_pos_embed

        # 디코더 트랜스포머
        for block in self.decoder_blocks:
            x = block(x)
        x = self.decoder_norm(x)

        # 픽셀 예측
        x = self.decoder_pred(x)
        x = x[:, 1:, :]  # CLS 토큰 제거
        return x

    def forward(self, imgs, mask_ratio=0.75):
        # 인코딩 (마스킹 포함)
        latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
        # 디코딩 (픽셀 복원)
        pred = self.forward_decoder(latent, ids_restore)
        # 손실 계산
        loss = self.forward_loss(imgs, pred, mask)
        return loss, pred, mask

    def forward_loss(self, imgs, pred, mask):
        """마스킹된 패치에 대한 MSE Loss"""
        # 타깃: 패치화된 이미지
        target = self.patchify(imgs)

        # 정규화 (patch normalization)
        mean = target.mean(dim=-1, keepdim=True)
        var = target.var(dim=-1, keepdim=True)
        target = (target - mean) / (var + 1.e-6).sqrt()

        # 마스킹된 패치에 대해서만 손실 계산
        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)  # 패치당 평균
        loss = (loss * mask).sum() / mask.sum()  # 마스킹된 패치 평균
        return loss

    def patchify(self, imgs):
        """이미지를 패치 시퀀스로 변환"""
        p = self.patch_size
        h = w = imgs.shape[2] // p
        x = imgs.reshape(imgs.shape[0], 3, h, p, w, p)
        x = torch.einsum('nchpwq->nhwpqc', x)
        x = x.reshape(imgs.shape[0], h * w, p**2 * 3)
        return x


# MAE 모델 생성 및 학습
mae_model = MAE(
    img_size=224,
    patch_size=16,
    encoder_dim=768,
    encoder_depth=12,
    encoder_heads=12,
    decoder_dim=512,
    decoder_depth=8,
    decoder_heads=16,
    mask_ratio=0.75
).cuda()

optimizer = torch.optim.AdamW(
    mae_model.parameters(),
    lr=1.5e-4,
    betas=(0.9, 0.95),
    weight_decay=0.05
)

# 학습 스텝
def train_step(model, imgs, optimizer):
    model.train()
    optimizer.zero_grad()

    loss, pred, mask = model(imgs, mask_ratio=0.75)
    loss.backward()
    optimizer.step()

    return loss.item()

5. DINO

DINO(Self-DIstillation with NO labels)는 2021년 Facebook AI Research가 발표했습니다. ViT를 SSL로 학습했을 때 놀라운 특성이 나타납니다.

5.1 DINO의 핵심 발견

DINO로 학습된 ViT의 자기 주의(self-attention) 맵은 **레이블 없이도 의미론적 분할(segmentation)**을 수행합니다. 배경과 전경을 완벽히 구분합니다.

5.2 Self-Distillation 구조

Student Network (학생)        Teacher Network (선생)
      ↑                              ↑
 Local Views (여러 개)          Global Views (2)
  • 선생: 전체 이미지 뷰 처리, 더 풍부한 컨텍스트
  • 학생: 로컬 크롭 처리, 선생의 예측을 따라 배움
  • 선생 업데이트: 학생의 EMA (그래디언트 없음)

5.3 Centering과 Sharpening

Collapse 방지를 위한 두 가지 기법:

class DINOHead(nn.Module):
    """DINO Projection Head"""
    def __init__(self, in_dim, out_dim=65536, bottleneck_dim=256):
        super().__init__()
        layers = [nn.Linear(in_dim, 2048), nn.GELU()]
        layers += [nn.Linear(2048, 2048), nn.GELU()]
        layers += [nn.Linear(2048, bottleneck_dim)]
        self.mlp = nn.Sequential(*layers)

        # Weight normalization된 마지막 레이어
        self.last_layer = nn.utils.weight_norm(
            nn.Linear(bottleneck_dim, out_dim, bias=False)
        )
        self.last_layer.weight_g.data.fill_(1)
        self.last_layer.weight_g.requires_grad = False

    def forward(self, x):
        x = self.mlp(x)
        x = F.normalize(x, dim=-1)
        x = self.last_layer(x)
        return x


class DINO(nn.Module):
    def __init__(self, backbone, out_dim=65536, teacher_temp=0.04, student_temp=0.1):
        super().__init__()
        self.student_temp = student_temp
        self.teacher_temp = teacher_temp

        # Student 네트워크
        self.student = backbone
        self.student_head = DINOHead(self.student.embed_dim, out_dim)

        # Teacher 네트워크 (EMA)
        self.teacher = copy.deepcopy(backbone)
        self.teacher_head = copy.deepcopy(self.student_head)

        for p in self.teacher.parameters():
            p.requires_grad = False
        for p in self.teacher_head.parameters():
            p.requires_grad = False

        # Centering을 위한 center 벡터
        self.register_buffer("center", torch.zeros(1, out_dim))

    @torch.no_grad()
    def update_teacher(self, momentum):
        """EMA 업데이트"""
        for student_ps, teacher_ps in zip(
            self.student.parameters(), self.teacher.parameters()
        ):
            teacher_ps.data.mul_(momentum).add_(
                (1 - momentum) * student_ps.data
            )
        for student_ps, teacher_ps in zip(
            self.student_head.parameters(), self.teacher_head.parameters()
        ):
            teacher_ps.data.mul_(momentum).add_(
                (1 - momentum) * student_ps.data
            )

    @torch.no_grad()
    def update_center(self, teacher_output):
        """Centering 업데이트 (collapse 방지)"""
        batch_center = teacher_output.mean(dim=0, keepdim=True)
        self.center = self.center * 0.9 + batch_center * 0.1

    def dino_loss(self, student_output, teacher_output):
        """DINO Cross-entropy Loss"""
        student_out = student_output / self.student_temp
        student_out = student_out.chunk(2)  # global crops

        # Teacher: centering + sharpening
        teacher_out = F.softmax(
            (teacher_output - self.center) / self.teacher_temp, dim=-1
        )
        teacher_out = teacher_out.detach().chunk(2)

        total_loss = 0
        n_loss_terms = 0

        for iq, q in enumerate(teacher_out):
            for v in range(len(student_out)):
                if v == iq:
                    continue  # 같은 뷰는 제외
                loss = torch.sum(
                    -q * F.log_softmax(student_out[v], dim=-1), dim=-1
                )
                total_loss += loss.mean()
                n_loss_terms += 1

        return total_loss / n_loss_terms

    def forward(self, global_views, local_views, momentum=0.996):
        # Student: 모든 뷰 처리
        all_views = global_views + local_views
        student_output = torch.cat([
            self.student_head(self.student(view))
            for view in all_views
        ])

        # Teacher: 글로벌 뷰만 처리
        with torch.no_grad():
            teacher_output = torch.cat([
                self.teacher_head(self.teacher(view))
                for view in global_views
            ])

        loss = self.dino_loss(student_output, teacher_output)

        # 업데이트
        self.update_center(teacher_output)
        self.update_teacher(momentum)

        return loss


# DINO 데이터 증강 (멀티 크롭)
class DINOAugmentation:
    def __init__(self, global_crops_scale=(0.4, 1.), local_crops_scale=(0.05, 0.4),
                 n_local_crops=8, image_size=224, local_size=96):
        self.n_local_crops = n_local_crops

        # 전역 크롭 (2개, 큰 뷰)
        self.global_transform = T.Compose([
            T.RandomResizedCrop(image_size, scale=global_crops_scale, interpolation=T.InterpolationMode.BICUBIC),
            T.RandomHorizontalFlip(),
            T.RandomApply([T.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
            T.RandomGrayscale(p=0.2),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])

        # 로컬 크롭 (8개, 작은 뷰)
        self.local_transform = T.Compose([
            T.RandomResizedCrop(local_size, scale=local_crops_scale, interpolation=T.InterpolationMode.BICUBIC),
            T.RandomHorizontalFlip(),
            T.RandomApply([T.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
            T.RandomGrayscale(p=0.2),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])

    def __call__(self, image):
        global_views = [self.global_transform(image) for _ in range(2)]
        local_views = [self.local_transform(image) for _ in range(self.n_local_crops)]
        return global_views, local_views

5.4 DINOv2

2023년 발표된 DINOv2는 더 강력합니다:

  • 1억 4200만 장의 큐레이션된 데이터로 학습
  • iBOT(Image BERT Pre-Training with Online Tokenizer) 손실 추가
  • 더 안정적인 학습

6. CLIP

CLIP(Contrastive Language-Image Pretraining)은 2021년 OpenAI가 발표한 혁신적인 모델입니다.

6.1 텍스트-이미지 대조 학습

인터넷에서 수집한 4억 쌍의 (이미지, 텍스트) 데이터로 학습합니다.

이미지 인코더 (ViT / ResNet)
  이미지 임베딩

텍스트 인코더 (Transformer)
  텍스트 임베딩

목표: 매칭되는 (이미지, 텍스트) 쌍은 유사하게
     매칭 안 되는 쌍은 다르게

6.2 CLIP 손실 함수

def clip_loss(image_embeddings, text_embeddings, temperature):
    """
    CLIP 대칭 대조 학습 손실
    image_embeddings: (N, D)
    text_embeddings: (N, D)
    """
    # L2 정규화
    image_embeddings = F.normalize(image_embeddings, dim=1)
    text_embeddings = F.normalize(text_embeddings, dim=1)

    # 유사도 행렬 계산: (N, N)
    logits = torch.matmul(image_embeddings, text_embeddings.T) / temperature

    # 정답: 대각선 원소 (i번째 이미지 = i번째 텍스트)
    labels = torch.arange(len(logits)).to(logits.device)

    # 이미지 → 텍스트, 텍스트 → 이미지 양방향 손실
    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.T, labels)

    return (loss_i2t + loss_t2i) / 2

6.3 CLIP 완전 구현

from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import requests
import torch
import torch.nn.functional as F

# 사전학습된 CLIP 모델 로드
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# ---- Zero-shot 이미지 분류 ----
def zero_shot_classify(image_path, candidate_labels):
    """
    레이블 예제 없이 이미지 분류
    """
    image = Image.open(image_path)

    # 텍스트 프롬프트 생성
    texts = [f"a photo of a {label}" for label in candidate_labels]

    # 인코딩
    inputs = processor(
        text=texts,
        images=image,
        return_tensors="pt",
        padding=True
    )

    with torch.no_grad():
        outputs = model(**inputs)

    # 유사도 계산
    logits_per_image = outputs.logits_per_image
    probs = logits_per_image.softmax(dim=1)

    results = {}
    for label, prob in zip(candidate_labels, probs[0]):
        results[label] = prob.item()

    return results

# 사용 예시
labels = ["cat", "dog", "car", "airplane", "bird"]
results = zero_shot_classify("example.jpg", labels)
for label, prob in sorted(results.items(), key=lambda x: -x[1]):
    print(f"{label}: {prob:.4f}")


# ---- 이미지-텍스트 유사도 검색 ----
class CLIPRetrieval:
    def __init__(self, model_name="openai/clip-vit-base-patch32"):
        self.model = CLIPModel.from_pretrained(model_name)
        self.processor = CLIPProcessor.from_pretrained(model_name)
        self.model.eval()

        self.image_embeddings = None
        self.image_paths = []

    def index_images(self, image_paths):
        """이미지 데이터베이스 인덱싱"""
        self.image_paths = image_paths
        embeddings = []

        for path in image_paths:
            image = Image.open(path)
            inputs = self.processor(images=image, return_tensors="pt")
            with torch.no_grad():
                embedding = self.model.get_image_features(**inputs)
            embeddings.append(F.normalize(embedding, dim=1))

        self.image_embeddings = torch.cat(embeddings, dim=0)
        print(f"{len(image_paths)}개 이미지 인덱싱 완료")

    def search(self, query_text, top_k=5):
        """텍스트 쿼리로 이미지 검색"""
        inputs = self.processor(text=[query_text], return_tensors="pt", padding=True)
        with torch.no_grad():
            text_embedding = self.model.get_text_features(**inputs)
        text_embedding = F.normalize(text_embedding, dim=1)

        # 코사인 유사도
        similarities = torch.matmul(text_embedding, self.image_embeddings.T).squeeze()
        top_indices = similarities.argsort(descending=True)[:top_k]

        results = [
            (self.image_paths[i], similarities[i].item())
            for i in top_indices
        ]
        return results

# 멀티모달 임베딩 시각화
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np

def visualize_clip_embeddings(images, texts):
    """CLIP 임베딩 t-SNE 시각화"""
    image_inputs = processor(images=images, return_tensors="pt")
    text_inputs = processor(text=texts, return_tensors="pt", padding=True)

    with torch.no_grad():
        image_embs = F.normalize(model.get_image_features(**image_inputs), dim=1)
        text_embs = F.normalize(model.get_text_features(**text_inputs), dim=1)

    # 이미지와 텍스트 임베딩 합치기
    all_embs = torch.cat([image_embs, text_embs], dim=0).numpy()

    # t-SNE 차원 축소
    tsne = TSNE(n_components=2, random_state=42)
    reduced = tsne.fit_transform(all_embs)

    n = len(images)
    plt.figure(figsize=(12, 8))
    plt.scatter(reduced[:n, 0], reduced[:n, 1], c='blue', label='Images', alpha=0.7)
    plt.scatter(reduced[n:, 0], reduced[n:, 1], c='red', label='Texts', alpha=0.7)

    for i, text in enumerate(texts):
        plt.annotate(text, reduced[n+i], fontsize=8)

    plt.legend()
    plt.title('CLIP Embeddings (t-SNE)')
    plt.savefig('clip_tsne.png', dpi=150, bbox_inches='tight')
    plt.show()

7. BEiT

BEiT(BERT Pre-Training of Image Transformers)는 2021년 Microsoft가 발표했습니다. 이미지를 이산 토큰으로 변환한 후 BERT처럼 학습합니다.

7.1 이산 시각 토큰 (dVAE)

BEiT의 핵심: 이미지를 연속 픽셀이 아닌 이산 토큰으로 예측합니다.

1단계: dVAE로 이미지 → 이산 토큰 변환 (어휘 8192)
2단계: ViT가 마스킹된 패치의 토큰을 예측
from transformers import BeitForMaskedImageModeling, BeitImageProcessor
import torch

# BEiT 마스크드 이미지 모델링
class BEiTPretraining:
    def __init__(self, model_name="microsoft/beit-base-patch16-224"):
        self.model = BeitForMaskedImageModeling.from_pretrained(model_name)
        self.processor = BeitImageProcessor.from_pretrained(model_name)

    def create_bool_masked_pos(self, num_patches, mask_ratio=0.4):
        """마스킹 위치 생성"""
        num_masked = int(num_patches * mask_ratio)
        mask = torch.zeros(num_patches, dtype=torch.bool)
        # 블록 마스킹 (BEiT v2)
        masked_indices = torch.randperm(num_patches)[:num_masked]
        mask[masked_indices] = True
        return mask

    def forward(self, images):
        inputs = self.processor(images, return_tensors="pt")
        pixel_values = inputs.pixel_values

        num_patches = (224 // 16) ** 2  # 196
        bool_masked_pos = self.create_bool_masked_pos(num_patches)
        bool_masked_pos = bool_masked_pos.unsqueeze(0).expand(
            pixel_values.size(0), -1
        )

        outputs = self.model(
            pixel_values=pixel_values,
            bool_masked_pos=bool_masked_pos
        )

        return outputs.loss

8. 언어 모델의 사전학습

자연어 처리에서 SSL은 오랜 역사를 가집니다.

8.1 GPT - 다음 단어 예측

자기회귀(Autoregressive) 방식으로 왼쪽에서 오른쪽으로 다음 토큰을 예측합니다.

import torch
import torch.nn as nn
from torch.nn import functional as F

class GPTLanguageModel(nn.Module):
    def __init__(self, vocab_size, block_size, n_embd, n_head, n_layer, dropout=0.1):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, n_embd)
        self.position_embedding = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[
            TransformerBlock(n_embd, n_head, dropout=dropout)
            for _ in range(n_layer)
        ])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

        self.block_size = block_size

    def forward(self, idx, targets=None):
        B, T = idx.shape

        tok_emb = self.token_embedding(idx)  # (B, T, C)
        pos_emb = self.position_embedding(
            torch.arange(T, device=idx.device)
        )  # (T, C)
        x = tok_emb + pos_emb

        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)  # (B, T, vocab_size)

        if targets is None:
            return logits, None

        # 다음 토큰 예측 손실
        B, T, C = logits.shape
        logits_flat = logits.view(B * T, C)
        targets_flat = targets.view(B * T)
        loss = F.cross_entropy(logits_flat, targets_flat)

        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature

            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = float('-inf')

            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

8.2 BERT - 마스킹 언어 모델

양방향 문맥을 학습합니다. 15%의 토큰을 마스킹하고 복원합니다.

from transformers import BertForMaskedLM, BertTokenizer
import torch

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForMaskedLM.from_pretrained("bert-base-uncased")

def create_mlm_input(text, mask_prob=0.15):
    """마스크드 언어 모델 입력 생성"""
    tokens = tokenizer.encode(text, return_tensors="pt")
    labels = tokens.clone()

    # 15% 무작위 마스킹
    rand = torch.rand(tokens.shape)
    mask_arr = (rand < mask_prob) & (tokens != tokenizer.cls_token_id) & \
               (tokens != tokenizer.sep_token_id)

    # 마스킹 전략:
    # - 80%: [MASK] 토큰으로 교체
    # - 10%: 랜덤 토큰으로 교체
    # - 10%: 원래 토큰 유지
    for i, masked in enumerate(mask_arr[0]):
        if masked:
            prob = torch.rand(1).item()
            if prob < 0.8:
                tokens[0][i] = tokenizer.mask_token_id
            elif prob < 0.9:
                tokens[0][i] = torch.randint(len(tokenizer), (1,))
            # 나머지 10%: 원래 토큰 유지

    # 마스킹되지 않은 위치는 손실 계산에서 제외
    labels[~mask_arr] = -100

    return tokens, labels

# MLM 손실 계산
text = "The quick brown fox jumps over the lazy dog"
input_ids, labels = create_mlm_input(text)

with torch.no_grad():
    outputs = model(input_ids=input_ids, labels=labels)
    print(f"MLM Loss: {outputs.loss.item():.4f}")

# 마스킹된 토큰 예측
input_text = "The [MASK] brown fox"
inputs = tokenizer(input_text, return_tensors="pt")

with torch.no_grad():
    logits = model(**inputs).logits

mask_idx = (inputs.input_ids == tokenizer.mask_token_id).nonzero()[0][1]
top_5 = logits[0, mask_idx].topk(5)
for token_id, score in zip(top_5.indices, top_5.values):
    print(f"{tokenizer.decode([token_id])}: {score:.3f}")

8.3 T5 - 텍스트-텍스트 변환

from transformers import T5ForConditionalGeneration, T5Tokenizer

tokenizer = T5Tokenizer.from_pretrained("t5-base")
model = T5ForConditionalGeneration.from_pretrained("t5-base")

# T5는 모든 태스크를 텍스트-텍스트로 통일
tasks = {
    "번역": "translate English to Korean: The weather is nice today",
    "요약": "summarize: Scientists have discovered a new species of deep-sea fish...",
    "감성 분석": "sentiment analysis: This movie was absolutely fantastic!",
    "질의응답": "question: What is the capital of France? context: Paris is the capital..."
}

for task_name, input_text in tasks.items():
    inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
    with torch.no_grad():
        output_ids = model.generate(
            inputs.input_ids,
            max_length=128,
            num_beams=4,
            early_stopping=True
        )
    output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    print(f"\n[{task_name}]")
    print(f"입력: {input_text[:60]}...")
    print(f"출력: {output}")

9. 실전 활용

9.1 소수 레이블 학습 (Few-shot Learning)

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import numpy as np

class LinearProbe(nn.Module):
    """SSL 특징 위에 선형 분류기"""
    def __init__(self, feature_dim, num_classes):
        super().__init__()
        self.linear = nn.Linear(feature_dim, num_classes)

    def forward(self, x):
        return self.linear(x)


def extract_features(model, dataloader, device='cuda'):
    """사전학습된 모델로 특징 추출"""
    model.eval()
    features_list = []
    labels_list = []

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            features = model.encode(images)
            features_list.append(features.cpu())
            labels_list.append(labels)

    return torch.cat(features_list), torch.cat(labels_list)


def few_shot_evaluation(ssl_model, train_loader, val_loader,
                        n_shots=[1, 5, 10, 100], device='cuda'):
    """Few-shot 평가: n개 레이블로 선형 분류기 학습"""
    ssl_model.eval()

    # 전체 특징 추출
    all_features, all_labels = extract_features(ssl_model, train_loader, device)
    val_features, val_labels = extract_features(ssl_model, val_loader, device)

    results = {}
    num_classes = len(all_labels.unique())

    for n_shot in n_shots:
        # n_shot개 레이블만 사용
        selected_indices = []
        for cls in range(num_classes):
            cls_indices = (all_labels == cls).nonzero(as_tuple=True)[0]
            selected = cls_indices[torch.randperm(len(cls_indices))[:n_shot]]
            selected_indices.append(selected)

        selected_indices = torch.cat(selected_indices)
        train_feats = all_features[selected_indices]
        train_labs = all_labels[selected_indices]

        # 선형 분류기 학습
        probe = LinearProbe(all_features.shape[1], num_classes).to(device)
        optimizer = torch.optim.Adam(probe.parameters(), lr=1e-3)
        criterion = nn.CrossEntropyLoss()

        train_feats = train_feats.to(device)
        train_labs = train_labs.to(device)

        for epoch in range(100):
            probe.train()
            logits = probe(train_feats)
            loss = criterion(logits, train_labs)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # 검증
        probe.eval()
        with torch.no_grad():
            val_logits = probe(val_features.to(device))
            preds = val_logits.argmax(dim=1).cpu()
            acc = (preds == val_labels).float().mean().item()

        results[n_shot] = acc
        print(f"{n_shot}-shot 정확도: {acc * 100:.2f}%")

    return results


# k-NN 평가 (파라미터 없는 평가)
def knn_evaluation(ssl_model, train_loader, val_loader, k=20, device='cuda'):
    """k-NN 분류기로 SSL 표현 품질 평가"""
    ssl_model.eval()

    train_features, train_labels = extract_features(ssl_model, train_loader, device)
    val_features, val_labels = extract_features(ssl_model, val_loader, device)

    # 정규화
    train_features = nn.functional.normalize(train_features, dim=1)
    val_features = nn.functional.normalize(val_features, dim=1)

    # 배치별 k-NN 분류
    correct = 0
    total = 0
    batch_size = 256

    for i in range(0, len(val_features), batch_size):
        batch_feats = val_features[i:i+batch_size].to(device)
        batch_labels = val_labels[i:i+batch_size]

        # 코사인 유사도 계산
        sims = torch.mm(batch_feats, train_features.to(device).T)
        topk_sims, topk_indices = sims.topk(k, dim=1)

        # 다수결 투표
        topk_labels = train_labels[topk_indices.cpu()]
        preds = topk_labels.mode(dim=1).values

        correct += (preds == batch_labels).sum().item()
        total += len(batch_labels)

    accuracy = correct / total
    print(f"k-NN ({k}) 정확도: {accuracy * 100:.2f}%")
    return accuracy

9.2 도메인 특화 SSL

# 의료 이미지용 SSL (CheXpert 흉부 X-ray 예시)
class MedicalSSL(nn.Module):
    """의료 이미지 특화 자기지도 학습"""
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone

        # 의료 이미지 특화 증강 (약한 증강)
        self.augmentation = T.Compose([
            T.RandomResizedCrop(224, scale=(0.8, 1.0)),  # 작은 크롭 변화
            T.RandomHorizontalFlip(p=0.5),
            # 의료 이미지는 색상 변화 최소화
            T.RandomApply([T.ColorJitter(0.1, 0.1, 0.0, 0.0)], p=0.3),
            T.ToTensor(),
            T.Normalize([0.485], [0.229])  # 흑백 X-ray
        ])

    def forward(self, x):
        v1 = self.augmentation(x)
        v2 = self.augmentation(x)
        return v1, v2


# 위성 이미지용 SSL
class SatelliteSSL:
    """위성 이미지 특화 증강 (지리적 특성 보존)"""
    def __init__(self):
        self.transform = T.Compose([
            T.RandomResizedCrop(256, scale=(0.4, 1.0)),
            # 위성 이미지: 회전 불변성 중요
            T.RandomRotation(90),
            T.RandomHorizontalFlip(),
            T.RandomVerticalFlip(),
            # 시간대/계절 변화 시뮬레이션
            T.ColorJitter(brightness=0.4, contrast=0.4),
            T.ToTensor(),
        ])


# 학습 성능 비교
def compare_ssl_methods(dataset, num_epochs=100):
    methods = {
        'SimCLR': SimCLR(backbone='resnet50'),
        'BYOL': BYOL(backbone=resnet50()),
        'MAE': MAE(img_size=224),
    }

    results = {}
    for name, model in methods.items():
        print(f"\n{name} 학습 중...")
        model = model.cuda()
        optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

        for epoch in range(num_epochs):
            train_one_epoch(model, dataset, optimizer)

        # 100-shot 평가
        acc = few_shot_evaluation(model, dataset, n_shots=[100])
        results[name] = acc[100]
        print(f"{name} 100-shot 정확도: {acc[100]*100:.2f}%")

    return results

9.3 SSL 모델 허브 활용

# Hugging Face에서 사전학습된 SSL 모델 사용
from transformers import (
    ViTModel, ViTFeatureExtractor,
    DeiTModel,
    CLIPModel, CLIPProcessor
)
from PIL import Image
import torch

# ViT (MAE로 사전학습된 버전)
def load_mae_pretrained():
    model = ViTModel.from_pretrained("facebook/vit-mae-base")
    feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")
    return model, feature_extractor

# DINO ViT 특징 추출
def extract_dino_features(image_path):
    from transformers import ViTModel
    model = ViTModel.from_pretrained("facebook/dino-vits16")
    processor = ViTFeatureExtractor.from_pretrained("facebook/dino-vits16")

    image = Image.open(image_path)
    inputs = processor(images=image, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)

    # [CLS] 토큰 특징 (전체 이미지 표현)
    cls_features = outputs.last_hidden_state[:, 0, :]
    # 패치별 특징 (공간 정보 포함)
    patch_features = outputs.last_hidden_state[:, 1:, :]

    return cls_features, patch_features

# CLIP으로 이미지 특징 추출 후 다운스트림 태스크
def clip_feature_extraction(images, texts=None):
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    if texts:
        inputs = processor(text=texts, images=images, return_tensors="pt", padding=True)
        with torch.no_grad():
            outputs = model(**inputs)
        return outputs.image_embeds, outputs.text_embeds
    else:
        inputs = processor(images=images, return_tensors="pt")
        with torch.no_grad():
            image_features = model.get_image_features(**inputs)
        return image_features

마무리

자기지도 학습은 AI의 판도를 바꾸고 있습니다. 핵심 정리:

방법연도핵심 아이디어강점
SimCLR2020강한 증강 + 대조 학습단순하고 강력
MoCo2020모멘텀 큐배치 크기 독립
BYOL2020네거티브 없는 대조 학습배치 크기 독립
MAE202175% 마스킹 + 픽셀 복원효율적, ViT 최적
DINO2021Self-distillation세그멘테이션 특성
CLIP2021이미지-텍스트 대조Zero-shot 분류
BEiT2021이산 토큰 마스킹BERT 스타일
DINOv22023대규모 + iBOT범용 특징 추출

학습 로드맵:

  1. SimCLR 구현으로 대조 학습 이해
  2. MAE로 마스킹 접근법 이해
  3. CLIP으로 멀티모달 학습 적용
  4. 도메인 데이터에 파인튜닝

SSL은 이제 LLM의 사전학습, 컴퓨터 비전의 근간이 되었습니다. 레이블 없는 대규모 데이터를 효과적으로 활용하는 SSL은 앞으로도 AI 발전의 핵심 동력이 될 것입니다.


참고 자료

Self-Supervised Learning Complete Guide: SimCLR, MAE, DINO, CLIP

Table of Contents

  1. What is Self-Supervised Learning?
  2. Early SSL Methods
  3. Contrastive Learning
  4. Masked Autoencoder (MAE)
  5. DINO - Self-Distillation with No Labels
  6. CLIP - Contrastive Language-Image Pretraining
  7. BEiT - BERT for Images
  8. Language Model Pretraining
  9. Practical Applications

1. What is Self-Supervised Learning?

1.1 Comparing Learning Paradigms

Let's start by understanding the three major learning paradigms in deep learning.

Supervised Learning trains on labeled data. ImageNet's 10 million images each carry a human-annotated label, and models like ResNet and ViT are trained on this data. The problem is that label collection is extremely expensive. A single medical image can take a specialist dozens of minutes to annotate.

Unsupervised Learning discovers patterns in data without labels. K-Means clustering, PCA, and GANs fall into this category. However, traditional unsupervised learning struggles to learn representations directly usable for downstream tasks.

Self-Supervised Learning (SSL) combines the best of both worlds. It leverages large amounts of unlabeled data while generating supervisory signals from the data itself.

Core idea: The data itself is the label.

The internet contains billions of images and trillions of words. SSL exploits this vast unlabeled data to learn powerful representations.

1.2 The Label Scarcity Problem

In the real world, label collection is extremely difficult.

DomainLabel Collection CostReason
Medical imagesVery highRequires specialist physicians
Satellite imageryHighRequires geographic expertise
Legal documentsHighRequires legal experts
Industrial defect detectionHighRare defect samples
Speech recognitionMediumTranscription required

SSL fundamentally solves this problem. The approach is to first pretrain on large-scale unlabeled data, then fine-tune with a small amount of labeled data.

1.3 Pretext Tasks

The key to SSL is the pretext task. We create artificial prediction problems from the data itself to force the model to learn meaningful representations.

Conditions for a good pretext task:

  1. Must be constructable automatically without labels
  2. Solving the task well must require meaningful representations
  3. Must not be too easy or too hard

Examples:

  • Rotation prediction: Rotate an image by 0°, 90°, 180°, 270° and predict how many degrees it was rotated
  • Masking: Cover part of an image/text and restore it
  • Contrastive learning: Make two views of the same image similar in embedding space

1.4 Applications of SSL

SSL has become the foundation of modern AI:

  • GPT, BERT: Text SSL → basis for ChatGPT, Gemini
  • CLIP: Image-text SSL → basis for DALL-E, Stable Diffusion
  • MAE, DINO: Image SSL → basis for medical imaging, autonomous driving
  • wav2vec: Audio SSL → basis for speech recognition

2. Early SSL Methods

Early SSL research explored a variety of pretext tasks.

2.1 Rotation Prediction

Proposed by Gidaris et al. in 2018. The model predicts which of four rotations (0°, 90°, 180°, 270°) was applied to an image.

import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image

class RotationSSL(nn.Module):
    def __init__(self, backbone, num_classes=4):
        super().__init__()
        self.backbone = backbone
        self.classifier = nn.Linear(backbone.output_dim, num_classes)

    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)

def create_rotation_dataset(images):
    """Create (image, label) pairs by rotating images in 4 directions"""
    rotated_images = []
    labels = []

    for img in images:
        for k in range(4):  # 0, 90, 180, 270 degrees
            rotated = T.functional.rotate(img, k * 90)
            rotated_images.append(rotated)
            labels.append(k)

    return rotated_images, labels

def train_rotation_ssl(model, dataloader, optimizer, epochs=10):
    criterion = nn.CrossEntropyLoss()
    model.train()

    for epoch in range(epochs):
        total_loss = 0
        for images, labels in dataloader:
            images, labels = images.cuda(), labels.cuda()
            optimizer.zero_grad()
            logits = model(images)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

Limitation: While natural images have clear orientations, textures and sky images may not benefit.

2.2 Jigsaw Puzzle Solving

Proposed by Noroozi & Favaro in 2016. An image is divided into a 3x3 grid and shuffled; the model must predict the original ordering.

import itertools
import numpy as np

class JigsawSSL(nn.Module):
    def __init__(self, backbone, num_permutations=100):
        super().__init__()
        self.backbone = backbone
        self.patch_encoder = nn.Sequential(
            backbone,
            nn.Linear(backbone.output_dim, 512)
        )
        self.classifier = nn.Sequential(
            nn.Linear(512 * 9, 4096),
            nn.ReLU(),
            nn.Linear(4096, num_permutations)
        )

    def forward(self, patches):
        # patches: (batch, 9, C, H, W)
        batch_size = patches.size(0)
        patch_features = []

        for i in range(9):
            feat = self.patch_encoder(patches[:, i])  # (batch, 512)
            patch_features.append(feat)

        combined = torch.cat(patch_features, dim=1)  # (batch, 512*9)
        return self.classifier(combined)

def create_jigsaw_dataset(image, grid_size=3):
    """Split image into grid_size x grid_size patches and shuffle"""
    h, w = image.shape[-2:]
    patch_h, patch_w = h // grid_size, w // grid_size

    patches = []
    for i in range(grid_size):
        for j in range(grid_size):
            patch = image[...,
                         i*patch_h:(i+1)*patch_h,
                         j*patch_w:(j+1)*patch_w]
            patches.append(patch)

    permutation_idx = np.random.randint(0, len(PREDEFINED_PERMUTATIONS))
    perm = PREDEFINED_PERMUTATIONS[permutation_idx]
    shuffled = [patches[p] for p in perm]

    return torch.stack(shuffled), permutation_idx

2.3 Colorization

Proposed by Zhang et al. in 2016. A grayscale image is given as input; the model predicts the colors.

class ColorizationSSL(nn.Module):
    def __init__(self):
        super().__init__()
        # L channel input (lightness)
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.ReLU(),
        )
        # ab channel output (color)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(64, 2, 3, padding=1),  # ab channels
            nn.Tanh()
        )

    def forward(self, l_channel):
        features = self.encoder(l_channel)
        ab_channels = self.decoder(features)
        return ab_channels

from skimage.color import rgb2lab, lab2rgb

def prepare_colorization_data(rgb_image):
    lab = rgb2lab(rgb_image)
    l_channel = lab[:, :, 0:1]   # lightness channel
    ab_channels = lab[:, :, 1:]  # color channels
    return l_channel, ab_channels

3. Contrastive Learning

Contrastive Learning is the core of modern SSL. It exploded in development around 2020.

3.1 Core Idea

Pull similar things together, push different things apart.

Two different views (crop, flip, color jitter, etc.) of the same image should be close together in embedding space, while different images should be far apart.

Image x
├── Augmented view v1 → z1 (embedding)
└── Augmented view v2 → z2 (embedding)

Goal: z1 and z2 close together, z1 and z3 (from different image) far apart

3.2 InfoNCE Loss

The standard loss function for contrastive learning — also called NT-Xent (Normalized Temperature-scaled Cross Entropy).

L=12Ni=1N[logesim(zi,zj(i))/τkiesim(zi,zk)/τ]\mathcal{L} = -\frac{1}{2N} \sum_{i=1}^{N} \left[ \log \frac{e^{sim(z_i, z_{j(i)})/\tau}}{\sum_{k \neq i} e^{sim(z_i, z_k)/\tau}} \right]

where sim is cosine similarity and τ is the temperature parameter.

import torch
import torch.nn.functional as F

def info_nce_loss(z1, z2, temperature=0.5):
    """
    InfoNCE / NT-Xent Loss
    z1, z2: (batch_size, embedding_dim) - two views of the same images
    """
    batch_size = z1.size(0)

    # L2 normalize
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)

    # Concatenate: [z1_1, ..., z1_N, z2_1, ..., z2_N]
    z = torch.cat([z1, z2], dim=0)  # (2N, D)

    # Compute pairwise similarities
    similarity = torch.mm(z, z.t()) / temperature  # (2N, 2N)

    # Exclude self-similarity (diagonal = -inf)
    mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
    similarity.masked_fill_(mask, float('-inf'))

    # Positive pairs: index i and i+N
    labels = torch.cat([
        torch.arange(batch_size, 2 * batch_size),
        torch.arange(batch_size)
    ]).to(z.device)

    loss = F.cross_entropy(similarity, labels)
    return loss

# Usage
z1 = torch.randn(32, 128)
z2 = torch.randn(32, 128)
loss = info_nce_loss(z1, z2, temperature=0.5)
print(f"InfoNCE Loss: {loss.item():.4f}")

3.3 SimCLR

Published by Google in 2020. Simple Framework for Contrastive Learning. Simple but powerful.

Key Components:

  1. Data augmentation (strong augmentation is critical)
  2. Encoder (ResNet)
  3. Projection Head (MLP)
  4. NT-Xent Loss
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as T

class SimCLR(nn.Module):
    def __init__(self, backbone='resnet50', projection_dim=128, temperature=0.5):
        super().__init__()
        self.temperature = temperature

        if backbone == 'resnet50':
            resnet = models.resnet50(pretrained=False)
            self.encoder = nn.Sequential(*list(resnet.children())[:-1])
            self.feature_dim = 2048
        elif backbone == 'resnet18':
            resnet = models.resnet18(pretrained=False)
            self.encoder = nn.Sequential(*list(resnet.children())[:-1])
            self.feature_dim = 512

        # Projection Head (removed during fine-tuning)
        self.projection_head = nn.Sequential(
            nn.Linear(self.feature_dim, self.feature_dim),
            nn.ReLU(inplace=True),
            nn.Linear(self.feature_dim, projection_dim)
        )

    def encode(self, x):
        h = self.encoder(x)
        h = h.squeeze(-1).squeeze(-1)
        return h

    def forward(self, x):
        h = self.encode(x)
        z = self.projection_head(h)
        return z

    def contrastive_loss(self, z1, z2):
        return info_nce_loss(z1, z2, self.temperature)


class SimCLRDataAugmentation:
    """SimCLR's strong data augmentation"""
    def __init__(self, image_size=224, s=1.0):
        color_jitter = T.ColorJitter(
            brightness=0.8*s,
            contrast=0.8*s,
            saturation=0.8*s,
            hue=0.2*s
        )
        self.transform = T.Compose([
            T.RandomResizedCrop(image_size, scale=(0.2, 1.0)),
            T.RandomHorizontalFlip(p=0.5),
            T.RandomApply([color_jitter], p=0.8),
            T.RandomGrayscale(p=0.2),
            T.GaussianBlur(kernel_size=int(0.1 * image_size)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
        ])

    def __call__(self, x):
        return self.transform(x), self.transform(x)


def train_simclr(model, dataloader, optimizer, epochs=200):
    model.train()
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=epochs
    )

    for epoch in range(epochs):
        total_loss = 0
        for (x1, x2), _ in dataloader:
            x1, x2 = x1.cuda(), x2.cuda()

            z1 = model(x1)
            z2 = model(x2)

            loss = model.contrastive_loss(z1, z2)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        scheduler.step()
        avg_loss = total_loss / len(dataloader)

        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")


model = SimCLR(backbone='resnet50', projection_dim=128).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-6)

def get_finetuning_model(pretrained_simclr, num_classes):
    encoder = pretrained_simclr.encoder
    classifier = nn.Linear(pretrained_simclr.feature_dim, num_classes)
    return nn.Sequential(encoder, nn.Flatten(), classifier)

3.4 MoCo (Momentum Contrast)

Published by Facebook AI in 2020. Enables using many negative samples without requiring a large batch size.

Key idea: Stable key encoder via momentum update + Queue for managing negative samples

class MoCo(nn.Module):
    def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07):
        super().__init__()
        self.K = K   # queue size
        self.m = m   # momentum coefficient
        self.T = T   # temperature

        # Query encoder and key encoder
        self.encoder_q = base_encoder(num_classes=dim)
        self.encoder_k = base_encoder(num_classes=dim)

        # Key encoder: no gradients, updated by momentum
        for param_q, param_k in zip(
            self.encoder_q.parameters(),
            self.encoder_k.parameters()
        ):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False

        # Negative sample queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = F.normalize(self.queue, dim=0)
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        for param_q, param_k in zip(
            self.encoder_q.parameters(),
            self.encoder_k.parameters()
        ):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        batch_size = keys.shape[0]
        ptr = int(self.queue_ptr)
        self.queue[:, ptr:ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K
        self.queue_ptr[0] = ptr

    def forward(self, im_q, im_k):
        q = self.encoder_q(im_q)
        q = F.normalize(q, dim=1)

        with torch.no_grad():
            self._momentum_update_key_encoder()
            k = self.encoder_k(im_k)
            k = F.normalize(k, dim=1)

        # Positive logits: (batch, 1)
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        # Negative logits: (batch, K)
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

        logits = torch.cat([l_pos, l_neg], dim=1) / self.T
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

        loss = F.cross_entropy(logits, labels)

        self._dequeue_and_enqueue(k)
        return loss

3.5 BYOL (Bootstrap Your Own Latent)

Published by DeepMind in 2020. A groundbreaking method that works without negative samples.

Key insight: An online network predicts the representation of a target network. The target is updated by momentum.

import copy

class BYOL(nn.Module):
    def __init__(self, backbone, projection_dim=256, prediction_dim=128, tau=0.996):
        super().__init__()
        self.tau = tau

        # Online network
        self.online_encoder = backbone
        self.online_projector = nn.Sequential(
            nn.Linear(2048, 4096), nn.BatchNorm1d(4096), nn.ReLU(),
            nn.Linear(4096, projection_dim)
        )
        self.online_predictor = nn.Sequential(
            nn.Linear(projection_dim, 4096), nn.BatchNorm1d(4096), nn.ReLU(),
            nn.Linear(4096, prediction_dim)
        )

        # Target network (no gradients)
        self.target_encoder = copy.deepcopy(backbone)
        self.target_projector = copy.deepcopy(self.online_projector)

        for param in self.target_encoder.parameters():
            param.requires_grad = False
        for param in self.target_projector.parameters():
            param.requires_grad = False

    @torch.no_grad()
    def update_target(self):
        for online, target in zip(
            self.online_encoder.parameters(),
            self.target_encoder.parameters()
        ):
            target.data = self.tau * target.data + (1 - self.tau) * online.data

        for online, target in zip(
            self.online_projector.parameters(),
            self.target_projector.parameters()
        ):
            target.data = self.tau * target.data + (1 - self.tau) * online.data

    def forward(self, x1, x2):
        online_feat1 = self.online_encoder(x1).squeeze()
        online_proj1 = self.online_projector(online_feat1)
        online_pred1 = self.online_predictor(online_proj1)

        online_feat2 = self.online_encoder(x2).squeeze()
        online_proj2 = self.online_projector(online_feat2)
        online_pred2 = self.online_predictor(online_proj2)

        with torch.no_grad():
            target_feat1 = self.target_encoder(x1).squeeze()
            target_proj1 = self.target_projector(target_feat1)

            target_feat2 = self.target_encoder(x2).squeeze()
            target_proj2 = self.target_projector(target_feat2)

        loss1 = byol_loss(online_pred1, target_proj2.detach())
        loss2 = byol_loss(online_pred2, target_proj1.detach())

        return (loss1 + loss2).mean()

def byol_loss(p, z):
    p = F.normalize(p, dim=-1)
    z = F.normalize(z, dim=-1)
    return -(p * z).sum(dim=-1)

3.6 SimSiam

Published by Facebook AI in 2021. Even simpler than BYOL — no momentum needed to prevent collapse.

class SimSiam(nn.Module):
    def __init__(self, backbone, dim=2048, pred_dim=512):
        super().__init__()
        self.encoder = backbone

        self.projector = nn.Sequential(
            nn.Linear(dim, dim, bias=False),
            nn.BatchNorm1d(dim),
            nn.ReLU(inplace=True),
            nn.Linear(dim, dim, bias=False),
            nn.BatchNorm1d(dim),
            nn.ReLU(inplace=True),
            nn.Linear(dim, dim, bias=False),
            nn.BatchNorm1d(dim, affine=False)
        )

        self.predictor = nn.Sequential(
            nn.Linear(dim, pred_dim, bias=False),
            nn.BatchNorm1d(pred_dim),
            nn.ReLU(inplace=True),
            nn.Linear(pred_dim, dim)
        )

    def forward(self, x1, x2):
        z1 = self.projector(self.encoder(x1).squeeze())
        z2 = self.projector(self.encoder(x2).squeeze())

        p1 = self.predictor(z1)
        p2 = self.predictor(z2)

        # Stop gradient - the key to the asymmetric design
        loss = simsiam_loss(p1, z2.detach()) / 2 + \
               simsiam_loss(p2, z1.detach()) / 2
        return loss

def simsiam_loss(p, z):
    z = z.detach()  # Stop gradient
    p = F.normalize(p, dim=1)
    z = F.normalize(z, dim=1)
    return -(p * z).sum(dim=1).mean()

4. Masked Autoencoder (MAE)

MAE, published by Kaiming He et al. in 2021, applies BERT's masking approach from NLP to images.

4.1 Core Idea

Mask 75% of an image and reconstruct the full image from the remaining 25%.

Why such a high masking ratio of 75%?

  • Text tokens carry high semantic density, so 15% masking suffices for BERT
  • Adjacent pixels in images have high redundancy — 75% masking forces genuine understanding

4.2 Asymmetric Encoder-Decoder

MAE's innovation: The encoder processes only visible patches; the decoder reconstructs the full image.

Input image (196 patches)
75% masking
49 visible patches → Encoder (ViT-Large)Encoded representation
196 patches (with mask tokens)Decoder (shallow ViT)Pixel reconstruction

The encoder processes only 25% of the original size — 3-4x faster training.

4.3 Complete MAE Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2

        self.projection = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        x = self.projection(x)   # (batch, embed_dim, H/p, W/p)
        x = x.flatten(2)         # (batch, embed_dim, num_patches)
        x = x.transpose(1, 2)   # (batch, num_patches, embed_dim)
        return x


class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., dropout=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        mlp_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        norm_x = self.norm1(x)
        attn_out, _ = self.attn(norm_x, norm_x, norm_x)
        x = x + attn_out
        x = x + self.mlp(self.norm2(x))
        return x


class MAE(nn.Module):
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_channels=3,
        encoder_dim=768,
        encoder_depth=12,
        encoder_heads=12,
        decoder_dim=512,
        decoder_depth=8,
        decoder_heads=16,
        mask_ratio=0.75,
    ):
        super().__init__()
        self.mask_ratio = mask_ratio
        self.patch_size = patch_size
        num_patches = (img_size // patch_size) ** 2

        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, encoder_dim)

        self.cls_token = nn.Parameter(torch.zeros(1, 1, encoder_dim))
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, encoder_dim),
            requires_grad=False
        )

        self.encoder_blocks = nn.ModuleList([
            TransformerBlock(encoder_dim, encoder_heads)
            for _ in range(encoder_depth)
        ])
        self.encoder_norm = nn.LayerNorm(encoder_dim)

        self.decoder_embed = nn.Linear(encoder_dim, decoder_dim, bias=True)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))

        self.decoder_pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, decoder_dim),
            requires_grad=False
        )

        self.decoder_blocks = nn.ModuleList([
            TransformerBlock(decoder_dim, decoder_heads)
            for _ in range(decoder_depth)
        ])
        self.decoder_norm = nn.LayerNorm(decoder_dim)
        self.decoder_pred = nn.Linear(decoder_dim, patch_size**2 * in_channels)

    def random_masking(self, x, mask_ratio):
        N, L, D = x.shape
        len_keep = int(L * (1 - mask_ratio))

        noise = torch.rand(N, L, device=x.device)
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(
            x, dim=1,
            index=ids_keep.unsqueeze(-1).repeat(1, 1, D)
        )

        mask = torch.ones(N, L, device=x.device)
        mask[:, :len_keep] = 0
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

    def forward_encoder(self, x, mask_ratio):
        x = self.patch_embed(x)
        x = x + self.pos_embed[:, 1:, :]
        x, mask, ids_restore = self.random_masking(x, mask_ratio)

        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        for block in self.encoder_blocks:
            x = block(x)
        x = self.encoder_norm(x)

        return x, mask, ids_restore

    def forward_decoder(self, x, ids_restore):
        x = self.decoder_embed(x)

        mask_tokens = self.mask_token.repeat(
            x.shape[0],
            ids_restore.shape[1] + 1 - x.shape[1],
            1
        )
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
        x_ = torch.gather(
            x_, dim=1,
            index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
        )
        x = torch.cat([x[:, :1, :], x_], dim=1)

        x = x + self.decoder_pos_embed

        for block in self.decoder_blocks:
            x = block(x)
        x = self.decoder_norm(x)
        x = self.decoder_pred(x)
        x = x[:, 1:, :]  # Remove CLS token
        return x

    def forward(self, imgs, mask_ratio=0.75):
        latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
        pred = self.forward_decoder(latent, ids_restore)
        loss = self.forward_loss(imgs, pred, mask)
        return loss, pred, mask

    def forward_loss(self, imgs, pred, mask):
        target = self.patchify(imgs)

        mean = target.mean(dim=-1, keepdim=True)
        var = target.var(dim=-1, keepdim=True)
        target = (target - mean) / (var + 1.e-6).sqrt()

        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)
        loss = (loss * mask).sum() / mask.sum()
        return loss

    def patchify(self, imgs):
        p = self.patch_size
        h = w = imgs.shape[2] // p
        x = imgs.reshape(imgs.shape[0], 3, h, p, w, p)
        x = torch.einsum('nchpwq->nhwpqc', x)
        x = x.reshape(imgs.shape[0], h * w, p**2 * 3)
        return x


# Create and train the MAE model
mae_model = MAE(
    img_size=224,
    patch_size=16,
    encoder_dim=768,
    encoder_depth=12,
    encoder_heads=12,
    decoder_dim=512,
    decoder_depth=8,
    decoder_heads=16,
    mask_ratio=0.75
).cuda()

optimizer = torch.optim.AdamW(
    mae_model.parameters(),
    lr=1.5e-4,
    betas=(0.9, 0.95),
    weight_decay=0.05
)

def train_step(model, imgs, optimizer):
    model.train()
    optimizer.zero_grad()
    loss, pred, mask = model(imgs, mask_ratio=0.75)
    loss.backward()
    optimizer.step()
    return loss.item()

5. DINO

DINO (Self-DIstillation with NO labels) was published by Facebook AI Research in 2021. When ViT is trained with SSL, remarkable properties emerge.

5.1 Key Discovery

DINO-trained ViT's self-attention maps perform semantic segmentation without labels. The model perfectly separates foreground from background.

5.2 Self-Distillation Architecture

Student Network                 Teacher Network
      ↑                               ↑
 Local Views (many)             Global Views (2)
  • Teacher: Processes full-image views, richer context
  • Student: Processes local crops, learns to match teacher's predictions
  • Teacher update: EMA of the student (no gradients)

5.3 Centering and Sharpening

Two techniques to prevent collapse:

class DINOHead(nn.Module):
    def __init__(self, in_dim, out_dim=65536, bottleneck_dim=256):
        super().__init__()
        layers = [nn.Linear(in_dim, 2048), nn.GELU()]
        layers += [nn.Linear(2048, 2048), nn.GELU()]
        layers += [nn.Linear(2048, bottleneck_dim)]
        self.mlp = nn.Sequential(*layers)

        self.last_layer = nn.utils.weight_norm(
            nn.Linear(bottleneck_dim, out_dim, bias=False)
        )
        self.last_layer.weight_g.data.fill_(1)
        self.last_layer.weight_g.requires_grad = False

    def forward(self, x):
        x = self.mlp(x)
        x = F.normalize(x, dim=-1)
        x = self.last_layer(x)
        return x


class DINO(nn.Module):
    def __init__(self, backbone, out_dim=65536, teacher_temp=0.04, student_temp=0.1):
        super().__init__()
        self.student_temp = student_temp
        self.teacher_temp = teacher_temp

        self.student = backbone
        self.student_head = DINOHead(self.student.embed_dim, out_dim)

        self.teacher = copy.deepcopy(backbone)
        self.teacher_head = copy.deepcopy(self.student_head)

        for p in self.teacher.parameters():
            p.requires_grad = False
        for p in self.teacher_head.parameters():
            p.requires_grad = False

        self.register_buffer("center", torch.zeros(1, out_dim))

    @torch.no_grad()
    def update_teacher(self, momentum):
        for student_ps, teacher_ps in zip(
            self.student.parameters(), self.teacher.parameters()
        ):
            teacher_ps.data.mul_(momentum).add_((1 - momentum) * student_ps.data)

        for student_ps, teacher_ps in zip(
            self.student_head.parameters(), self.teacher_head.parameters()
        ):
            teacher_ps.data.mul_(momentum).add_((1 - momentum) * student_ps.data)

    @torch.no_grad()
    def update_center(self, teacher_output):
        batch_center = teacher_output.mean(dim=0, keepdim=True)
        self.center = self.center * 0.9 + batch_center * 0.1

    def dino_loss(self, student_output, teacher_output):
        student_out = student_output / self.student_temp
        student_out = student_out.chunk(2)

        # Teacher: centering + sharpening
        teacher_out = F.softmax(
            (teacher_output - self.center) / self.teacher_temp, dim=-1
        )
        teacher_out = teacher_out.detach().chunk(2)

        total_loss = 0
        n_loss_terms = 0

        for iq, q in enumerate(teacher_out):
            for v in range(len(student_out)):
                if v == iq:
                    continue
                loss = torch.sum(
                    -q * F.log_softmax(student_out[v], dim=-1), dim=-1
                )
                total_loss += loss.mean()
                n_loss_terms += 1

        return total_loss / n_loss_terms

    def forward(self, global_views, local_views, momentum=0.996):
        all_views = global_views + local_views
        student_output = torch.cat([
            self.student_head(self.student(view))
            for view in all_views
        ])

        with torch.no_grad():
            teacher_output = torch.cat([
                self.teacher_head(self.teacher(view))
                for view in global_views
            ])

        loss = self.dino_loss(student_output, teacher_output)
        self.update_center(teacher_output)
        self.update_teacher(momentum)

        return loss


class DINOAugmentation:
    def __init__(self, global_crops_scale=(0.4, 1.), local_crops_scale=(0.05, 0.4),
                 n_local_crops=8, image_size=224, local_size=96):
        self.n_local_crops = n_local_crops

        self.global_transform = T.Compose([
            T.RandomResizedCrop(image_size, scale=global_crops_scale,
                               interpolation=T.InterpolationMode.BICUBIC),
            T.RandomHorizontalFlip(),
            T.RandomApply([T.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
            T.RandomGrayscale(p=0.2),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])

        self.local_transform = T.Compose([
            T.RandomResizedCrop(local_size, scale=local_crops_scale,
                               interpolation=T.InterpolationMode.BICUBIC),
            T.RandomHorizontalFlip(),
            T.RandomApply([T.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
            T.RandomGrayscale(p=0.2),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])

    def __call__(self, image):
        global_views = [self.global_transform(image) for _ in range(2)]
        local_views = [self.local_transform(image) for _ in range(self.n_local_crops)]
        return global_views, local_views

5.4 DINOv2

DINOv2, released in 2023, is even more powerful:

  • Trained on 142 million curated images
  • Added iBOT (Image BERT Pre-Training with Online Tokenizer) loss
  • More stable training dynamics

6. CLIP

CLIP (Contrastive Language-Image Pretraining) is an innovative model published by OpenAI in 2021.

6.1 Text-Image Contrastive Learning

Trained on 400 million (image, text) pairs collected from the internet.

Image Encoder (ViT / ResNet)
  Image embedding

Text Encoder (Transformer)
  Text embedding

Goal: Matching (image, text) pairs → similar embeddings
      Non-matching pairs → dissimilar embeddings

6.2 CLIP Loss

def clip_loss(image_embeddings, text_embeddings, temperature):
    """
    CLIP Symmetric Contrastive Loss
    image_embeddings: (N, D)
    text_embeddings: (N, D)
    """
    image_embeddings = F.normalize(image_embeddings, dim=1)
    text_embeddings = F.normalize(text_embeddings, dim=1)

    # Similarity matrix: (N, N)
    logits = torch.matmul(image_embeddings, text_embeddings.T) / temperature

    # Labels: diagonal (image i matches text i)
    labels = torch.arange(len(logits)).to(logits.device)

    # Bidirectional loss: image→text and text→image
    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.T, labels)

    return (loss_i2t + loss_t2i) / 2

6.3 Full CLIP Usage

from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
import torch.nn.functional as F

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# ---- Zero-shot Image Classification ----
def zero_shot_classify(image_path, candidate_labels):
    """Classify an image without any label examples"""
    image = Image.open(image_path)

    # Create text prompts
    texts = [f"a photo of a {label}" for label in candidate_labels]

    inputs = processor(
        text=texts,
        images=image,
        return_tensors="pt",
        padding=True
    )

    with torch.no_grad():
        outputs = model(**inputs)

    logits_per_image = outputs.logits_per_image
    probs = logits_per_image.softmax(dim=1)

    results = {}
    for label, prob in zip(candidate_labels, probs[0]):
        results[label] = prob.item()

    return results

# Usage
labels = ["cat", "dog", "car", "airplane", "bird"]
results = zero_shot_classify("example.jpg", labels)
for label, prob in sorted(results.items(), key=lambda x: -x[1]):
    print(f"{label}: {prob:.4f}")


# ---- Image-Text Similarity Search ----
class CLIPRetrieval:
    def __init__(self, model_name="openai/clip-vit-base-patch32"):
        self.model = CLIPModel.from_pretrained(model_name)
        self.processor = CLIPProcessor.from_pretrained(model_name)
        self.model.eval()

        self.image_embeddings = None
        self.image_paths = []

    def index_images(self, image_paths):
        """Index an image database"""
        self.image_paths = image_paths
        embeddings = []

        for path in image_paths:
            image = Image.open(path)
            inputs = self.processor(images=image, return_tensors="pt")
            with torch.no_grad():
                embedding = self.model.get_image_features(**inputs)
            embeddings.append(F.normalize(embedding, dim=1))

        self.image_embeddings = torch.cat(embeddings, dim=0)
        print(f"Indexed {len(image_paths)} images")

    def search(self, query_text, top_k=5):
        """Search images using a text query"""
        inputs = self.processor(text=[query_text], return_tensors="pt", padding=True)
        with torch.no_grad():
            text_embedding = self.model.get_text_features(**inputs)
        text_embedding = F.normalize(text_embedding, dim=1)

        similarities = torch.matmul(text_embedding, self.image_embeddings.T).squeeze()
        top_indices = similarities.argsort(descending=True)[:top_k]

        results = [
            (self.image_paths[i], similarities[i].item())
            for i in top_indices
        ]
        return results


# Visualize CLIP embeddings with t-SNE
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np

def visualize_clip_embeddings(images, texts):
    """t-SNE visualization of CLIP embeddings"""
    image_inputs = processor(images=images, return_tensors="pt")
    text_inputs = processor(text=texts, return_tensors="pt", padding=True)

    with torch.no_grad():
        image_embs = F.normalize(model.get_image_features(**image_inputs), dim=1)
        text_embs = F.normalize(model.get_text_features(**text_inputs), dim=1)

    all_embs = torch.cat([image_embs, text_embs], dim=0).numpy()

    tsne = TSNE(n_components=2, random_state=42)
    reduced = tsne.fit_transform(all_embs)

    n = len(images)
    plt.figure(figsize=(12, 8))
    plt.scatter(reduced[:n, 0], reduced[:n, 1], c='blue', label='Images', alpha=0.7)
    plt.scatter(reduced[n:, 0], reduced[n:, 1], c='red', label='Texts', alpha=0.7)

    for i, text in enumerate(texts):
        plt.annotate(text, reduced[n+i], fontsize=8)

    plt.legend()
    plt.title('CLIP Embeddings (t-SNE)')
    plt.savefig('clip_tsne.png', dpi=150, bbox_inches='tight')
    plt.show()

7. BEiT

BEiT (BERT Pre-Training of Image Transformers) was published by Microsoft in 2021. Images are converted to discrete tokens and trained in a BERT-like fashion.

7.1 Discrete Visual Tokens (dVAE)

BEiT's core: instead of predicting raw pixels, the model predicts discrete tokens.

Step 1: dVAE converts image → discrete tokens (vocabulary of 8192)
Step 2: ViT predicts the token of each masked patch
from transformers import BeitForMaskedImageModeling, BeitImageProcessor
import torch

class BEiTPretraining:
    def __init__(self, model_name="microsoft/beit-base-patch16-224"):
        self.model = BeitForMaskedImageModeling.from_pretrained(model_name)
        self.processor = BeitImageProcessor.from_pretrained(model_name)

    def create_bool_masked_pos(self, num_patches, mask_ratio=0.4):
        num_masked = int(num_patches * mask_ratio)
        mask = torch.zeros(num_patches, dtype=torch.bool)
        masked_indices = torch.randperm(num_patches)[:num_masked]
        mask[masked_indices] = True
        return mask

    def forward(self, images):
        inputs = self.processor(images, return_tensors="pt")
        pixel_values = inputs.pixel_values

        num_patches = (224 // 16) ** 2  # 196
        bool_masked_pos = self.create_bool_masked_pos(num_patches)
        bool_masked_pos = bool_masked_pos.unsqueeze(0).expand(
            pixel_values.size(0), -1
        )

        outputs = self.model(
            pixel_values=pixel_values,
            bool_masked_pos=bool_masked_pos
        )

        return outputs.loss

8. Language Model Pretraining

SSL has a long history in NLP.

8.1 GPT - Next Token Prediction

Autoregressive approach that predicts the next token from left to right.

import torch
import torch.nn as nn
from torch.nn import functional as F

class GPTLanguageModel(nn.Module):
    def __init__(self, vocab_size, block_size, n_embd, n_head, n_layer, dropout=0.1):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, n_embd)
        self.position_embedding = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[
            TransformerBlock(n_embd, n_head, dropout=dropout)
            for _ in range(n_layer)
        ])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        self.block_size = block_size

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding(idx)
        pos_emb = self.position_embedding(torch.arange(T, device=idx.device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        if targets is None:
            return logits, None

        B, T, C = logits.shape
        loss = F.cross_entropy(logits.view(B * T, C), targets.view(B * T))
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature

            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = float('-inf')

            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

8.2 BERT - Masked Language Modeling

Learns bidirectional context. 15% of tokens are masked and must be reconstructed.

from transformers import BertForMaskedLM, BertTokenizer
import torch

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForMaskedLM.from_pretrained("bert-base-uncased")

def create_mlm_input(text, mask_prob=0.15):
    """Generate masked language model input"""
    tokens = tokenizer.encode(text, return_tensors="pt")
    labels = tokens.clone()

    rand = torch.rand(tokens.shape)
    mask_arr = (rand < mask_prob) & (tokens != tokenizer.cls_token_id) & \
               (tokens != tokenizer.sep_token_id)

    # Masking strategy:
    # - 80%: replace with [MASK] token
    # - 10%: replace with a random token
    # - 10%: keep original token
    for i, masked in enumerate(mask_arr[0]):
        if masked:
            prob = torch.rand(1).item()
            if prob < 0.8:
                tokens[0][i] = tokenizer.mask_token_id
            elif prob < 0.9:
                tokens[0][i] = torch.randint(len(tokenizer), (1,))

    labels[~mask_arr] = -100
    return tokens, labels

# Compute MLM loss
text = "The quick brown fox jumps over the lazy dog"
input_ids, labels = create_mlm_input(text)

with torch.no_grad():
    outputs = model(input_ids=input_ids, labels=labels)
    print(f"MLM Loss: {outputs.loss.item():.4f}")

# Predict masked tokens
input_text = "The [MASK] brown fox"
inputs = tokenizer(input_text, return_tensors="pt")

with torch.no_grad():
    logits = model(**inputs).logits

mask_idx = (inputs.input_ids == tokenizer.mask_token_id).nonzero()[0][1]
top_5 = logits[0, mask_idx].topk(5)
for token_id, score in zip(top_5.indices, top_5.values):
    print(f"{tokenizer.decode([token_id])}: {score:.3f}")

8.3 T5 - Text-to-Text Transfer

from transformers import T5ForConditionalGeneration, T5Tokenizer

tokenizer = T5Tokenizer.from_pretrained("t5-base")
model = T5ForConditionalGeneration.from_pretrained("t5-base")

# T5 unifies all tasks as text-to-text
tasks = {
    "translation": "translate English to French: The weather is nice today",
    "summarization": "summarize: Scientists have discovered a new species of deep-sea fish...",
    "sentiment": "sentiment analysis: This movie was absolutely fantastic!",
    "qa": "question: What is the capital of France? context: Paris is the capital..."
}

for task_name, input_text in tasks.items():
    inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
    with torch.no_grad():
        output_ids = model.generate(
            inputs.input_ids,
            max_length=128,
            num_beams=4,
            early_stopping=True
        )
    output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    print(f"\n[{task_name}]")
    print(f"Input: {input_text[:60]}...")
    print(f"Output: {output}")

9. Practical Applications

9.1 Few-shot Learning

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

class LinearProbe(nn.Module):
    """Linear classifier on top of SSL features"""
    def __init__(self, feature_dim, num_classes):
        super().__init__()
        self.linear = nn.Linear(feature_dim, num_classes)

    def forward(self, x):
        return self.linear(x)


def extract_features(model, dataloader, device='cuda'):
    """Extract features from a pretrained model"""
    model.eval()
    features_list = []
    labels_list = []

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            features = model.encode(images)
            features_list.append(features.cpu())
            labels_list.append(labels)

    return torch.cat(features_list), torch.cat(labels_list)


def few_shot_evaluation(ssl_model, train_loader, val_loader,
                        n_shots=[1, 5, 10, 100], device='cuda'):
    """Evaluate with n labeled examples per class"""
    ssl_model.eval()

    all_features, all_labels = extract_features(ssl_model, train_loader, device)
    val_features, val_labels = extract_features(ssl_model, val_loader, device)

    results = {}
    num_classes = len(all_labels.unique())

    for n_shot in n_shots:
        selected_indices = []
        for cls in range(num_classes):
            cls_indices = (all_labels == cls).nonzero(as_tuple=True)[0]
            selected = cls_indices[torch.randperm(len(cls_indices))[:n_shot]]
            selected_indices.append(selected)

        selected_indices = torch.cat(selected_indices)
        train_feats = all_features[selected_indices]
        train_labs = all_labels[selected_indices]

        probe = LinearProbe(all_features.shape[1], num_classes).to(device)
        optimizer = torch.optim.Adam(probe.parameters(), lr=1e-3)
        criterion = nn.CrossEntropyLoss()

        train_feats = train_feats.to(device)
        train_labs = train_labs.to(device)

        for epoch in range(100):
            probe.train()
            logits = probe(train_feats)
            loss = criterion(logits, train_labs)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        probe.eval()
        with torch.no_grad():
            val_logits = probe(val_features.to(device))
            preds = val_logits.argmax(dim=1).cpu()
            acc = (preds == val_labels).float().mean().item()

        results[n_shot] = acc
        print(f"{n_shot}-shot accuracy: {acc * 100:.2f}%")

    return results


def knn_evaluation(ssl_model, train_loader, val_loader, k=20, device='cuda'):
    """k-NN evaluation of SSL representation quality (parameter-free)"""
    ssl_model.eval()

    train_features, train_labels = extract_features(ssl_model, train_loader, device)
    val_features, val_labels = extract_features(ssl_model, val_loader, device)

    train_features = nn.functional.normalize(train_features, dim=1)
    val_features = nn.functional.normalize(val_features, dim=1)

    correct = 0
    total = 0
    batch_size = 256

    for i in range(0, len(val_features), batch_size):
        batch_feats = val_features[i:i+batch_size].to(device)
        batch_labels = val_labels[i:i+batch_size]

        sims = torch.mm(batch_feats, train_features.to(device).T)
        topk_sims, topk_indices = sims.topk(k, dim=1)

        topk_labels = train_labels[topk_indices.cpu()]
        preds = topk_labels.mode(dim=1).values

        correct += (preds == batch_labels).sum().item()
        total += len(batch_labels)

    accuracy = correct / total
    print(f"k-NN ({k}) accuracy: {accuracy * 100:.2f}%")
    return accuracy

9.2 Domain-Specific SSL

# Medical images (chest X-ray example)
class MedicalSSL(nn.Module):
    """Self-supervised learning for medical images"""
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone

        # Weaker augmentation for medical images
        self.augmentation = T.Compose([
            T.RandomResizedCrop(224, scale=(0.8, 1.0)),
            T.RandomHorizontalFlip(p=0.5),
            # Minimize color changes for medical images
            T.RandomApply([T.ColorJitter(0.1, 0.1, 0.0, 0.0)], p=0.3),
            T.ToTensor(),
            T.Normalize([0.485], [0.229])  # Grayscale X-ray
        ])

    def forward(self, x):
        v1 = self.augmentation(x)
        v2 = self.augmentation(x)
        return v1, v2


# Satellite images
class SatelliteSSL:
    """Augmentation for satellite images (preserving geographic properties)"""
    def __init__(self):
        self.transform = T.Compose([
            T.RandomResizedCrop(256, scale=(0.4, 1.0)),
            # Satellite images: rotation invariance is important
            T.RandomRotation(90),
            T.RandomHorizontalFlip(),
            T.RandomVerticalFlip(),
            # Simulate time-of-day / seasonal variation
            T.ColorJitter(brightness=0.4, contrast=0.4),
            T.ToTensor(),
        ])

9.3 Using Pretrained SSL Models from the Hub

from transformers import (
    ViTModel, ViTFeatureExtractor,
    CLIPModel, CLIPProcessor
)
from PIL import Image
import torch

# Load MAE pretrained ViT
def load_mae_pretrained():
    model = ViTModel.from_pretrained("facebook/vit-mae-base")
    feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")
    return model, feature_extractor

# DINO ViT feature extraction
def extract_dino_features(image_path):
    model = ViTModel.from_pretrained("facebook/dino-vits16")
    processor = ViTFeatureExtractor.from_pretrained("facebook/dino-vits16")

    image = Image.open(image_path)
    inputs = processor(images=image, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)

    # CLS token (global image representation)
    cls_features = outputs.last_hidden_state[:, 0, :]
    # Per-patch features (with spatial information)
    patch_features = outputs.last_hidden_state[:, 1:, :]

    return cls_features, patch_features

# Extract features with CLIP
def clip_feature_extraction(images, texts=None):
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    if texts:
        inputs = processor(text=texts, images=images, return_tensors="pt", padding=True)
        with torch.no_grad():
            outputs = model(**inputs)
        return outputs.image_embeds, outputs.text_embeds
    else:
        inputs = processor(images=images, return_tensors="pt")
        with torch.no_grad():
            image_features = model.get_image_features(**inputs)
        return image_features

Conclusion

Self-supervised learning is transforming AI. Key takeaways:

MethodYearCore IdeaStrength
SimCLR2020Strong augmentation + contrastiveSimple and powerful
MoCo2020Momentum queueBatch-size independent
BYOL2020Contrastive without negativesBatch-size independent
MAE202175% masking + pixel reconstructionEfficient, ViT-optimal
DINO2021Self-distillationSegmentation properties
CLIP2021Image-text contrastiveZero-shot classification
BEiT2021Discrete token maskingBERT-style
DINOv22023Large-scale + iBOTUniversal feature extractor

Learning Roadmap:

  1. Implement SimCLR to understand contrastive learning
  2. Study MAE to grasp the masking approach
  3. Apply CLIP to multimodal learning
  4. Fine-tune on your domain data

SSL has become the foundation of LLM pretraining and computer vision. As an approach that effectively leverages vast amounts of unlabeled data, SSL will continue to be a core driver of AI progress.


References