Split View: 자기지도 학습(Self-Supervised Learning) 완전 정복: SimCLR, MAE, DINO, CLIP
자기지도 학습(Self-Supervised Learning) 완전 정복: SimCLR, MAE, DINO, CLIP
목차
- 자기지도 학습이란?
- 초기 SSL 방법들
- 대조 학습 (Contrastive Learning)
- Masked Autoencoder (MAE)
- DINO - 라벨 없는 자기 증류
- CLIP - 텍스트-이미지 대조 학습
- BEiT - BERT의 이미지 버전
- 언어 모델의 사전학습
- 실전 활용
1. 자기지도 학습이란?
1.1 학습 패러다임의 비교
딥러닝의 세 가지 주요 학습 패러다임을 이해하는 것부터 시작하겠습니다.
**지도 학습 (Supervised Learning)**은 레이블이 붙은 데이터로 학습합니다. ImageNet 1000만 장의 이미지에 인간이 하나하나 레이블을 달았고, 이 데이터로 ResNet, ViT 같은 모델을 학습합니다. 문제는 레이블 수집 비용이 엄청나다는 것입니다. 의료 이미지 한 장에 전문 의사가 레이블을 다는 데 수십 분이 걸릴 수 있습니다.
**비지도 학습 (Unsupervised Learning)**은 레이블 없이 데이터의 구조를 파악합니다. K-Means 클러스터링, PCA, GAN이 여기에 해당합니다. 하지만 전통적인 비지도 학습은 다운스트림 태스크에 바로 쓸 수 있는 표현(representation)을 학습하기 어려웠습니다.
**자기지도 학습 (Self-Supervised Learning, SSL)**은 두 세계의 장점을 결합합니다. 레이블 없는 대규모 데이터를 활용하면서도, 데이터 자체에서 지도 신호를 만들어냅니다.
핵심 아이디어: 데이터 자체가 레이블이다.
인터넷에는 수십억 장의 이미지, 수조 개의 단어가 존재합니다. 자기지도 학습은 이 방대한 비레이블 데이터를 활용해 강력한 표현을 학습합니다.
1.2 레이블 부족 문제
현실 세계에서는 레이블 수집이 매우 어렵습니다.
| 도메인 | 레이블 수집 비용 | 이유 |
|---|---|---|
| 의료 이미지 | 매우 높음 | 전문 의사 필요 |
| 위성 이미지 | 높음 | 전문 지리 지식 필요 |
| 법률 문서 | 높음 | 법률 전문가 필요 |
| 산업 결함 감지 | 높음 | 희귀한 불량 샘플 |
| 음성 인식 | 중간 | 전사(transcription) 필요 |
SSL은 이 문제를 근본적으로 해결합니다. 레이블 없는 대규모 데이터로 먼저 사전학습(pretraining)하고, 소량의 레이블 데이터로 파인튜닝(fine-tuning)하는 방식입니다.
1.3 프리텍스트 태스크 (Pretext Task)
SSL의 핵심은 프리텍스트 태스크입니다. 데이터 자체에서 인위적인 예측 문제를 만들어 모델이 의미 있는 표현을 학습하도록 합니다.
좋은 프리텍스트 태스크의 조건:
- 레이블 없이 자동으로 만들 수 있어야 함
- 태스크를 잘 풀려면 의미 있는 표현이 필요해야 함
- 너무 쉽거나 너무 어렵지 않아야 함
예시:
- 이미지 회전 예측: 이미지를 0°, 90°, 180°, 270° 회전 후, 몇 도 회전했는지 예측
- 마스킹: 이미지/텍스트 일부를 가리고 복원
- 대조 학습: 같은 이미지의 두 뷰가 유사하도록 학습
1.4 SSL의 응용 분야
SSL은 현대 AI의 근간이 되었습니다:
- GPT, BERT: 텍스트 SSL → ChatGPT, Gemini의 기반
- CLIP: 이미지-텍스트 SSL → DALL-E, Stable Diffusion의 기반
- MAE, DINO: 이미지 SSL → 의료 영상, 자율주행의 기반
- wav2vec: 오디오 SSL → 음성 인식의 기반
2. 초기 SSL 방법들
초기 SSL 연구들은 다양한 프리텍스트 태스크를 탐구했습니다.
2.1 회전 예측 (Rotation Prediction)
2018년 Gidaris et al.이 제안한 방법입니다. 이미지를 4가지 각도(0°, 90°, 180°, 270°)로 회전시키고, 어떤 회전이 적용되었는지 분류하는 문제입니다.
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
class RotationSSL(nn.Module):
def __init__(self, backbone, num_classes=4):
super().__init__()
self.backbone = backbone
self.classifier = nn.Linear(backbone.output_dim, num_classes)
def forward(self, x):
features = self.backbone(x)
return self.classifier(features)
def create_rotation_dataset(images):
"""이미지를 4가지 방향으로 회전하여 (image, label) 쌍 생성"""
rotated_images = []
labels = []
for img in images:
for k in range(4): # 0, 90, 180, 270도
# PIL Image를 k*90도 회전
rotated = T.functional.rotate(img, k * 90)
rotated_images.append(rotated)
labels.append(k)
return rotated_images, labels
# 학습 루프
def train_rotation_ssl(model, dataloader, optimizer, epochs=10):
criterion = nn.CrossEntropyLoss()
model.train()
for epoch in range(epochs):
total_loss = 0
for images, labels in dataloader:
images, labels = images.cuda(), labels.cuda()
optimizer.zero_grad()
logits = model(images)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
이 방법의 한계: 자연 이미지에는 명확한 방향성이 있지만, 하늘 사진이나 텍스처 이미지에는 효과적이지 않습니다.
2.2 Jigsaw Puzzle 풀기
2016년 Noroozi & Favaro가 제안했습니다. 이미지를 3x3 그리드로 자른 후 무작위로 섞고, 원래 순서를 예측합니다.
import itertools
import numpy as np
class JigsawSSL(nn.Module):
def __init__(self, backbone, num_permutations=100):
super().__init__()
self.backbone = backbone
# 각 패치를 독립적으로 처리
self.patch_encoder = nn.Sequential(
backbone,
nn.Linear(backbone.output_dim, 512)
)
# 9개 패치를 모두 처리 후 순열 분류
self.classifier = nn.Sequential(
nn.Linear(512 * 9, 4096),
nn.ReLU(),
nn.Linear(4096, num_permutations)
)
def forward(self, patches):
# patches: (batch, 9, C, H, W)
batch_size = patches.size(0)
patch_features = []
for i in range(9):
feat = self.patch_encoder(patches[:, i]) # (batch, 512)
patch_features.append(feat)
# 모든 패치 특징 연결
combined = torch.cat(patch_features, dim=1) # (batch, 512*9)
return self.classifier(combined)
def create_jigsaw_dataset(image, grid_size=3):
"""이미지를 grid_size x grid_size 패치로 분할하고 섞기"""
h, w = image.shape[-2:]
patch_h, patch_w = h // grid_size, w // grid_size
patches = []
for i in range(grid_size):
for j in range(grid_size):
patch = image[...,
i*patch_h:(i+1)*patch_h,
j*patch_w:(j+1)*patch_w]
patches.append(patch)
# 미리 정의된 순열 집합에서 선택
permutation_idx = np.random.randint(0, len(PREDEFINED_PERMUTATIONS))
perm = PREDEFINED_PERMUTATIONS[permutation_idx]
shuffled = [patches[p] for p in perm]
return torch.stack(shuffled), permutation_idx
2.3 컬러화 (Colorization)
2016년 Zhang et al.이 제안한 방법입니다. 흑백 이미지를 입력받아 색상을 예측합니다.
class ColorizationSSL(nn.Module):
def __init__(self):
super().__init__()
# L 채널 입력 (밝기)
self.encoder = nn.Sequential(
nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(),
nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.ReLU(),
nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.ReLU(),
)
# ab 채널 출력 (색상)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), nn.ReLU(),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.ReLU(),
nn.Conv2d(64, 2, 3, padding=1), # ab 채널
nn.Tanh()
)
def forward(self, l_channel):
features = self.encoder(l_channel)
ab_channels = self.decoder(features)
return ab_channels
# LAB 색 공간에서 작업
from skimage.color import rgb2lab, lab2rgb
def prepare_colorization_data(rgb_image):
lab = rgb2lab(rgb_image)
l_channel = lab[:, :, 0:1] # 밝기 채널
ab_channels = lab[:, :, 1:] # 색상 채널
return l_channel, ab_channels
2.4 다음 프레임 예측
비디오에서 시간적 연속성을 이용합니다. 이전 프레임들로 다음 프레임을 예측합니다.
class VideoSSL(nn.Module):
def __init__(self, frame_encoder, temporal_model):
super().__init__()
self.frame_encoder = frame_encoder
self.temporal_model = temporal_model # LSTM 또는 Transformer
self.decoder = nn.ConvTranspose2d(...)
def forward(self, frames):
# frames: (batch, T, C, H, W)
T = frames.size(1)
# 각 프레임 인코딩
frame_features = []
for t in range(T - 1):
feat = self.frame_encoder(frames[:, t])
frame_features.append(feat)
# 시간 모델링
features = torch.stack(frame_features, dim=1)
next_feat = self.temporal_model(features)
# 다음 프레임 디코딩
next_frame_pred = self.decoder(next_feat[:, -1])
return next_frame_pred
3. 대조 학습
대조 학습(Contrastive Learning)은 현대 SSL의 핵심입니다. 2020년 전후로 폭발적으로 발전했습니다.
3.1 핵심 아이디어
같은 것은 가깝게, 다른 것은 멀게.
같은 이미지의 두 가지 다른 뷰(crop, flip, color jitter 등)는 임베딩 공간에서 가까워야 하고, 다른 이미지들은 멀어야 합니다.
이미지 x
├── 증강 뷰 v1 → z1 (임베딩)
└── 증강 뷰 v2 → z2 (임베딩)
목표: z1과 z2는 가깝게, z1과 다른 이미지의 z3은 멀게
3.2 InfoNCE Loss
대조 학습의 표준 손실 함수입니다. NT-Xent(Normalized Temperature-scaled Cross Entropy)라고도 합니다.
수식 (배치 내 N개 이미지, 각 이미지당 2개 뷰):
여기서 sim은 코사인 유사도, τ는 온도 파라미터입니다.
import torch
import torch.nn.functional as F
def info_nce_loss(z1, z2, temperature=0.5):
"""
InfoNCE / NT-Xent Loss 구현
z1, z2: (batch_size, embedding_dim) - 같은 이미지의 두 뷰
"""
batch_size = z1.size(0)
# L2 정규화
z1 = F.normalize(z1, dim=1)
z2 = F.normalize(z2, dim=1)
# 두 뷰를 합쳐서 2N x D 행렬 생성
# [z1_1, z1_2, ..., z1_N, z2_1, z2_2, ..., z2_N]
z = torch.cat([z1, z2], dim=0) # (2N, D)
# 모든 쌍의 유사도 계산
similarity = torch.mm(z, z.t()) / temperature # (2N, 2N)
# 자기 자신과의 유사도는 제외 (대각선 -inf)
mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
similarity.masked_fill_(mask, float('-inf'))
# 포지티브 쌍: i번째와 i+N번째, i+N번째와 i번째
labels = torch.cat([
torch.arange(batch_size, 2 * batch_size),
torch.arange(batch_size)
]).to(z.device)
loss = F.cross_entropy(similarity, labels)
return loss
# 사용 예시
z1 = torch.randn(32, 128) # 배치 32, 임베딩 128차원
z2 = torch.randn(32, 128)
loss = info_nce_loss(z1, z2, temperature=0.5)
print(f"InfoNCE Loss: {loss.item():.4f}")
3.3 SimCLR
2020년 Google이 발표한 Simple Framework for Contrastive Learning입니다. 단순하지만 강력합니다.
핵심 구성요소:
- 데이터 증강 (강한 증강이 핵심)
- 인코더 (ResNet)
- Projection Head (MLP)
- NT-Xent Loss
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as T
class SimCLR(nn.Module):
def __init__(self, backbone='resnet50', projection_dim=128, temperature=0.5):
super().__init__()
self.temperature = temperature
# 백본 인코더
if backbone == 'resnet50':
resnet = models.resnet50(pretrained=False)
self.encoder = nn.Sequential(*list(resnet.children())[:-1]) # FC 제거
self.feature_dim = 2048
elif backbone == 'resnet18':
resnet = models.resnet18(pretrained=False)
self.encoder = nn.Sequential(*list(resnet.children())[:-1])
self.feature_dim = 512
# Projection Head (중요: 파인튜닝 시 제거)
self.projection_head = nn.Sequential(
nn.Linear(self.feature_dim, self.feature_dim),
nn.ReLU(inplace=True),
nn.Linear(self.feature_dim, projection_dim)
)
def encode(self, x):
h = self.encoder(x)
h = h.squeeze(-1).squeeze(-1) # (batch, feature_dim)
return h
def forward(self, x):
h = self.encode(x)
z = self.projection_head(h)
return z
def contrastive_loss(self, z1, z2):
return info_nce_loss(z1, z2, self.temperature)
class SimCLRDataAugmentation:
"""SimCLR의 강한 데이터 증강"""
def __init__(self, image_size=224, s=1.0):
color_jitter = T.ColorJitter(
brightness=0.8*s,
contrast=0.8*s,
saturation=0.8*s,
hue=0.2*s
)
self.transform = T.Compose([
T.RandomResizedCrop(image_size, scale=(0.2, 1.0)),
T.RandomHorizontalFlip(p=0.5),
T.RandomApply([color_jitter], p=0.8),
T.RandomGrayscale(p=0.2),
T.GaussianBlur(kernel_size=int(0.1 * image_size)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def __call__(self, x):
return self.transform(x), self.transform(x)
# SimCLR 학습
def train_simclr(model, dataloader, optimizer, epochs=200):
model.train()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=epochs
)
for epoch in range(epochs):
total_loss = 0
for (x1, x2), _ in dataloader:
x1, x2 = x1.cuda(), x2.cuda()
z1 = model(x1)
z2 = model(x2)
loss = model.contrastive_loss(z1, z2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
scheduler.step()
avg_loss = total_loss / len(dataloader)
if (epoch + 1) % 10 == 0:
print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")
# 전체 파이프라인
model = SimCLR(backbone='resnet50', projection_dim=128).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-6)
augmentation = SimCLRDataAugmentation(image_size=224)
# 파인튜닝 시 projection head 제거
def get_finetuning_model(pretrained_simclr, num_classes):
# 인코더만 사용
encoder = pretrained_simclr.encoder
classifier = nn.Linear(pretrained_simclr.feature_dim, num_classes)
return nn.Sequential(encoder, nn.Flatten(), classifier)
3.4 MoCo (Momentum Contrast)
Facebook AI가 2020년 발표했습니다. 큰 배치 없이도 많은 네거티브 샘플을 사용할 수 있는 방법입니다.
핵심 아이디어: 모멘텀 업데이트로 안정적인 키 인코더 유지 + 큐(Queue)로 네거티브 샘플 관리
class MoCo(nn.Module):
def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07):
super().__init__()
self.K = K # 큐 크기
self.m = m # 모멘텀 계수
self.T = T # 온도
# 쿼리 인코더와 키 인코더
self.encoder_q = base_encoder(num_classes=dim)
self.encoder_k = base_encoder(num_classes=dim)
# 키 인코더는 그래디언트 없이 모멘텀 업데이트
for param_q, param_k in zip(
self.encoder_q.parameters(),
self.encoder_k.parameters()
):
param_k.data.copy_(param_q.data)
param_k.requires_grad = False
# 네거티브 샘플 큐
self.register_buffer("queue", torch.randn(dim, K))
self.queue = F.normalize(self.queue, dim=0)
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def _momentum_update_key_encoder(self):
"""키 인코더를 모멘텀으로 업데이트"""
for param_q, param_k in zip(
self.encoder_q.parameters(),
self.encoder_k.parameters()
):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
"""큐에 새 키 추가, 오래된 키 제거"""
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
self.queue[:, ptr:ptr + batch_size] = keys.T
ptr = (ptr + batch_size) % self.K
self.queue_ptr[0] = ptr
def forward(self, im_q, im_k):
# 쿼리 임베딩
q = self.encoder_q(im_q)
q = F.normalize(q, dim=1)
# 키 임베딩 (그래디언트 없음)
with torch.no_grad():
self._momentum_update_key_encoder()
k = self.encoder_k(im_k)
k = F.normalize(k, dim=1)
# 포지티브 로짓: (batch, 1)
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
# 네거티브 로짓: (batch, K)
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
logits = torch.cat([l_pos, l_neg], dim=1) / self.T
labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
loss = F.cross_entropy(logits, labels)
self._dequeue_and_enqueue(k)
return loss
3.5 BYOL (Bootstrap Your Own Latent)
2020년 DeepMind가 발표한 획기적인 방법입니다. 네거티브 샘플 없이 작동합니다.
핵심: 온라인 네트워크가 타깃 네트워크의 표현을 예측. 타깃은 모멘텀 업데이트.
class BYOL(nn.Module):
def __init__(self, backbone, projection_dim=256, prediction_dim=128, tau=0.996):
super().__init__()
self.tau = tau # 모멘텀 계수
# 온라인 네트워크
self.online_encoder = backbone
self.online_projector = nn.Sequential(
nn.Linear(2048, 4096), nn.BatchNorm1d(4096), nn.ReLU(),
nn.Linear(4096, projection_dim)
)
self.online_predictor = nn.Sequential(
nn.Linear(projection_dim, 4096), nn.BatchNorm1d(4096), nn.ReLU(),
nn.Linear(4096, prediction_dim)
)
# 타깃 네트워크 (그래디언트 없음)
self.target_encoder = copy.deepcopy(backbone)
self.target_projector = copy.deepcopy(self.online_projector)
for param in self.target_encoder.parameters():
param.requires_grad = False
for param in self.target_projector.parameters():
param.requires_grad = False
@torch.no_grad()
def update_target(self):
"""EMA 업데이트"""
for online, target in zip(
self.online_encoder.parameters(),
self.target_encoder.parameters()
):
target.data = self.tau * target.data + (1 - self.tau) * online.data
for online, target in zip(
self.online_projector.parameters(),
self.target_projector.parameters()
):
target.data = self.tau * target.data + (1 - self.tau) * online.data
def forward(self, x1, x2):
# 온라인 예측
online_feat1 = self.online_encoder(x1).squeeze()
online_proj1 = self.online_projector(online_feat1)
online_pred1 = self.online_predictor(online_proj1)
online_feat2 = self.online_encoder(x2).squeeze()
online_proj2 = self.online_projector(online_feat2)
online_pred2 = self.online_predictor(online_proj2)
# 타깃 (stop gradient)
with torch.no_grad():
target_feat1 = self.target_encoder(x1).squeeze()
target_proj1 = self.target_projector(target_feat1)
target_feat2 = self.target_encoder(x2).squeeze()
target_proj2 = self.target_projector(target_feat2)
# BYOL Loss: cosine similarity 최대화
loss1 = byol_loss(online_pred1, target_proj2.detach())
loss2 = byol_loss(online_pred2, target_proj1.detach())
return (loss1 + loss2).mean()
def byol_loss(p, z):
"""Negative cosine similarity"""
p = F.normalize(p, dim=-1)
z = F.normalize(z, dim=-1)
return -(p * z).sum(dim=-1)
3.6 SimSiam
2021년 Facebook AI가 발표. BYOL보다 더 단순합니다. 모멘텀 없이도 collapse 없이 학습됩니다.
class SimSiam(nn.Module):
def __init__(self, backbone, dim=2048, pred_dim=512):
super().__init__()
self.encoder = backbone
# Projector
self.projector = nn.Sequential(
nn.Linear(dim, dim, bias=False),
nn.BatchNorm1d(dim),
nn.ReLU(inplace=True),
nn.Linear(dim, dim, bias=False),
nn.BatchNorm1d(dim),
nn.ReLU(inplace=True),
nn.Linear(dim, dim, bias=False),
nn.BatchNorm1d(dim, affine=False) # 마지막 BN: affine=False
)
# Predictor
self.predictor = nn.Sequential(
nn.Linear(dim, pred_dim, bias=False),
nn.BatchNorm1d(pred_dim),
nn.ReLU(inplace=True),
nn.Linear(pred_dim, dim)
)
def forward(self, x1, x2):
z1 = self.projector(self.encoder(x1).squeeze())
z2 = self.projector(self.encoder(x2).squeeze())
p1 = self.predictor(z1)
p2 = self.predictor(z2)
# Stop gradient - 비대칭 구조의 핵심
loss = simsiam_loss(p1, z2.detach()) / 2 + \
simsiam_loss(p2, z1.detach()) / 2
return loss
def simsiam_loss(p, z):
"""Negative cosine similarity"""
z = z.detach() # Stop gradient
p = F.normalize(p, dim=1)
z = F.normalize(z, dim=1)
return -(p * z).sum(dim=1).mean()
4. Masked Autoencoder (MAE)
2021년 Kaiming He et al.이 발표한 MAE는 자연어 처리의 BERT를 이미지에 적용한 것입니다.
4.1 핵심 아이디어
이미지의 75%를 마스킹하고 나머지 25%로 전체 이미지를 복원.
왜 75%라는 높은 마스킹 비율인가?
- 텍스트는 각 토큰이 높은 의미 밀도를 가져 15% 마스킹으로 충분
- 이미지는 인접 픽셀 간 높은 중복성 → 75%는 마스킹해야 진정한 이해 필요
4.2 비대칭 인코더-디코더
MAE의 혁신: 인코더는 가시 패치만 처리, 디코더는 전체를 복원.
입력 이미지 (196개 패치)
↓ 75% 마스킹
49개 가시 패치 → 인코더 (ViT-Large) → 인코딩
↓
196개 패치 (마스크 토큰 포함) → 디코더 (얕은 ViT) → 픽셀 복원
인코더는 원래 크기의 25%만 처리 → 3-4배 빠른 학습
4.3 완전한 MAE 구현
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
class PatchEmbedding(nn.Module):
"""이미지를 패치 임베딩으로 변환"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
self.projection = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
# x: (batch, C, H, W)
x = self.projection(x) # (batch, embed_dim, H/p, W/p)
x = x.flatten(2) # (batch, embed_dim, num_patches)
x = x.transpose(1, 2) # (batch, num_patches, embed_dim)
return x
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., dropout=0.):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
self.norm2 = nn.LayerNorm(dim)
mlp_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
# Self-attention
norm_x = self.norm1(x)
attn_out, _ = self.attn(norm_x, norm_x, norm_x)
x = x + attn_out
# MLP
x = x + self.mlp(self.norm2(x))
return x
class MAE(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
in_channels=3,
# 인코더 설정
encoder_dim=768,
encoder_depth=12,
encoder_heads=12,
# 디코더 설정 (인코더보다 훨씬 작음)
decoder_dim=512,
decoder_depth=8,
decoder_heads=16,
# 마스킹 비율
mask_ratio=0.75,
):
super().__init__()
self.mask_ratio = mask_ratio
self.patch_size = patch_size
num_patches = (img_size // patch_size) ** 2
# 패치 임베딩
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, encoder_dim)
# [CLS] 토큰과 위치 임베딩 (인코더)
self.cls_token = nn.Parameter(torch.zeros(1, 1, encoder_dim))
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, encoder_dim),
requires_grad=False
)
# 인코더 (ViT)
self.encoder_blocks = nn.ModuleList([
TransformerBlock(encoder_dim, encoder_heads)
for _ in range(encoder_depth)
])
self.encoder_norm = nn.LayerNorm(encoder_dim)
# 인코더-디코더 연결
self.decoder_embed = nn.Linear(encoder_dim, decoder_dim, bias=True)
# 마스크 토큰
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))
# 디코더 위치 임베딩
self.decoder_pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, decoder_dim),
requires_grad=False
)
# 디코더 (얕은 ViT)
self.decoder_blocks = nn.ModuleList([
TransformerBlock(decoder_dim, decoder_heads)
for _ in range(decoder_depth)
])
self.decoder_norm = nn.LayerNorm(decoder_dim)
# 픽셀 예측 헤드
self.decoder_pred = nn.Linear(decoder_dim, patch_size**2 * in_channels)
self._init_weights()
def _init_weights(self):
"""Sinusoidal 위치 임베딩 초기화"""
# 위치 임베딩 초기화 (sinusoidal)
pos_embed = self._get_2d_sincos_pos_embed(
self.pos_embed.shape[-1],
int(self.patch_embed.num_patches**0.5)
)
self.pos_embed.data.copy_(
torch.from_numpy(pos_embed).float().unsqueeze(0)
)
# 나머지 초기화
nn.init.normal_(self.cls_token, std=.02)
nn.init.normal_(self.mask_token, std=.02)
def _get_2d_sincos_pos_embed(self, embed_dim, grid_size):
"""2D sinusoidal 위치 임베딩"""
import numpy as np
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h)
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
emb_h = self._get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
emb_w = self._get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
emb = np.concatenate([emb_h, emb_w], axis=1)
# CLS 토큰을 위한 제로 임베딩 추가
emb = np.concatenate([np.zeros([1, embed_dim]), emb], axis=0)
return emb
def _get_1d_sincos_pos_embed_from_grid(self, embed_dim, pos):
import numpy as np
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.
omega = 1. / 10000**omega
pos = pos.reshape(-1)
out = np.einsum('m,d->md', pos, omega)
emb_sin = np.sin(out)
emb_cos = np.cos(out)
emb = np.concatenate([emb_sin, emb_cos], axis=1)
return emb
def random_masking(self, x, mask_ratio):
"""랜덤 마스킹: 일부 패치만 남김"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
# 랜덤 노이즈로 셔플 인덱스 생성
noise = torch.rand(N, L, device=x.device)
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
# 가시 패치만 선택
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(
x, dim=1,
index=ids_keep.unsqueeze(-1).repeat(1, 1, D)
)
# 마스크 생성 (1: 마스킹됨, 0: 가시)
mask = torch.ones(N, L, device=x.device)
mask[:, :len_keep] = 0
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
def forward_encoder(self, x, mask_ratio):
"""인코더: 가시 패치만 처리"""
# 패치 임베딩
x = self.patch_embed(x)
# 위치 임베딩 추가 (CLS 제외)
x = x + self.pos_embed[:, 1:, :]
# 랜덤 마스킹
x, mask, ids_restore = self.random_masking(x, mask_ratio)
# CLS 토큰 추가
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# 인코더 트랜스포머
for block in self.encoder_blocks:
x = block(x)
x = self.encoder_norm(x)
return x, mask, ids_restore
def forward_decoder(self, x, ids_restore):
"""디코더: 마스크 토큰 포함 전체 복원"""
# 인코더 차원 → 디코더 차원
x = self.decoder_embed(x)
# 마스크 토큰 추가
mask_tokens = self.mask_token.repeat(
x.shape[0],
ids_restore.shape[1] + 1 - x.shape[1],
1
)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # CLS 제외
x_ = torch.gather(
x_, dim=1,
index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
)
x = torch.cat([x[:, :1, :], x_], dim=1) # CLS 다시 추가
# 디코더 위치 임베딩
x = x + self.decoder_pos_embed
# 디코더 트랜스포머
for block in self.decoder_blocks:
x = block(x)
x = self.decoder_norm(x)
# 픽셀 예측
x = self.decoder_pred(x)
x = x[:, 1:, :] # CLS 토큰 제거
return x
def forward(self, imgs, mask_ratio=0.75):
# 인코딩 (마스킹 포함)
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
# 디코딩 (픽셀 복원)
pred = self.forward_decoder(latent, ids_restore)
# 손실 계산
loss = self.forward_loss(imgs, pred, mask)
return loss, pred, mask
def forward_loss(self, imgs, pred, mask):
"""마스킹된 패치에 대한 MSE Loss"""
# 타깃: 패치화된 이미지
target = self.patchify(imgs)
# 정규화 (patch normalization)
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.e-6).sqrt()
# 마스킹된 패치에 대해서만 손실 계산
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # 패치당 평균
loss = (loss * mask).sum() / mask.sum() # 마스킹된 패치 평균
return loss
def patchify(self, imgs):
"""이미지를 패치 시퀀스로 변환"""
p = self.patch_size
h = w = imgs.shape[2] // p
x = imgs.reshape(imgs.shape[0], 3, h, p, w, p)
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(imgs.shape[0], h * w, p**2 * 3)
return x
# MAE 모델 생성 및 학습
mae_model = MAE(
img_size=224,
patch_size=16,
encoder_dim=768,
encoder_depth=12,
encoder_heads=12,
decoder_dim=512,
decoder_depth=8,
decoder_heads=16,
mask_ratio=0.75
).cuda()
optimizer = torch.optim.AdamW(
mae_model.parameters(),
lr=1.5e-4,
betas=(0.9, 0.95),
weight_decay=0.05
)
# 학습 스텝
def train_step(model, imgs, optimizer):
model.train()
optimizer.zero_grad()
loss, pred, mask = model(imgs, mask_ratio=0.75)
loss.backward()
optimizer.step()
return loss.item()
5. DINO
DINO(Self-DIstillation with NO labels)는 2021년 Facebook AI Research가 발표했습니다. ViT를 SSL로 학습했을 때 놀라운 특성이 나타납니다.
5.1 DINO의 핵심 발견
DINO로 학습된 ViT의 자기 주의(self-attention) 맵은 **레이블 없이도 의미론적 분할(segmentation)**을 수행합니다. 배경과 전경을 완벽히 구분합니다.
5.2 Self-Distillation 구조
Student Network (학생) Teacher Network (선생)
↑ ↑
Local Views (여러 개) Global Views (2개)
- 선생: 전체 이미지 뷰 처리, 더 풍부한 컨텍스트
- 학생: 로컬 크롭 처리, 선생의 예측을 따라 배움
- 선생 업데이트: 학생의 EMA (그래디언트 없음)
5.3 Centering과 Sharpening
Collapse 방지를 위한 두 가지 기법:
class DINOHead(nn.Module):
"""DINO Projection Head"""
def __init__(self, in_dim, out_dim=65536, bottleneck_dim=256):
super().__init__()
layers = [nn.Linear(in_dim, 2048), nn.GELU()]
layers += [nn.Linear(2048, 2048), nn.GELU()]
layers += [nn.Linear(2048, bottleneck_dim)]
self.mlp = nn.Sequential(*layers)
# Weight normalization된 마지막 레이어
self.last_layer = nn.utils.weight_norm(
nn.Linear(bottleneck_dim, out_dim, bias=False)
)
self.last_layer.weight_g.data.fill_(1)
self.last_layer.weight_g.requires_grad = False
def forward(self, x):
x = self.mlp(x)
x = F.normalize(x, dim=-1)
x = self.last_layer(x)
return x
class DINO(nn.Module):
def __init__(self, backbone, out_dim=65536, teacher_temp=0.04, student_temp=0.1):
super().__init__()
self.student_temp = student_temp
self.teacher_temp = teacher_temp
# Student 네트워크
self.student = backbone
self.student_head = DINOHead(self.student.embed_dim, out_dim)
# Teacher 네트워크 (EMA)
self.teacher = copy.deepcopy(backbone)
self.teacher_head = copy.deepcopy(self.student_head)
for p in self.teacher.parameters():
p.requires_grad = False
for p in self.teacher_head.parameters():
p.requires_grad = False
# Centering을 위한 center 벡터
self.register_buffer("center", torch.zeros(1, out_dim))
@torch.no_grad()
def update_teacher(self, momentum):
"""EMA 업데이트"""
for student_ps, teacher_ps in zip(
self.student.parameters(), self.teacher.parameters()
):
teacher_ps.data.mul_(momentum).add_(
(1 - momentum) * student_ps.data
)
for student_ps, teacher_ps in zip(
self.student_head.parameters(), self.teacher_head.parameters()
):
teacher_ps.data.mul_(momentum).add_(
(1 - momentum) * student_ps.data
)
@torch.no_grad()
def update_center(self, teacher_output):
"""Centering 업데이트 (collapse 방지)"""
batch_center = teacher_output.mean(dim=0, keepdim=True)
self.center = self.center * 0.9 + batch_center * 0.1
def dino_loss(self, student_output, teacher_output):
"""DINO Cross-entropy Loss"""
student_out = student_output / self.student_temp
student_out = student_out.chunk(2) # global crops
# Teacher: centering + sharpening
teacher_out = F.softmax(
(teacher_output - self.center) / self.teacher_temp, dim=-1
)
teacher_out = teacher_out.detach().chunk(2)
total_loss = 0
n_loss_terms = 0
for iq, q in enumerate(teacher_out):
for v in range(len(student_out)):
if v == iq:
continue # 같은 뷰는 제외
loss = torch.sum(
-q * F.log_softmax(student_out[v], dim=-1), dim=-1
)
total_loss += loss.mean()
n_loss_terms += 1
return total_loss / n_loss_terms
def forward(self, global_views, local_views, momentum=0.996):
# Student: 모든 뷰 처리
all_views = global_views + local_views
student_output = torch.cat([
self.student_head(self.student(view))
for view in all_views
])
# Teacher: 글로벌 뷰만 처리
with torch.no_grad():
teacher_output = torch.cat([
self.teacher_head(self.teacher(view))
for view in global_views
])
loss = self.dino_loss(student_output, teacher_output)
# 업데이트
self.update_center(teacher_output)
self.update_teacher(momentum)
return loss
# DINO 데이터 증강 (멀티 크롭)
class DINOAugmentation:
def __init__(self, global_crops_scale=(0.4, 1.), local_crops_scale=(0.05, 0.4),
n_local_crops=8, image_size=224, local_size=96):
self.n_local_crops = n_local_crops
# 전역 크롭 (2개, 큰 뷰)
self.global_transform = T.Compose([
T.RandomResizedCrop(image_size, scale=global_crops_scale, interpolation=T.InterpolationMode.BICUBIC),
T.RandomHorizontalFlip(),
T.RandomApply([T.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
T.RandomGrayscale(p=0.2),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
# 로컬 크롭 (8개, 작은 뷰)
self.local_transform = T.Compose([
T.RandomResizedCrop(local_size, scale=local_crops_scale, interpolation=T.InterpolationMode.BICUBIC),
T.RandomHorizontalFlip(),
T.RandomApply([T.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
T.RandomGrayscale(p=0.2),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def __call__(self, image):
global_views = [self.global_transform(image) for _ in range(2)]
local_views = [self.local_transform(image) for _ in range(self.n_local_crops)]
return global_views, local_views
5.4 DINOv2
2023년 발표된 DINOv2는 더 강력합니다:
- 1억 4200만 장의 큐레이션된 데이터로 학습
- iBOT(Image BERT Pre-Training with Online Tokenizer) 손실 추가
- 더 안정적인 학습
6. CLIP
CLIP(Contrastive Language-Image Pretraining)은 2021년 OpenAI가 발표한 혁신적인 모델입니다.
6.1 텍스트-이미지 대조 학습
인터넷에서 수집한 4억 쌍의 (이미지, 텍스트) 데이터로 학습합니다.
이미지 인코더 (ViT / ResNet)
↓
이미지 임베딩
텍스트 인코더 (Transformer)
↓
텍스트 임베딩
목표: 매칭되는 (이미지, 텍스트) 쌍은 유사하게
매칭 안 되는 쌍은 다르게
6.2 CLIP 손실 함수
def clip_loss(image_embeddings, text_embeddings, temperature):
"""
CLIP 대칭 대조 학습 손실
image_embeddings: (N, D)
text_embeddings: (N, D)
"""
# L2 정규화
image_embeddings = F.normalize(image_embeddings, dim=1)
text_embeddings = F.normalize(text_embeddings, dim=1)
# 유사도 행렬 계산: (N, N)
logits = torch.matmul(image_embeddings, text_embeddings.T) / temperature
# 정답: 대각선 원소 (i번째 이미지 = i번째 텍스트)
labels = torch.arange(len(logits)).to(logits.device)
# 이미지 → 텍스트, 텍스트 → 이미지 양방향 손실
loss_i2t = F.cross_entropy(logits, labels)
loss_t2i = F.cross_entropy(logits.T, labels)
return (loss_i2t + loss_t2i) / 2
6.3 CLIP 완전 구현
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import requests
import torch
import torch.nn.functional as F
# 사전학습된 CLIP 모델 로드
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# ---- Zero-shot 이미지 분류 ----
def zero_shot_classify(image_path, candidate_labels):
"""
레이블 예제 없이 이미지 분류
"""
image = Image.open(image_path)
# 텍스트 프롬프트 생성
texts = [f"a photo of a {label}" for label in candidate_labels]
# 인코딩
inputs = processor(
text=texts,
images=image,
return_tensors="pt",
padding=True
)
with torch.no_grad():
outputs = model(**inputs)
# 유사도 계산
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
results = {}
for label, prob in zip(candidate_labels, probs[0]):
results[label] = prob.item()
return results
# 사용 예시
labels = ["cat", "dog", "car", "airplane", "bird"]
results = zero_shot_classify("example.jpg", labels)
for label, prob in sorted(results.items(), key=lambda x: -x[1]):
print(f"{label}: {prob:.4f}")
# ---- 이미지-텍스트 유사도 검색 ----
class CLIPRetrieval:
def __init__(self, model_name="openai/clip-vit-base-patch32"):
self.model = CLIPModel.from_pretrained(model_name)
self.processor = CLIPProcessor.from_pretrained(model_name)
self.model.eval()
self.image_embeddings = None
self.image_paths = []
def index_images(self, image_paths):
"""이미지 데이터베이스 인덱싱"""
self.image_paths = image_paths
embeddings = []
for path in image_paths:
image = Image.open(path)
inputs = self.processor(images=image, return_tensors="pt")
with torch.no_grad():
embedding = self.model.get_image_features(**inputs)
embeddings.append(F.normalize(embedding, dim=1))
self.image_embeddings = torch.cat(embeddings, dim=0)
print(f"{len(image_paths)}개 이미지 인덱싱 완료")
def search(self, query_text, top_k=5):
"""텍스트 쿼리로 이미지 검색"""
inputs = self.processor(text=[query_text], return_tensors="pt", padding=True)
with torch.no_grad():
text_embedding = self.model.get_text_features(**inputs)
text_embedding = F.normalize(text_embedding, dim=1)
# 코사인 유사도
similarities = torch.matmul(text_embedding, self.image_embeddings.T).squeeze()
top_indices = similarities.argsort(descending=True)[:top_k]
results = [
(self.image_paths[i], similarities[i].item())
for i in top_indices
]
return results
# 멀티모달 임베딩 시각화
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np
def visualize_clip_embeddings(images, texts):
"""CLIP 임베딩 t-SNE 시각화"""
image_inputs = processor(images=images, return_tensors="pt")
text_inputs = processor(text=texts, return_tensors="pt", padding=True)
with torch.no_grad():
image_embs = F.normalize(model.get_image_features(**image_inputs), dim=1)
text_embs = F.normalize(model.get_text_features(**text_inputs), dim=1)
# 이미지와 텍스트 임베딩 합치기
all_embs = torch.cat([image_embs, text_embs], dim=0).numpy()
# t-SNE 차원 축소
tsne = TSNE(n_components=2, random_state=42)
reduced = tsne.fit_transform(all_embs)
n = len(images)
plt.figure(figsize=(12, 8))
plt.scatter(reduced[:n, 0], reduced[:n, 1], c='blue', label='Images', alpha=0.7)
plt.scatter(reduced[n:, 0], reduced[n:, 1], c='red', label='Texts', alpha=0.7)
for i, text in enumerate(texts):
plt.annotate(text, reduced[n+i], fontsize=8)
plt.legend()
plt.title('CLIP Embeddings (t-SNE)')
plt.savefig('clip_tsne.png', dpi=150, bbox_inches='tight')
plt.show()
7. BEiT
BEiT(BERT Pre-Training of Image Transformers)는 2021년 Microsoft가 발표했습니다. 이미지를 이산 토큰으로 변환한 후 BERT처럼 학습합니다.
7.1 이산 시각 토큰 (dVAE)
BEiT의 핵심: 이미지를 연속 픽셀이 아닌 이산 토큰으로 예측합니다.
1단계: dVAE로 이미지 → 이산 토큰 변환 (어휘 8192개)
2단계: ViT가 마스킹된 패치의 토큰을 예측
from transformers import BeitForMaskedImageModeling, BeitImageProcessor
import torch
# BEiT 마스크드 이미지 모델링
class BEiTPretraining:
def __init__(self, model_name="microsoft/beit-base-patch16-224"):
self.model = BeitForMaskedImageModeling.from_pretrained(model_name)
self.processor = BeitImageProcessor.from_pretrained(model_name)
def create_bool_masked_pos(self, num_patches, mask_ratio=0.4):
"""마스킹 위치 생성"""
num_masked = int(num_patches * mask_ratio)
mask = torch.zeros(num_patches, dtype=torch.bool)
# 블록 마스킹 (BEiT v2)
masked_indices = torch.randperm(num_patches)[:num_masked]
mask[masked_indices] = True
return mask
def forward(self, images):
inputs = self.processor(images, return_tensors="pt")
pixel_values = inputs.pixel_values
num_patches = (224 // 16) ** 2 # 196
bool_masked_pos = self.create_bool_masked_pos(num_patches)
bool_masked_pos = bool_masked_pos.unsqueeze(0).expand(
pixel_values.size(0), -1
)
outputs = self.model(
pixel_values=pixel_values,
bool_masked_pos=bool_masked_pos
)
return outputs.loss
8. 언어 모델의 사전학습
자연어 처리에서 SSL은 오랜 역사를 가집니다.
8.1 GPT - 다음 단어 예측
자기회귀(Autoregressive) 방식으로 왼쪽에서 오른쪽으로 다음 토큰을 예측합니다.
import torch
import torch.nn as nn
from torch.nn import functional as F
class GPTLanguageModel(nn.Module):
def __init__(self, vocab_size, block_size, n_embd, n_head, n_layer, dropout=0.1):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, n_embd)
self.position_embedding = nn.Embedding(block_size, n_embd)
self.blocks = nn.Sequential(*[
TransformerBlock(n_embd, n_head, dropout=dropout)
for _ in range(n_layer)
])
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size)
self.block_size = block_size
def forward(self, idx, targets=None):
B, T = idx.shape
tok_emb = self.token_embedding(idx) # (B, T, C)
pos_emb = self.position_embedding(
torch.arange(T, device=idx.device)
) # (T, C)
x = tok_emb + pos_emb
x = self.blocks(x)
x = self.ln_f(x)
logits = self.lm_head(x) # (B, T, vocab_size)
if targets is None:
return logits, None
# 다음 토큰 예측 손실
B, T, C = logits.shape
logits_flat = logits.view(B * T, C)
targets_flat = targets.view(B * T)
loss = F.cross_entropy(logits_flat, targets_flat)
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.block_size:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float('-inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
8.2 BERT - 마스킹 언어 모델
양방향 문맥을 학습합니다. 15%의 토큰을 마스킹하고 복원합니다.
from transformers import BertForMaskedLM, BertTokenizer
import torch
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
def create_mlm_input(text, mask_prob=0.15):
"""마스크드 언어 모델 입력 생성"""
tokens = tokenizer.encode(text, return_tensors="pt")
labels = tokens.clone()
# 15% 무작위 마스킹
rand = torch.rand(tokens.shape)
mask_arr = (rand < mask_prob) & (tokens != tokenizer.cls_token_id) & \
(tokens != tokenizer.sep_token_id)
# 마스킹 전략:
# - 80%: [MASK] 토큰으로 교체
# - 10%: 랜덤 토큰으로 교체
# - 10%: 원래 토큰 유지
for i, masked in enumerate(mask_arr[0]):
if masked:
prob = torch.rand(1).item()
if prob < 0.8:
tokens[0][i] = tokenizer.mask_token_id
elif prob < 0.9:
tokens[0][i] = torch.randint(len(tokenizer), (1,))
# 나머지 10%: 원래 토큰 유지
# 마스킹되지 않은 위치는 손실 계산에서 제외
labels[~mask_arr] = -100
return tokens, labels
# MLM 손실 계산
text = "The quick brown fox jumps over the lazy dog"
input_ids, labels = create_mlm_input(text)
with torch.no_grad():
outputs = model(input_ids=input_ids, labels=labels)
print(f"MLM Loss: {outputs.loss.item():.4f}")
# 마스킹된 토큰 예측
input_text = "The [MASK] brown fox"
inputs = tokenizer(input_text, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
mask_idx = (inputs.input_ids == tokenizer.mask_token_id).nonzero()[0][1]
top_5 = logits[0, mask_idx].topk(5)
for token_id, score in zip(top_5.indices, top_5.values):
print(f"{tokenizer.decode([token_id])}: {score:.3f}")
8.3 T5 - 텍스트-텍스트 변환
from transformers import T5ForConditionalGeneration, T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained("t5-base")
model = T5ForConditionalGeneration.from_pretrained("t5-base")
# T5는 모든 태스크를 텍스트-텍스트로 통일
tasks = {
"번역": "translate English to Korean: The weather is nice today",
"요약": "summarize: Scientists have discovered a new species of deep-sea fish...",
"감성 분석": "sentiment analysis: This movie was absolutely fantastic!",
"질의응답": "question: What is the capital of France? context: Paris is the capital..."
}
for task_name, input_text in tasks.items():
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
with torch.no_grad():
output_ids = model.generate(
inputs.input_ids,
max_length=128,
num_beams=4,
early_stopping=True
)
output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(f"\n[{task_name}]")
print(f"입력: {input_text[:60]}...")
print(f"출력: {output}")
9. 실전 활용
9.1 소수 레이블 학습 (Few-shot Learning)
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import numpy as np
class LinearProbe(nn.Module):
"""SSL 특징 위에 선형 분류기"""
def __init__(self, feature_dim, num_classes):
super().__init__()
self.linear = nn.Linear(feature_dim, num_classes)
def forward(self, x):
return self.linear(x)
def extract_features(model, dataloader, device='cuda'):
"""사전학습된 모델로 특징 추출"""
model.eval()
features_list = []
labels_list = []
with torch.no_grad():
for images, labels in dataloader:
images = images.to(device)
features = model.encode(images)
features_list.append(features.cpu())
labels_list.append(labels)
return torch.cat(features_list), torch.cat(labels_list)
def few_shot_evaluation(ssl_model, train_loader, val_loader,
n_shots=[1, 5, 10, 100], device='cuda'):
"""Few-shot 평가: n개 레이블로 선형 분류기 학습"""
ssl_model.eval()
# 전체 특징 추출
all_features, all_labels = extract_features(ssl_model, train_loader, device)
val_features, val_labels = extract_features(ssl_model, val_loader, device)
results = {}
num_classes = len(all_labels.unique())
for n_shot in n_shots:
# n_shot개 레이블만 사용
selected_indices = []
for cls in range(num_classes):
cls_indices = (all_labels == cls).nonzero(as_tuple=True)[0]
selected = cls_indices[torch.randperm(len(cls_indices))[:n_shot]]
selected_indices.append(selected)
selected_indices = torch.cat(selected_indices)
train_feats = all_features[selected_indices]
train_labs = all_labels[selected_indices]
# 선형 분류기 학습
probe = LinearProbe(all_features.shape[1], num_classes).to(device)
optimizer = torch.optim.Adam(probe.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
train_feats = train_feats.to(device)
train_labs = train_labs.to(device)
for epoch in range(100):
probe.train()
logits = probe(train_feats)
loss = criterion(logits, train_labs)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 검증
probe.eval()
with torch.no_grad():
val_logits = probe(val_features.to(device))
preds = val_logits.argmax(dim=1).cpu()
acc = (preds == val_labels).float().mean().item()
results[n_shot] = acc
print(f"{n_shot}-shot 정확도: {acc * 100:.2f}%")
return results
# k-NN 평가 (파라미터 없는 평가)
def knn_evaluation(ssl_model, train_loader, val_loader, k=20, device='cuda'):
"""k-NN 분류기로 SSL 표현 품질 평가"""
ssl_model.eval()
train_features, train_labels = extract_features(ssl_model, train_loader, device)
val_features, val_labels = extract_features(ssl_model, val_loader, device)
# 정규화
train_features = nn.functional.normalize(train_features, dim=1)
val_features = nn.functional.normalize(val_features, dim=1)
# 배치별 k-NN 분류
correct = 0
total = 0
batch_size = 256
for i in range(0, len(val_features), batch_size):
batch_feats = val_features[i:i+batch_size].to(device)
batch_labels = val_labels[i:i+batch_size]
# 코사인 유사도 계산
sims = torch.mm(batch_feats, train_features.to(device).T)
topk_sims, topk_indices = sims.topk(k, dim=1)
# 다수결 투표
topk_labels = train_labels[topk_indices.cpu()]
preds = topk_labels.mode(dim=1).values
correct += (preds == batch_labels).sum().item()
total += len(batch_labels)
accuracy = correct / total
print(f"k-NN ({k}) 정확도: {accuracy * 100:.2f}%")
return accuracy
9.2 도메인 특화 SSL
# 의료 이미지용 SSL (CheXpert 흉부 X-ray 예시)
class MedicalSSL(nn.Module):
"""의료 이미지 특화 자기지도 학습"""
def __init__(self, backbone):
super().__init__()
self.backbone = backbone
# 의료 이미지 특화 증강 (약한 증강)
self.augmentation = T.Compose([
T.RandomResizedCrop(224, scale=(0.8, 1.0)), # 작은 크롭 변화
T.RandomHorizontalFlip(p=0.5),
# 의료 이미지는 색상 변화 최소화
T.RandomApply([T.ColorJitter(0.1, 0.1, 0.0, 0.0)], p=0.3),
T.ToTensor(),
T.Normalize([0.485], [0.229]) # 흑백 X-ray
])
def forward(self, x):
v1 = self.augmentation(x)
v2 = self.augmentation(x)
return v1, v2
# 위성 이미지용 SSL
class SatelliteSSL:
"""위성 이미지 특화 증강 (지리적 특성 보존)"""
def __init__(self):
self.transform = T.Compose([
T.RandomResizedCrop(256, scale=(0.4, 1.0)),
# 위성 이미지: 회전 불변성 중요
T.RandomRotation(90),
T.RandomHorizontalFlip(),
T.RandomVerticalFlip(),
# 시간대/계절 변화 시뮬레이션
T.ColorJitter(brightness=0.4, contrast=0.4),
T.ToTensor(),
])
# 학습 성능 비교
def compare_ssl_methods(dataset, num_epochs=100):
methods = {
'SimCLR': SimCLR(backbone='resnet50'),
'BYOL': BYOL(backbone=resnet50()),
'MAE': MAE(img_size=224),
}
results = {}
for name, model in methods.items():
print(f"\n{name} 학습 중...")
model = model.cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for epoch in range(num_epochs):
train_one_epoch(model, dataset, optimizer)
# 100-shot 평가
acc = few_shot_evaluation(model, dataset, n_shots=[100])
results[name] = acc[100]
print(f"{name} 100-shot 정확도: {acc[100]*100:.2f}%")
return results
9.3 SSL 모델 허브 활용
# Hugging Face에서 사전학습된 SSL 모델 사용
from transformers import (
ViTModel, ViTFeatureExtractor,
DeiTModel,
CLIPModel, CLIPProcessor
)
from PIL import Image
import torch
# ViT (MAE로 사전학습된 버전)
def load_mae_pretrained():
model = ViTModel.from_pretrained("facebook/vit-mae-base")
feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")
return model, feature_extractor
# DINO ViT 특징 추출
def extract_dino_features(image_path):
from transformers import ViTModel
model = ViTModel.from_pretrained("facebook/dino-vits16")
processor = ViTFeatureExtractor.from_pretrained("facebook/dino-vits16")
image = Image.open(image_path)
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# [CLS] 토큰 특징 (전체 이미지 표현)
cls_features = outputs.last_hidden_state[:, 0, :]
# 패치별 특징 (공간 정보 포함)
patch_features = outputs.last_hidden_state[:, 1:, :]
return cls_features, patch_features
# CLIP으로 이미지 특징 추출 후 다운스트림 태스크
def clip_feature_extraction(images, texts=None):
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
if texts:
inputs = processor(text=texts, images=images, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
return outputs.image_embeds, outputs.text_embeds
else:
inputs = processor(images=images, return_tensors="pt")
with torch.no_grad():
image_features = model.get_image_features(**inputs)
return image_features
마무리
자기지도 학습은 AI의 판도를 바꾸고 있습니다. 핵심 정리:
| 방법 | 연도 | 핵심 아이디어 | 강점 |
|---|---|---|---|
| SimCLR | 2020 | 강한 증강 + 대조 학습 | 단순하고 강력 |
| MoCo | 2020 | 모멘텀 큐 | 배치 크기 독립 |
| BYOL | 2020 | 네거티브 없는 대조 학습 | 배치 크기 독립 |
| MAE | 2021 | 75% 마스킹 + 픽셀 복원 | 효율적, ViT 최적 |
| DINO | 2021 | Self-distillation | 세그멘테이션 특성 |
| CLIP | 2021 | 이미지-텍스트 대조 | Zero-shot 분류 |
| BEiT | 2021 | 이산 토큰 마스킹 | BERT 스타일 |
| DINOv2 | 2023 | 대규모 + iBOT | 범용 특징 추출 |
학습 로드맵:
- SimCLR 구현으로 대조 학습 이해
- MAE로 마스킹 접근법 이해
- CLIP으로 멀티모달 학습 적용
- 도메인 데이터에 파인튜닝
SSL은 이제 LLM의 사전학습, 컴퓨터 비전의 근간이 되었습니다. 레이블 없는 대규모 데이터를 효과적으로 활용하는 SSL은 앞으로도 AI 발전의 핵심 동력이 될 것입니다.
참고 자료
Self-Supervised Learning Complete Guide: SimCLR, MAE, DINO, CLIP
Table of Contents
- What is Self-Supervised Learning?
- Early SSL Methods
- Contrastive Learning
- Masked Autoencoder (MAE)
- DINO - Self-Distillation with No Labels
- CLIP - Contrastive Language-Image Pretraining
- BEiT - BERT for Images
- Language Model Pretraining
- Practical Applications
1. What is Self-Supervised Learning?
1.1 Comparing Learning Paradigms
Let's start by understanding the three major learning paradigms in deep learning.
Supervised Learning trains on labeled data. ImageNet's 10 million images each carry a human-annotated label, and models like ResNet and ViT are trained on this data. The problem is that label collection is extremely expensive. A single medical image can take a specialist dozens of minutes to annotate.
Unsupervised Learning discovers patterns in data without labels. K-Means clustering, PCA, and GANs fall into this category. However, traditional unsupervised learning struggles to learn representations directly usable for downstream tasks.
Self-Supervised Learning (SSL) combines the best of both worlds. It leverages large amounts of unlabeled data while generating supervisory signals from the data itself.
Core idea: The data itself is the label.
The internet contains billions of images and trillions of words. SSL exploits this vast unlabeled data to learn powerful representations.
1.2 The Label Scarcity Problem
In the real world, label collection is extremely difficult.
| Domain | Label Collection Cost | Reason |
|---|---|---|
| Medical images | Very high | Requires specialist physicians |
| Satellite imagery | High | Requires geographic expertise |
| Legal documents | High | Requires legal experts |
| Industrial defect detection | High | Rare defect samples |
| Speech recognition | Medium | Transcription required |
SSL fundamentally solves this problem. The approach is to first pretrain on large-scale unlabeled data, then fine-tune with a small amount of labeled data.
1.3 Pretext Tasks
The key to SSL is the pretext task. We create artificial prediction problems from the data itself to force the model to learn meaningful representations.
Conditions for a good pretext task:
- Must be constructable automatically without labels
- Solving the task well must require meaningful representations
- Must not be too easy or too hard
Examples:
- Rotation prediction: Rotate an image by 0°, 90°, 180°, 270° and predict how many degrees it was rotated
- Masking: Cover part of an image/text and restore it
- Contrastive learning: Make two views of the same image similar in embedding space
1.4 Applications of SSL
SSL has become the foundation of modern AI:
- GPT, BERT: Text SSL → basis for ChatGPT, Gemini
- CLIP: Image-text SSL → basis for DALL-E, Stable Diffusion
- MAE, DINO: Image SSL → basis for medical imaging, autonomous driving
- wav2vec: Audio SSL → basis for speech recognition
2. Early SSL Methods
Early SSL research explored a variety of pretext tasks.
2.1 Rotation Prediction
Proposed by Gidaris et al. in 2018. The model predicts which of four rotations (0°, 90°, 180°, 270°) was applied to an image.
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
class RotationSSL(nn.Module):
def __init__(self, backbone, num_classes=4):
super().__init__()
self.backbone = backbone
self.classifier = nn.Linear(backbone.output_dim, num_classes)
def forward(self, x):
features = self.backbone(x)
return self.classifier(features)
def create_rotation_dataset(images):
"""Create (image, label) pairs by rotating images in 4 directions"""
rotated_images = []
labels = []
for img in images:
for k in range(4): # 0, 90, 180, 270 degrees
rotated = T.functional.rotate(img, k * 90)
rotated_images.append(rotated)
labels.append(k)
return rotated_images, labels
def train_rotation_ssl(model, dataloader, optimizer, epochs=10):
criterion = nn.CrossEntropyLoss()
model.train()
for epoch in range(epochs):
total_loss = 0
for images, labels in dataloader:
images, labels = images.cuda(), labels.cuda()
optimizer.zero_grad()
logits = model(images)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
Limitation: While natural images have clear orientations, textures and sky images may not benefit.
2.2 Jigsaw Puzzle Solving
Proposed by Noroozi & Favaro in 2016. An image is divided into a 3x3 grid and shuffled; the model must predict the original ordering.
import itertools
import numpy as np
class JigsawSSL(nn.Module):
def __init__(self, backbone, num_permutations=100):
super().__init__()
self.backbone = backbone
self.patch_encoder = nn.Sequential(
backbone,
nn.Linear(backbone.output_dim, 512)
)
self.classifier = nn.Sequential(
nn.Linear(512 * 9, 4096),
nn.ReLU(),
nn.Linear(4096, num_permutations)
)
def forward(self, patches):
# patches: (batch, 9, C, H, W)
batch_size = patches.size(0)
patch_features = []
for i in range(9):
feat = self.patch_encoder(patches[:, i]) # (batch, 512)
patch_features.append(feat)
combined = torch.cat(patch_features, dim=1) # (batch, 512*9)
return self.classifier(combined)
def create_jigsaw_dataset(image, grid_size=3):
"""Split image into grid_size x grid_size patches and shuffle"""
h, w = image.shape[-2:]
patch_h, patch_w = h // grid_size, w // grid_size
patches = []
for i in range(grid_size):
for j in range(grid_size):
patch = image[...,
i*patch_h:(i+1)*patch_h,
j*patch_w:(j+1)*patch_w]
patches.append(patch)
permutation_idx = np.random.randint(0, len(PREDEFINED_PERMUTATIONS))
perm = PREDEFINED_PERMUTATIONS[permutation_idx]
shuffled = [patches[p] for p in perm]
return torch.stack(shuffled), permutation_idx
2.3 Colorization
Proposed by Zhang et al. in 2016. A grayscale image is given as input; the model predicts the colors.
class ColorizationSSL(nn.Module):
def __init__(self):
super().__init__()
# L channel input (lightness)
self.encoder = nn.Sequential(
nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(),
nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.ReLU(),
nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.ReLU(),
)
# ab channel output (color)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), nn.ReLU(),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.ReLU(),
nn.Conv2d(64, 2, 3, padding=1), # ab channels
nn.Tanh()
)
def forward(self, l_channel):
features = self.encoder(l_channel)
ab_channels = self.decoder(features)
return ab_channels
from skimage.color import rgb2lab, lab2rgb
def prepare_colorization_data(rgb_image):
lab = rgb2lab(rgb_image)
l_channel = lab[:, :, 0:1] # lightness channel
ab_channels = lab[:, :, 1:] # color channels
return l_channel, ab_channels
3. Contrastive Learning
Contrastive Learning is the core of modern SSL. It exploded in development around 2020.
3.1 Core Idea
Pull similar things together, push different things apart.
Two different views (crop, flip, color jitter, etc.) of the same image should be close together in embedding space, while different images should be far apart.
Image x
├── Augmented view v1 → z1 (embedding)
└── Augmented view v2 → z2 (embedding)
Goal: z1 and z2 close together, z1 and z3 (from different image) far apart
3.2 InfoNCE Loss
The standard loss function for contrastive learning — also called NT-Xent (Normalized Temperature-scaled Cross Entropy).
where sim is cosine similarity and τ is the temperature parameter.
import torch
import torch.nn.functional as F
def info_nce_loss(z1, z2, temperature=0.5):
"""
InfoNCE / NT-Xent Loss
z1, z2: (batch_size, embedding_dim) - two views of the same images
"""
batch_size = z1.size(0)
# L2 normalize
z1 = F.normalize(z1, dim=1)
z2 = F.normalize(z2, dim=1)
# Concatenate: [z1_1, ..., z1_N, z2_1, ..., z2_N]
z = torch.cat([z1, z2], dim=0) # (2N, D)
# Compute pairwise similarities
similarity = torch.mm(z, z.t()) / temperature # (2N, 2N)
# Exclude self-similarity (diagonal = -inf)
mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
similarity.masked_fill_(mask, float('-inf'))
# Positive pairs: index i and i+N
labels = torch.cat([
torch.arange(batch_size, 2 * batch_size),
torch.arange(batch_size)
]).to(z.device)
loss = F.cross_entropy(similarity, labels)
return loss
# Usage
z1 = torch.randn(32, 128)
z2 = torch.randn(32, 128)
loss = info_nce_loss(z1, z2, temperature=0.5)
print(f"InfoNCE Loss: {loss.item():.4f}")
3.3 SimCLR
Published by Google in 2020. Simple Framework for Contrastive Learning. Simple but powerful.
Key Components:
- Data augmentation (strong augmentation is critical)
- Encoder (ResNet)
- Projection Head (MLP)
- NT-Xent Loss
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as T
class SimCLR(nn.Module):
def __init__(self, backbone='resnet50', projection_dim=128, temperature=0.5):
super().__init__()
self.temperature = temperature
if backbone == 'resnet50':
resnet = models.resnet50(pretrained=False)
self.encoder = nn.Sequential(*list(resnet.children())[:-1])
self.feature_dim = 2048
elif backbone == 'resnet18':
resnet = models.resnet18(pretrained=False)
self.encoder = nn.Sequential(*list(resnet.children())[:-1])
self.feature_dim = 512
# Projection Head (removed during fine-tuning)
self.projection_head = nn.Sequential(
nn.Linear(self.feature_dim, self.feature_dim),
nn.ReLU(inplace=True),
nn.Linear(self.feature_dim, projection_dim)
)
def encode(self, x):
h = self.encoder(x)
h = h.squeeze(-1).squeeze(-1)
return h
def forward(self, x):
h = self.encode(x)
z = self.projection_head(h)
return z
def contrastive_loss(self, z1, z2):
return info_nce_loss(z1, z2, self.temperature)
class SimCLRDataAugmentation:
"""SimCLR's strong data augmentation"""
def __init__(self, image_size=224, s=1.0):
color_jitter = T.ColorJitter(
brightness=0.8*s,
contrast=0.8*s,
saturation=0.8*s,
hue=0.2*s
)
self.transform = T.Compose([
T.RandomResizedCrop(image_size, scale=(0.2, 1.0)),
T.RandomHorizontalFlip(p=0.5),
T.RandomApply([color_jitter], p=0.8),
T.RandomGrayscale(p=0.2),
T.GaussianBlur(kernel_size=int(0.1 * image_size)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def __call__(self, x):
return self.transform(x), self.transform(x)
def train_simclr(model, dataloader, optimizer, epochs=200):
model.train()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=epochs
)
for epoch in range(epochs):
total_loss = 0
for (x1, x2), _ in dataloader:
x1, x2 = x1.cuda(), x2.cuda()
z1 = model(x1)
z2 = model(x2)
loss = model.contrastive_loss(z1, z2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
scheduler.step()
avg_loss = total_loss / len(dataloader)
if (epoch + 1) % 10 == 0:
print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")
model = SimCLR(backbone='resnet50', projection_dim=128).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-6)
def get_finetuning_model(pretrained_simclr, num_classes):
encoder = pretrained_simclr.encoder
classifier = nn.Linear(pretrained_simclr.feature_dim, num_classes)
return nn.Sequential(encoder, nn.Flatten(), classifier)
3.4 MoCo (Momentum Contrast)
Published by Facebook AI in 2020. Enables using many negative samples without requiring a large batch size.
Key idea: Stable key encoder via momentum update + Queue for managing negative samples
class MoCo(nn.Module):
def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07):
super().__init__()
self.K = K # queue size
self.m = m # momentum coefficient
self.T = T # temperature
# Query encoder and key encoder
self.encoder_q = base_encoder(num_classes=dim)
self.encoder_k = base_encoder(num_classes=dim)
# Key encoder: no gradients, updated by momentum
for param_q, param_k in zip(
self.encoder_q.parameters(),
self.encoder_k.parameters()
):
param_k.data.copy_(param_q.data)
param_k.requires_grad = False
# Negative sample queue
self.register_buffer("queue", torch.randn(dim, K))
self.queue = F.normalize(self.queue, dim=0)
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def _momentum_update_key_encoder(self):
for param_q, param_k in zip(
self.encoder_q.parameters(),
self.encoder_k.parameters()
):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
self.queue[:, ptr:ptr + batch_size] = keys.T
ptr = (ptr + batch_size) % self.K
self.queue_ptr[0] = ptr
def forward(self, im_q, im_k):
q = self.encoder_q(im_q)
q = F.normalize(q, dim=1)
with torch.no_grad():
self._momentum_update_key_encoder()
k = self.encoder_k(im_k)
k = F.normalize(k, dim=1)
# Positive logits: (batch, 1)
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
# Negative logits: (batch, K)
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
logits = torch.cat([l_pos, l_neg], dim=1) / self.T
labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
loss = F.cross_entropy(logits, labels)
self._dequeue_and_enqueue(k)
return loss
3.5 BYOL (Bootstrap Your Own Latent)
Published by DeepMind in 2020. A groundbreaking method that works without negative samples.
Key insight: An online network predicts the representation of a target network. The target is updated by momentum.
import copy
class BYOL(nn.Module):
def __init__(self, backbone, projection_dim=256, prediction_dim=128, tau=0.996):
super().__init__()
self.tau = tau
# Online network
self.online_encoder = backbone
self.online_projector = nn.Sequential(
nn.Linear(2048, 4096), nn.BatchNorm1d(4096), nn.ReLU(),
nn.Linear(4096, projection_dim)
)
self.online_predictor = nn.Sequential(
nn.Linear(projection_dim, 4096), nn.BatchNorm1d(4096), nn.ReLU(),
nn.Linear(4096, prediction_dim)
)
# Target network (no gradients)
self.target_encoder = copy.deepcopy(backbone)
self.target_projector = copy.deepcopy(self.online_projector)
for param in self.target_encoder.parameters():
param.requires_grad = False
for param in self.target_projector.parameters():
param.requires_grad = False
@torch.no_grad()
def update_target(self):
for online, target in zip(
self.online_encoder.parameters(),
self.target_encoder.parameters()
):
target.data = self.tau * target.data + (1 - self.tau) * online.data
for online, target in zip(
self.online_projector.parameters(),
self.target_projector.parameters()
):
target.data = self.tau * target.data + (1 - self.tau) * online.data
def forward(self, x1, x2):
online_feat1 = self.online_encoder(x1).squeeze()
online_proj1 = self.online_projector(online_feat1)
online_pred1 = self.online_predictor(online_proj1)
online_feat2 = self.online_encoder(x2).squeeze()
online_proj2 = self.online_projector(online_feat2)
online_pred2 = self.online_predictor(online_proj2)
with torch.no_grad():
target_feat1 = self.target_encoder(x1).squeeze()
target_proj1 = self.target_projector(target_feat1)
target_feat2 = self.target_encoder(x2).squeeze()
target_proj2 = self.target_projector(target_feat2)
loss1 = byol_loss(online_pred1, target_proj2.detach())
loss2 = byol_loss(online_pred2, target_proj1.detach())
return (loss1 + loss2).mean()
def byol_loss(p, z):
p = F.normalize(p, dim=-1)
z = F.normalize(z, dim=-1)
return -(p * z).sum(dim=-1)
3.6 SimSiam
Published by Facebook AI in 2021. Even simpler than BYOL — no momentum needed to prevent collapse.
class SimSiam(nn.Module):
def __init__(self, backbone, dim=2048, pred_dim=512):
super().__init__()
self.encoder = backbone
self.projector = nn.Sequential(
nn.Linear(dim, dim, bias=False),
nn.BatchNorm1d(dim),
nn.ReLU(inplace=True),
nn.Linear(dim, dim, bias=False),
nn.BatchNorm1d(dim),
nn.ReLU(inplace=True),
nn.Linear(dim, dim, bias=False),
nn.BatchNorm1d(dim, affine=False)
)
self.predictor = nn.Sequential(
nn.Linear(dim, pred_dim, bias=False),
nn.BatchNorm1d(pred_dim),
nn.ReLU(inplace=True),
nn.Linear(pred_dim, dim)
)
def forward(self, x1, x2):
z1 = self.projector(self.encoder(x1).squeeze())
z2 = self.projector(self.encoder(x2).squeeze())
p1 = self.predictor(z1)
p2 = self.predictor(z2)
# Stop gradient - the key to the asymmetric design
loss = simsiam_loss(p1, z2.detach()) / 2 + \
simsiam_loss(p2, z1.detach()) / 2
return loss
def simsiam_loss(p, z):
z = z.detach() # Stop gradient
p = F.normalize(p, dim=1)
z = F.normalize(z, dim=1)
return -(p * z).sum(dim=1).mean()
4. Masked Autoencoder (MAE)
MAE, published by Kaiming He et al. in 2021, applies BERT's masking approach from NLP to images.
4.1 Core Idea
Mask 75% of an image and reconstruct the full image from the remaining 25%.
Why such a high masking ratio of 75%?
- Text tokens carry high semantic density, so 15% masking suffices for BERT
- Adjacent pixels in images have high redundancy — 75% masking forces genuine understanding
4.2 Asymmetric Encoder-Decoder
MAE's innovation: The encoder processes only visible patches; the decoder reconstructs the full image.
Input image (196 patches)
↓ 75% masking
49 visible patches → Encoder (ViT-Large) → Encoded representation
↓
196 patches (with mask tokens) → Decoder (shallow ViT) → Pixel reconstruction
The encoder processes only 25% of the original size — 3-4x faster training.
4.3 Complete MAE Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
class PatchEmbedding(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
self.projection = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
x = self.projection(x) # (batch, embed_dim, H/p, W/p)
x = x.flatten(2) # (batch, embed_dim, num_patches)
x = x.transpose(1, 2) # (batch, num_patches, embed_dim)
return x
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., dropout=0.):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
self.norm2 = nn.LayerNorm(dim)
mlp_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
norm_x = self.norm1(x)
attn_out, _ = self.attn(norm_x, norm_x, norm_x)
x = x + attn_out
x = x + self.mlp(self.norm2(x))
return x
class MAE(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
in_channels=3,
encoder_dim=768,
encoder_depth=12,
encoder_heads=12,
decoder_dim=512,
decoder_depth=8,
decoder_heads=16,
mask_ratio=0.75,
):
super().__init__()
self.mask_ratio = mask_ratio
self.patch_size = patch_size
num_patches = (img_size // patch_size) ** 2
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, encoder_dim)
self.cls_token = nn.Parameter(torch.zeros(1, 1, encoder_dim))
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, encoder_dim),
requires_grad=False
)
self.encoder_blocks = nn.ModuleList([
TransformerBlock(encoder_dim, encoder_heads)
for _ in range(encoder_depth)
])
self.encoder_norm = nn.LayerNorm(encoder_dim)
self.decoder_embed = nn.Linear(encoder_dim, decoder_dim, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))
self.decoder_pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, decoder_dim),
requires_grad=False
)
self.decoder_blocks = nn.ModuleList([
TransformerBlock(decoder_dim, decoder_heads)
for _ in range(decoder_depth)
])
self.decoder_norm = nn.LayerNorm(decoder_dim)
self.decoder_pred = nn.Linear(decoder_dim, patch_size**2 * in_channels)
def random_masking(self, x, mask_ratio):
N, L, D = x.shape
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device)
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(
x, dim=1,
index=ids_keep.unsqueeze(-1).repeat(1, 1, D)
)
mask = torch.ones(N, L, device=x.device)
mask[:, :len_keep] = 0
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
def forward_encoder(self, x, mask_ratio):
x = self.patch_embed(x)
x = x + self.pos_embed[:, 1:, :]
x, mask, ids_restore = self.random_masking(x, mask_ratio)
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
for block in self.encoder_blocks:
x = block(x)
x = self.encoder_norm(x)
return x, mask, ids_restore
def forward_decoder(self, x, ids_restore):
x = self.decoder_embed(x)
mask_tokens = self.mask_token.repeat(
x.shape[0],
ids_restore.shape[1] + 1 - x.shape[1],
1
)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
x_ = torch.gather(
x_, dim=1,
index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
)
x = torch.cat([x[:, :1, :], x_], dim=1)
x = x + self.decoder_pos_embed
for block in self.decoder_blocks:
x = block(x)
x = self.decoder_norm(x)
x = self.decoder_pred(x)
x = x[:, 1:, :] # Remove CLS token
return x
def forward(self, imgs, mask_ratio=0.75):
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
pred = self.forward_decoder(latent, ids_restore)
loss = self.forward_loss(imgs, pred, mask)
return loss, pred, mask
def forward_loss(self, imgs, pred, mask):
target = self.patchify(imgs)
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.e-6).sqrt()
loss = (pred - target) ** 2
loss = loss.mean(dim=-1)
loss = (loss * mask).sum() / mask.sum()
return loss
def patchify(self, imgs):
p = self.patch_size
h = w = imgs.shape[2] // p
x = imgs.reshape(imgs.shape[0], 3, h, p, w, p)
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(imgs.shape[0], h * w, p**2 * 3)
return x
# Create and train the MAE model
mae_model = MAE(
img_size=224,
patch_size=16,
encoder_dim=768,
encoder_depth=12,
encoder_heads=12,
decoder_dim=512,
decoder_depth=8,
decoder_heads=16,
mask_ratio=0.75
).cuda()
optimizer = torch.optim.AdamW(
mae_model.parameters(),
lr=1.5e-4,
betas=(0.9, 0.95),
weight_decay=0.05
)
def train_step(model, imgs, optimizer):
model.train()
optimizer.zero_grad()
loss, pred, mask = model(imgs, mask_ratio=0.75)
loss.backward()
optimizer.step()
return loss.item()
5. DINO
DINO (Self-DIstillation with NO labels) was published by Facebook AI Research in 2021. When ViT is trained with SSL, remarkable properties emerge.
5.1 Key Discovery
DINO-trained ViT's self-attention maps perform semantic segmentation without labels. The model perfectly separates foreground from background.
5.2 Self-Distillation Architecture
Student Network Teacher Network
↑ ↑
Local Views (many) Global Views (2)
- Teacher: Processes full-image views, richer context
- Student: Processes local crops, learns to match teacher's predictions
- Teacher update: EMA of the student (no gradients)
5.3 Centering and Sharpening
Two techniques to prevent collapse:
class DINOHead(nn.Module):
def __init__(self, in_dim, out_dim=65536, bottleneck_dim=256):
super().__init__()
layers = [nn.Linear(in_dim, 2048), nn.GELU()]
layers += [nn.Linear(2048, 2048), nn.GELU()]
layers += [nn.Linear(2048, bottleneck_dim)]
self.mlp = nn.Sequential(*layers)
self.last_layer = nn.utils.weight_norm(
nn.Linear(bottleneck_dim, out_dim, bias=False)
)
self.last_layer.weight_g.data.fill_(1)
self.last_layer.weight_g.requires_grad = False
def forward(self, x):
x = self.mlp(x)
x = F.normalize(x, dim=-1)
x = self.last_layer(x)
return x
class DINO(nn.Module):
def __init__(self, backbone, out_dim=65536, teacher_temp=0.04, student_temp=0.1):
super().__init__()
self.student_temp = student_temp
self.teacher_temp = teacher_temp
self.student = backbone
self.student_head = DINOHead(self.student.embed_dim, out_dim)
self.teacher = copy.deepcopy(backbone)
self.teacher_head = copy.deepcopy(self.student_head)
for p in self.teacher.parameters():
p.requires_grad = False
for p in self.teacher_head.parameters():
p.requires_grad = False
self.register_buffer("center", torch.zeros(1, out_dim))
@torch.no_grad()
def update_teacher(self, momentum):
for student_ps, teacher_ps in zip(
self.student.parameters(), self.teacher.parameters()
):
teacher_ps.data.mul_(momentum).add_((1 - momentum) * student_ps.data)
for student_ps, teacher_ps in zip(
self.student_head.parameters(), self.teacher_head.parameters()
):
teacher_ps.data.mul_(momentum).add_((1 - momentum) * student_ps.data)
@torch.no_grad()
def update_center(self, teacher_output):
batch_center = teacher_output.mean(dim=0, keepdim=True)
self.center = self.center * 0.9 + batch_center * 0.1
def dino_loss(self, student_output, teacher_output):
student_out = student_output / self.student_temp
student_out = student_out.chunk(2)
# Teacher: centering + sharpening
teacher_out = F.softmax(
(teacher_output - self.center) / self.teacher_temp, dim=-1
)
teacher_out = teacher_out.detach().chunk(2)
total_loss = 0
n_loss_terms = 0
for iq, q in enumerate(teacher_out):
for v in range(len(student_out)):
if v == iq:
continue
loss = torch.sum(
-q * F.log_softmax(student_out[v], dim=-1), dim=-1
)
total_loss += loss.mean()
n_loss_terms += 1
return total_loss / n_loss_terms
def forward(self, global_views, local_views, momentum=0.996):
all_views = global_views + local_views
student_output = torch.cat([
self.student_head(self.student(view))
for view in all_views
])
with torch.no_grad():
teacher_output = torch.cat([
self.teacher_head(self.teacher(view))
for view in global_views
])
loss = self.dino_loss(student_output, teacher_output)
self.update_center(teacher_output)
self.update_teacher(momentum)
return loss
class DINOAugmentation:
def __init__(self, global_crops_scale=(0.4, 1.), local_crops_scale=(0.05, 0.4),
n_local_crops=8, image_size=224, local_size=96):
self.n_local_crops = n_local_crops
self.global_transform = T.Compose([
T.RandomResizedCrop(image_size, scale=global_crops_scale,
interpolation=T.InterpolationMode.BICUBIC),
T.RandomHorizontalFlip(),
T.RandomApply([T.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
T.RandomGrayscale(p=0.2),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
self.local_transform = T.Compose([
T.RandomResizedCrop(local_size, scale=local_crops_scale,
interpolation=T.InterpolationMode.BICUBIC),
T.RandomHorizontalFlip(),
T.RandomApply([T.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
T.RandomGrayscale(p=0.2),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def __call__(self, image):
global_views = [self.global_transform(image) for _ in range(2)]
local_views = [self.local_transform(image) for _ in range(self.n_local_crops)]
return global_views, local_views
5.4 DINOv2
DINOv2, released in 2023, is even more powerful:
- Trained on 142 million curated images
- Added iBOT (Image BERT Pre-Training with Online Tokenizer) loss
- More stable training dynamics
6. CLIP
CLIP (Contrastive Language-Image Pretraining) is an innovative model published by OpenAI in 2021.
6.1 Text-Image Contrastive Learning
Trained on 400 million (image, text) pairs collected from the internet.
Image Encoder (ViT / ResNet)
↓
Image embedding
Text Encoder (Transformer)
↓
Text embedding
Goal: Matching (image, text) pairs → similar embeddings
Non-matching pairs → dissimilar embeddings
6.2 CLIP Loss
def clip_loss(image_embeddings, text_embeddings, temperature):
"""
CLIP Symmetric Contrastive Loss
image_embeddings: (N, D)
text_embeddings: (N, D)
"""
image_embeddings = F.normalize(image_embeddings, dim=1)
text_embeddings = F.normalize(text_embeddings, dim=1)
# Similarity matrix: (N, N)
logits = torch.matmul(image_embeddings, text_embeddings.T) / temperature
# Labels: diagonal (image i matches text i)
labels = torch.arange(len(logits)).to(logits.device)
# Bidirectional loss: image→text and text→image
loss_i2t = F.cross_entropy(logits, labels)
loss_t2i = F.cross_entropy(logits.T, labels)
return (loss_i2t + loss_t2i) / 2
6.3 Full CLIP Usage
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
import torch.nn.functional as F
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# ---- Zero-shot Image Classification ----
def zero_shot_classify(image_path, candidate_labels):
"""Classify an image without any label examples"""
image = Image.open(image_path)
# Create text prompts
texts = [f"a photo of a {label}" for label in candidate_labels]
inputs = processor(
text=texts,
images=image,
return_tensors="pt",
padding=True
)
with torch.no_grad():
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
results = {}
for label, prob in zip(candidate_labels, probs[0]):
results[label] = prob.item()
return results
# Usage
labels = ["cat", "dog", "car", "airplane", "bird"]
results = zero_shot_classify("example.jpg", labels)
for label, prob in sorted(results.items(), key=lambda x: -x[1]):
print(f"{label}: {prob:.4f}")
# ---- Image-Text Similarity Search ----
class CLIPRetrieval:
def __init__(self, model_name="openai/clip-vit-base-patch32"):
self.model = CLIPModel.from_pretrained(model_name)
self.processor = CLIPProcessor.from_pretrained(model_name)
self.model.eval()
self.image_embeddings = None
self.image_paths = []
def index_images(self, image_paths):
"""Index an image database"""
self.image_paths = image_paths
embeddings = []
for path in image_paths:
image = Image.open(path)
inputs = self.processor(images=image, return_tensors="pt")
with torch.no_grad():
embedding = self.model.get_image_features(**inputs)
embeddings.append(F.normalize(embedding, dim=1))
self.image_embeddings = torch.cat(embeddings, dim=0)
print(f"Indexed {len(image_paths)} images")
def search(self, query_text, top_k=5):
"""Search images using a text query"""
inputs = self.processor(text=[query_text], return_tensors="pt", padding=True)
with torch.no_grad():
text_embedding = self.model.get_text_features(**inputs)
text_embedding = F.normalize(text_embedding, dim=1)
similarities = torch.matmul(text_embedding, self.image_embeddings.T).squeeze()
top_indices = similarities.argsort(descending=True)[:top_k]
results = [
(self.image_paths[i], similarities[i].item())
for i in top_indices
]
return results
# Visualize CLIP embeddings with t-SNE
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np
def visualize_clip_embeddings(images, texts):
"""t-SNE visualization of CLIP embeddings"""
image_inputs = processor(images=images, return_tensors="pt")
text_inputs = processor(text=texts, return_tensors="pt", padding=True)
with torch.no_grad():
image_embs = F.normalize(model.get_image_features(**image_inputs), dim=1)
text_embs = F.normalize(model.get_text_features(**text_inputs), dim=1)
all_embs = torch.cat([image_embs, text_embs], dim=0).numpy()
tsne = TSNE(n_components=2, random_state=42)
reduced = tsne.fit_transform(all_embs)
n = len(images)
plt.figure(figsize=(12, 8))
plt.scatter(reduced[:n, 0], reduced[:n, 1], c='blue', label='Images', alpha=0.7)
plt.scatter(reduced[n:, 0], reduced[n:, 1], c='red', label='Texts', alpha=0.7)
for i, text in enumerate(texts):
plt.annotate(text, reduced[n+i], fontsize=8)
plt.legend()
plt.title('CLIP Embeddings (t-SNE)')
plt.savefig('clip_tsne.png', dpi=150, bbox_inches='tight')
plt.show()
7. BEiT
BEiT (BERT Pre-Training of Image Transformers) was published by Microsoft in 2021. Images are converted to discrete tokens and trained in a BERT-like fashion.
7.1 Discrete Visual Tokens (dVAE)
BEiT's core: instead of predicting raw pixels, the model predicts discrete tokens.
Step 1: dVAE converts image → discrete tokens (vocabulary of 8192)
Step 2: ViT predicts the token of each masked patch
from transformers import BeitForMaskedImageModeling, BeitImageProcessor
import torch
class BEiTPretraining:
def __init__(self, model_name="microsoft/beit-base-patch16-224"):
self.model = BeitForMaskedImageModeling.from_pretrained(model_name)
self.processor = BeitImageProcessor.from_pretrained(model_name)
def create_bool_masked_pos(self, num_patches, mask_ratio=0.4):
num_masked = int(num_patches * mask_ratio)
mask = torch.zeros(num_patches, dtype=torch.bool)
masked_indices = torch.randperm(num_patches)[:num_masked]
mask[masked_indices] = True
return mask
def forward(self, images):
inputs = self.processor(images, return_tensors="pt")
pixel_values = inputs.pixel_values
num_patches = (224 // 16) ** 2 # 196
bool_masked_pos = self.create_bool_masked_pos(num_patches)
bool_masked_pos = bool_masked_pos.unsqueeze(0).expand(
pixel_values.size(0), -1
)
outputs = self.model(
pixel_values=pixel_values,
bool_masked_pos=bool_masked_pos
)
return outputs.loss
8. Language Model Pretraining
SSL has a long history in NLP.
8.1 GPT - Next Token Prediction
Autoregressive approach that predicts the next token from left to right.
import torch
import torch.nn as nn
from torch.nn import functional as F
class GPTLanguageModel(nn.Module):
def __init__(self, vocab_size, block_size, n_embd, n_head, n_layer, dropout=0.1):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, n_embd)
self.position_embedding = nn.Embedding(block_size, n_embd)
self.blocks = nn.Sequential(*[
TransformerBlock(n_embd, n_head, dropout=dropout)
for _ in range(n_layer)
])
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size)
self.block_size = block_size
def forward(self, idx, targets=None):
B, T = idx.shape
tok_emb = self.token_embedding(idx)
pos_emb = self.position_embedding(torch.arange(T, device=idx.device))
x = tok_emb + pos_emb
x = self.blocks(x)
x = self.ln_f(x)
logits = self.lm_head(x)
if targets is None:
return logits, None
B, T, C = logits.shape
loss = F.cross_entropy(logits.view(B * T, C), targets.view(B * T))
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.block_size:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float('-inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
8.2 BERT - Masked Language Modeling
Learns bidirectional context. 15% of tokens are masked and must be reconstructed.
from transformers import BertForMaskedLM, BertTokenizer
import torch
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
def create_mlm_input(text, mask_prob=0.15):
"""Generate masked language model input"""
tokens = tokenizer.encode(text, return_tensors="pt")
labels = tokens.clone()
rand = torch.rand(tokens.shape)
mask_arr = (rand < mask_prob) & (tokens != tokenizer.cls_token_id) & \
(tokens != tokenizer.sep_token_id)
# Masking strategy:
# - 80%: replace with [MASK] token
# - 10%: replace with a random token
# - 10%: keep original token
for i, masked in enumerate(mask_arr[0]):
if masked:
prob = torch.rand(1).item()
if prob < 0.8:
tokens[0][i] = tokenizer.mask_token_id
elif prob < 0.9:
tokens[0][i] = torch.randint(len(tokenizer), (1,))
labels[~mask_arr] = -100
return tokens, labels
# Compute MLM loss
text = "The quick brown fox jumps over the lazy dog"
input_ids, labels = create_mlm_input(text)
with torch.no_grad():
outputs = model(input_ids=input_ids, labels=labels)
print(f"MLM Loss: {outputs.loss.item():.4f}")
# Predict masked tokens
input_text = "The [MASK] brown fox"
inputs = tokenizer(input_text, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
mask_idx = (inputs.input_ids == tokenizer.mask_token_id).nonzero()[0][1]
top_5 = logits[0, mask_idx].topk(5)
for token_id, score in zip(top_5.indices, top_5.values):
print(f"{tokenizer.decode([token_id])}: {score:.3f}")
8.3 T5 - Text-to-Text Transfer
from transformers import T5ForConditionalGeneration, T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained("t5-base")
model = T5ForConditionalGeneration.from_pretrained("t5-base")
# T5 unifies all tasks as text-to-text
tasks = {
"translation": "translate English to French: The weather is nice today",
"summarization": "summarize: Scientists have discovered a new species of deep-sea fish...",
"sentiment": "sentiment analysis: This movie was absolutely fantastic!",
"qa": "question: What is the capital of France? context: Paris is the capital..."
}
for task_name, input_text in tasks.items():
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
with torch.no_grad():
output_ids = model.generate(
inputs.input_ids,
max_length=128,
num_beams=4,
early_stopping=True
)
output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(f"\n[{task_name}]")
print(f"Input: {input_text[:60]}...")
print(f"Output: {output}")
9. Practical Applications
9.1 Few-shot Learning
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
class LinearProbe(nn.Module):
"""Linear classifier on top of SSL features"""
def __init__(self, feature_dim, num_classes):
super().__init__()
self.linear = nn.Linear(feature_dim, num_classes)
def forward(self, x):
return self.linear(x)
def extract_features(model, dataloader, device='cuda'):
"""Extract features from a pretrained model"""
model.eval()
features_list = []
labels_list = []
with torch.no_grad():
for images, labels in dataloader:
images = images.to(device)
features = model.encode(images)
features_list.append(features.cpu())
labels_list.append(labels)
return torch.cat(features_list), torch.cat(labels_list)
def few_shot_evaluation(ssl_model, train_loader, val_loader,
n_shots=[1, 5, 10, 100], device='cuda'):
"""Evaluate with n labeled examples per class"""
ssl_model.eval()
all_features, all_labels = extract_features(ssl_model, train_loader, device)
val_features, val_labels = extract_features(ssl_model, val_loader, device)
results = {}
num_classes = len(all_labels.unique())
for n_shot in n_shots:
selected_indices = []
for cls in range(num_classes):
cls_indices = (all_labels == cls).nonzero(as_tuple=True)[0]
selected = cls_indices[torch.randperm(len(cls_indices))[:n_shot]]
selected_indices.append(selected)
selected_indices = torch.cat(selected_indices)
train_feats = all_features[selected_indices]
train_labs = all_labels[selected_indices]
probe = LinearProbe(all_features.shape[1], num_classes).to(device)
optimizer = torch.optim.Adam(probe.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
train_feats = train_feats.to(device)
train_labs = train_labs.to(device)
for epoch in range(100):
probe.train()
logits = probe(train_feats)
loss = criterion(logits, train_labs)
optimizer.zero_grad()
loss.backward()
optimizer.step()
probe.eval()
with torch.no_grad():
val_logits = probe(val_features.to(device))
preds = val_logits.argmax(dim=1).cpu()
acc = (preds == val_labels).float().mean().item()
results[n_shot] = acc
print(f"{n_shot}-shot accuracy: {acc * 100:.2f}%")
return results
def knn_evaluation(ssl_model, train_loader, val_loader, k=20, device='cuda'):
"""k-NN evaluation of SSL representation quality (parameter-free)"""
ssl_model.eval()
train_features, train_labels = extract_features(ssl_model, train_loader, device)
val_features, val_labels = extract_features(ssl_model, val_loader, device)
train_features = nn.functional.normalize(train_features, dim=1)
val_features = nn.functional.normalize(val_features, dim=1)
correct = 0
total = 0
batch_size = 256
for i in range(0, len(val_features), batch_size):
batch_feats = val_features[i:i+batch_size].to(device)
batch_labels = val_labels[i:i+batch_size]
sims = torch.mm(batch_feats, train_features.to(device).T)
topk_sims, topk_indices = sims.topk(k, dim=1)
topk_labels = train_labels[topk_indices.cpu()]
preds = topk_labels.mode(dim=1).values
correct += (preds == batch_labels).sum().item()
total += len(batch_labels)
accuracy = correct / total
print(f"k-NN ({k}) accuracy: {accuracy * 100:.2f}%")
return accuracy
9.2 Domain-Specific SSL
# Medical images (chest X-ray example)
class MedicalSSL(nn.Module):
"""Self-supervised learning for medical images"""
def __init__(self, backbone):
super().__init__()
self.backbone = backbone
# Weaker augmentation for medical images
self.augmentation = T.Compose([
T.RandomResizedCrop(224, scale=(0.8, 1.0)),
T.RandomHorizontalFlip(p=0.5),
# Minimize color changes for medical images
T.RandomApply([T.ColorJitter(0.1, 0.1, 0.0, 0.0)], p=0.3),
T.ToTensor(),
T.Normalize([0.485], [0.229]) # Grayscale X-ray
])
def forward(self, x):
v1 = self.augmentation(x)
v2 = self.augmentation(x)
return v1, v2
# Satellite images
class SatelliteSSL:
"""Augmentation for satellite images (preserving geographic properties)"""
def __init__(self):
self.transform = T.Compose([
T.RandomResizedCrop(256, scale=(0.4, 1.0)),
# Satellite images: rotation invariance is important
T.RandomRotation(90),
T.RandomHorizontalFlip(),
T.RandomVerticalFlip(),
# Simulate time-of-day / seasonal variation
T.ColorJitter(brightness=0.4, contrast=0.4),
T.ToTensor(),
])
9.3 Using Pretrained SSL Models from the Hub
from transformers import (
ViTModel, ViTFeatureExtractor,
CLIPModel, CLIPProcessor
)
from PIL import Image
import torch
# Load MAE pretrained ViT
def load_mae_pretrained():
model = ViTModel.from_pretrained("facebook/vit-mae-base")
feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")
return model, feature_extractor
# DINO ViT feature extraction
def extract_dino_features(image_path):
model = ViTModel.from_pretrained("facebook/dino-vits16")
processor = ViTFeatureExtractor.from_pretrained("facebook/dino-vits16")
image = Image.open(image_path)
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# CLS token (global image representation)
cls_features = outputs.last_hidden_state[:, 0, :]
# Per-patch features (with spatial information)
patch_features = outputs.last_hidden_state[:, 1:, :]
return cls_features, patch_features
# Extract features with CLIP
def clip_feature_extraction(images, texts=None):
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
if texts:
inputs = processor(text=texts, images=images, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
return outputs.image_embeds, outputs.text_embeds
else:
inputs = processor(images=images, return_tensors="pt")
with torch.no_grad():
image_features = model.get_image_features(**inputs)
return image_features
Conclusion
Self-supervised learning is transforming AI. Key takeaways:
| Method | Year | Core Idea | Strength |
|---|---|---|---|
| SimCLR | 2020 | Strong augmentation + contrastive | Simple and powerful |
| MoCo | 2020 | Momentum queue | Batch-size independent |
| BYOL | 2020 | Contrastive without negatives | Batch-size independent |
| MAE | 2021 | 75% masking + pixel reconstruction | Efficient, ViT-optimal |
| DINO | 2021 | Self-distillation | Segmentation properties |
| CLIP | 2021 | Image-text contrastive | Zero-shot classification |
| BEiT | 2021 | Discrete token masking | BERT-style |
| DINOv2 | 2023 | Large-scale + iBOT | Universal feature extractor |
Learning Roadmap:
- Implement SimCLR to understand contrastive learning
- Study MAE to grasp the masking approach
- Apply CLIP to multimodal learning
- Fine-tune on your domain data
SSL has become the foundation of LLM pretraining and computer vision. As an approach that effectively leverages vast amounts of unlabeled data, SSL will continue to be a core driver of AI progress.
References
- SimCLR Paper - Chen et al., 2020
- MAE Paper - He et al., 2021
- DINO Paper - Caron et al., 2021
- CLIP Paper - Radford et al., 2021
- BYOL Paper - Grill et al., 2020