- Authors

- Name
- Youngju Kim
- @fjvbn20031
Table of Contents
- What is Self-Supervised Learning?
- Early SSL Methods
- Contrastive Learning
- Masked Autoencoder (MAE)
- DINO - Self-Distillation with No Labels
- CLIP - Contrastive Language-Image Pretraining
- BEiT - BERT for Images
- Language Model Pretraining
- Practical Applications
1. What is Self-Supervised Learning?
1.1 Comparing Learning Paradigms
Let's start by understanding the three major learning paradigms in deep learning.
Supervised Learning trains on labeled data. ImageNet's 10 million images each carry a human-annotated label, and models like ResNet and ViT are trained on this data. The problem is that label collection is extremely expensive. A single medical image can take a specialist dozens of minutes to annotate.
Unsupervised Learning discovers patterns in data without labels. K-Means clustering, PCA, and GANs fall into this category. However, traditional unsupervised learning struggles to learn representations directly usable for downstream tasks.
Self-Supervised Learning (SSL) combines the best of both worlds. It leverages large amounts of unlabeled data while generating supervisory signals from the data itself.
Core idea: The data itself is the label.
The internet contains billions of images and trillions of words. SSL exploits this vast unlabeled data to learn powerful representations.
1.2 The Label Scarcity Problem
In the real world, label collection is extremely difficult.
| Domain | Label Collection Cost | Reason |
|---|---|---|
| Medical images | Very high | Requires specialist physicians |
| Satellite imagery | High | Requires geographic expertise |
| Legal documents | High | Requires legal experts |
| Industrial defect detection | High | Rare defect samples |
| Speech recognition | Medium | Transcription required |
SSL fundamentally solves this problem. The approach is to first pretrain on large-scale unlabeled data, then fine-tune with a small amount of labeled data.
1.3 Pretext Tasks
The key to SSL is the pretext task. We create artificial prediction problems from the data itself to force the model to learn meaningful representations.
Conditions for a good pretext task:
- Must be constructable automatically without labels
- Solving the task well must require meaningful representations
- Must not be too easy or too hard
Examples:
- Rotation prediction: Rotate an image by 0°, 90°, 180°, 270° and predict how many degrees it was rotated
- Masking: Cover part of an image/text and restore it
- Contrastive learning: Make two views of the same image similar in embedding space
1.4 Applications of SSL
SSL has become the foundation of modern AI:
- GPT, BERT: Text SSL → basis for ChatGPT, Gemini
- CLIP: Image-text SSL → basis for DALL-E, Stable Diffusion
- MAE, DINO: Image SSL → basis for medical imaging, autonomous driving
- wav2vec: Audio SSL → basis for speech recognition
2. Early SSL Methods
Early SSL research explored a variety of pretext tasks.
2.1 Rotation Prediction
Proposed by Gidaris et al. in 2018. The model predicts which of four rotations (0°, 90°, 180°, 270°) was applied to an image.
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
class RotationSSL(nn.Module):
def __init__(self, backbone, num_classes=4):
super().__init__()
self.backbone = backbone
self.classifier = nn.Linear(backbone.output_dim, num_classes)
def forward(self, x):
features = self.backbone(x)
return self.classifier(features)
def create_rotation_dataset(images):
"""Create (image, label) pairs by rotating images in 4 directions"""
rotated_images = []
labels = []
for img in images:
for k in range(4): # 0, 90, 180, 270 degrees
rotated = T.functional.rotate(img, k * 90)
rotated_images.append(rotated)
labels.append(k)
return rotated_images, labels
def train_rotation_ssl(model, dataloader, optimizer, epochs=10):
criterion = nn.CrossEntropyLoss()
model.train()
for epoch in range(epochs):
total_loss = 0
for images, labels in dataloader:
images, labels = images.cuda(), labels.cuda()
optimizer.zero_grad()
logits = model(images)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
Limitation: While natural images have clear orientations, textures and sky images may not benefit.
2.2 Jigsaw Puzzle Solving
Proposed by Noroozi & Favaro in 2016. An image is divided into a 3x3 grid and shuffled; the model must predict the original ordering.
import itertools
import numpy as np
class JigsawSSL(nn.Module):
def __init__(self, backbone, num_permutations=100):
super().__init__()
self.backbone = backbone
self.patch_encoder = nn.Sequential(
backbone,
nn.Linear(backbone.output_dim, 512)
)
self.classifier = nn.Sequential(
nn.Linear(512 * 9, 4096),
nn.ReLU(),
nn.Linear(4096, num_permutations)
)
def forward(self, patches):
# patches: (batch, 9, C, H, W)
batch_size = patches.size(0)
patch_features = []
for i in range(9):
feat = self.patch_encoder(patches[:, i]) # (batch, 512)
patch_features.append(feat)
combined = torch.cat(patch_features, dim=1) # (batch, 512*9)
return self.classifier(combined)
def create_jigsaw_dataset(image, grid_size=3):
"""Split image into grid_size x grid_size patches and shuffle"""
h, w = image.shape[-2:]
patch_h, patch_w = h // grid_size, w // grid_size
patches = []
for i in range(grid_size):
for j in range(grid_size):
patch = image[...,
i*patch_h:(i+1)*patch_h,
j*patch_w:(j+1)*patch_w]
patches.append(patch)
permutation_idx = np.random.randint(0, len(PREDEFINED_PERMUTATIONS))
perm = PREDEFINED_PERMUTATIONS[permutation_idx]
shuffled = [patches[p] for p in perm]
return torch.stack(shuffled), permutation_idx
2.3 Colorization
Proposed by Zhang et al. in 2016. A grayscale image is given as input; the model predicts the colors.
class ColorizationSSL(nn.Module):
def __init__(self):
super().__init__()
# L channel input (lightness)
self.encoder = nn.Sequential(
nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(),
nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.ReLU(),
nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.ReLU(),
)
# ab channel output (color)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), nn.ReLU(),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.ReLU(),
nn.Conv2d(64, 2, 3, padding=1), # ab channels
nn.Tanh()
)
def forward(self, l_channel):
features = self.encoder(l_channel)
ab_channels = self.decoder(features)
return ab_channels
from skimage.color import rgb2lab, lab2rgb
def prepare_colorization_data(rgb_image):
lab = rgb2lab(rgb_image)
l_channel = lab[:, :, 0:1] # lightness channel
ab_channels = lab[:, :, 1:] # color channels
return l_channel, ab_channels
3. Contrastive Learning
Contrastive Learning is the core of modern SSL. It exploded in development around 2020.
3.1 Core Idea
Pull similar things together, push different things apart.
Two different views (crop, flip, color jitter, etc.) of the same image should be close together in embedding space, while different images should be far apart.
Image x
├── Augmented view v1 → z1 (embedding)
└── Augmented view v2 → z2 (embedding)
Goal: z1 and z2 close together, z1 and z3 (from different image) far apart
3.2 InfoNCE Loss
The standard loss function for contrastive learning — also called NT-Xent (Normalized Temperature-scaled Cross Entropy).
where sim is cosine similarity and τ is the temperature parameter.
import torch
import torch.nn.functional as F
def info_nce_loss(z1, z2, temperature=0.5):
"""
InfoNCE / NT-Xent Loss
z1, z2: (batch_size, embedding_dim) - two views of the same images
"""
batch_size = z1.size(0)
# L2 normalize
z1 = F.normalize(z1, dim=1)
z2 = F.normalize(z2, dim=1)
# Concatenate: [z1_1, ..., z1_N, z2_1, ..., z2_N]
z = torch.cat([z1, z2], dim=0) # (2N, D)
# Compute pairwise similarities
similarity = torch.mm(z, z.t()) / temperature # (2N, 2N)
# Exclude self-similarity (diagonal = -inf)
mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
similarity.masked_fill_(mask, float('-inf'))
# Positive pairs: index i and i+N
labels = torch.cat([
torch.arange(batch_size, 2 * batch_size),
torch.arange(batch_size)
]).to(z.device)
loss = F.cross_entropy(similarity, labels)
return loss
# Usage
z1 = torch.randn(32, 128)
z2 = torch.randn(32, 128)
loss = info_nce_loss(z1, z2, temperature=0.5)
print(f"InfoNCE Loss: {loss.item():.4f}")
3.3 SimCLR
Published by Google in 2020. Simple Framework for Contrastive Learning. Simple but powerful.
Key Components:
- Data augmentation (strong augmentation is critical)
- Encoder (ResNet)
- Projection Head (MLP)
- NT-Xent Loss
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as T
class SimCLR(nn.Module):
def __init__(self, backbone='resnet50', projection_dim=128, temperature=0.5):
super().__init__()
self.temperature = temperature
if backbone == 'resnet50':
resnet = models.resnet50(pretrained=False)
self.encoder = nn.Sequential(*list(resnet.children())[:-1])
self.feature_dim = 2048
elif backbone == 'resnet18':
resnet = models.resnet18(pretrained=False)
self.encoder = nn.Sequential(*list(resnet.children())[:-1])
self.feature_dim = 512
# Projection Head (removed during fine-tuning)
self.projection_head = nn.Sequential(
nn.Linear(self.feature_dim, self.feature_dim),
nn.ReLU(inplace=True),
nn.Linear(self.feature_dim, projection_dim)
)
def encode(self, x):
h = self.encoder(x)
h = h.squeeze(-1).squeeze(-1)
return h
def forward(self, x):
h = self.encode(x)
z = self.projection_head(h)
return z
def contrastive_loss(self, z1, z2):
return info_nce_loss(z1, z2, self.temperature)
class SimCLRDataAugmentation:
"""SimCLR's strong data augmentation"""
def __init__(self, image_size=224, s=1.0):
color_jitter = T.ColorJitter(
brightness=0.8*s,
contrast=0.8*s,
saturation=0.8*s,
hue=0.2*s
)
self.transform = T.Compose([
T.RandomResizedCrop(image_size, scale=(0.2, 1.0)),
T.RandomHorizontalFlip(p=0.5),
T.RandomApply([color_jitter], p=0.8),
T.RandomGrayscale(p=0.2),
T.GaussianBlur(kernel_size=int(0.1 * image_size)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def __call__(self, x):
return self.transform(x), self.transform(x)
def train_simclr(model, dataloader, optimizer, epochs=200):
model.train()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=epochs
)
for epoch in range(epochs):
total_loss = 0
for (x1, x2), _ in dataloader:
x1, x2 = x1.cuda(), x2.cuda()
z1 = model(x1)
z2 = model(x2)
loss = model.contrastive_loss(z1, z2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
scheduler.step()
avg_loss = total_loss / len(dataloader)
if (epoch + 1) % 10 == 0:
print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")
model = SimCLR(backbone='resnet50', projection_dim=128).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-6)
def get_finetuning_model(pretrained_simclr, num_classes):
encoder = pretrained_simclr.encoder
classifier = nn.Linear(pretrained_simclr.feature_dim, num_classes)
return nn.Sequential(encoder, nn.Flatten(), classifier)
3.4 MoCo (Momentum Contrast)
Published by Facebook AI in 2020. Enables using many negative samples without requiring a large batch size.
Key idea: Stable key encoder via momentum update + Queue for managing negative samples
class MoCo(nn.Module):
def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07):
super().__init__()
self.K = K # queue size
self.m = m # momentum coefficient
self.T = T # temperature
# Query encoder and key encoder
self.encoder_q = base_encoder(num_classes=dim)
self.encoder_k = base_encoder(num_classes=dim)
# Key encoder: no gradients, updated by momentum
for param_q, param_k in zip(
self.encoder_q.parameters(),
self.encoder_k.parameters()
):
param_k.data.copy_(param_q.data)
param_k.requires_grad = False
# Negative sample queue
self.register_buffer("queue", torch.randn(dim, K))
self.queue = F.normalize(self.queue, dim=0)
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def _momentum_update_key_encoder(self):
for param_q, param_k in zip(
self.encoder_q.parameters(),
self.encoder_k.parameters()
):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
self.queue[:, ptr:ptr + batch_size] = keys.T
ptr = (ptr + batch_size) % self.K
self.queue_ptr[0] = ptr
def forward(self, im_q, im_k):
q = self.encoder_q(im_q)
q = F.normalize(q, dim=1)
with torch.no_grad():
self._momentum_update_key_encoder()
k = self.encoder_k(im_k)
k = F.normalize(k, dim=1)
# Positive logits: (batch, 1)
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
# Negative logits: (batch, K)
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
logits = torch.cat([l_pos, l_neg], dim=1) / self.T
labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
loss = F.cross_entropy(logits, labels)
self._dequeue_and_enqueue(k)
return loss
3.5 BYOL (Bootstrap Your Own Latent)
Published by DeepMind in 2020. A groundbreaking method that works without negative samples.
Key insight: An online network predicts the representation of a target network. The target is updated by momentum.
import copy
class BYOL(nn.Module):
def __init__(self, backbone, projection_dim=256, prediction_dim=128, tau=0.996):
super().__init__()
self.tau = tau
# Online network
self.online_encoder = backbone
self.online_projector = nn.Sequential(
nn.Linear(2048, 4096), nn.BatchNorm1d(4096), nn.ReLU(),
nn.Linear(4096, projection_dim)
)
self.online_predictor = nn.Sequential(
nn.Linear(projection_dim, 4096), nn.BatchNorm1d(4096), nn.ReLU(),
nn.Linear(4096, prediction_dim)
)
# Target network (no gradients)
self.target_encoder = copy.deepcopy(backbone)
self.target_projector = copy.deepcopy(self.online_projector)
for param in self.target_encoder.parameters():
param.requires_grad = False
for param in self.target_projector.parameters():
param.requires_grad = False
@torch.no_grad()
def update_target(self):
for online, target in zip(
self.online_encoder.parameters(),
self.target_encoder.parameters()
):
target.data = self.tau * target.data + (1 - self.tau) * online.data
for online, target in zip(
self.online_projector.parameters(),
self.target_projector.parameters()
):
target.data = self.tau * target.data + (1 - self.tau) * online.data
def forward(self, x1, x2):
online_feat1 = self.online_encoder(x1).squeeze()
online_proj1 = self.online_projector(online_feat1)
online_pred1 = self.online_predictor(online_proj1)
online_feat2 = self.online_encoder(x2).squeeze()
online_proj2 = self.online_projector(online_feat2)
online_pred2 = self.online_predictor(online_proj2)
with torch.no_grad():
target_feat1 = self.target_encoder(x1).squeeze()
target_proj1 = self.target_projector(target_feat1)
target_feat2 = self.target_encoder(x2).squeeze()
target_proj2 = self.target_projector(target_feat2)
loss1 = byol_loss(online_pred1, target_proj2.detach())
loss2 = byol_loss(online_pred2, target_proj1.detach())
return (loss1 + loss2).mean()
def byol_loss(p, z):
p = F.normalize(p, dim=-1)
z = F.normalize(z, dim=-1)
return -(p * z).sum(dim=-1)
3.6 SimSiam
Published by Facebook AI in 2021. Even simpler than BYOL — no momentum needed to prevent collapse.
class SimSiam(nn.Module):
def __init__(self, backbone, dim=2048, pred_dim=512):
super().__init__()
self.encoder = backbone
self.projector = nn.Sequential(
nn.Linear(dim, dim, bias=False),
nn.BatchNorm1d(dim),
nn.ReLU(inplace=True),
nn.Linear(dim, dim, bias=False),
nn.BatchNorm1d(dim),
nn.ReLU(inplace=True),
nn.Linear(dim, dim, bias=False),
nn.BatchNorm1d(dim, affine=False)
)
self.predictor = nn.Sequential(
nn.Linear(dim, pred_dim, bias=False),
nn.BatchNorm1d(pred_dim),
nn.ReLU(inplace=True),
nn.Linear(pred_dim, dim)
)
def forward(self, x1, x2):
z1 = self.projector(self.encoder(x1).squeeze())
z2 = self.projector(self.encoder(x2).squeeze())
p1 = self.predictor(z1)
p2 = self.predictor(z2)
# Stop gradient - the key to the asymmetric design
loss = simsiam_loss(p1, z2.detach()) / 2 + \
simsiam_loss(p2, z1.detach()) / 2
return loss
def simsiam_loss(p, z):
z = z.detach() # Stop gradient
p = F.normalize(p, dim=1)
z = F.normalize(z, dim=1)
return -(p * z).sum(dim=1).mean()
4. Masked Autoencoder (MAE)
MAE, published by Kaiming He et al. in 2021, applies BERT's masking approach from NLP to images.
4.1 Core Idea
Mask 75% of an image and reconstruct the full image from the remaining 25%.
Why such a high masking ratio of 75%?
- Text tokens carry high semantic density, so 15% masking suffices for BERT
- Adjacent pixels in images have high redundancy — 75% masking forces genuine understanding
4.2 Asymmetric Encoder-Decoder
MAE's innovation: The encoder processes only visible patches; the decoder reconstructs the full image.
Input image (196 patches)
↓ 75% masking
49 visible patches → Encoder (ViT-Large) → Encoded representation
↓
196 patches (with mask tokens) → Decoder (shallow ViT) → Pixel reconstruction
The encoder processes only 25% of the original size — 3-4x faster training.
4.3 Complete MAE Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
class PatchEmbedding(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
self.projection = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
x = self.projection(x) # (batch, embed_dim, H/p, W/p)
x = x.flatten(2) # (batch, embed_dim, num_patches)
x = x.transpose(1, 2) # (batch, num_patches, embed_dim)
return x
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., dropout=0.):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
self.norm2 = nn.LayerNorm(dim)
mlp_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
norm_x = self.norm1(x)
attn_out, _ = self.attn(norm_x, norm_x, norm_x)
x = x + attn_out
x = x + self.mlp(self.norm2(x))
return x
class MAE(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
in_channels=3,
encoder_dim=768,
encoder_depth=12,
encoder_heads=12,
decoder_dim=512,
decoder_depth=8,
decoder_heads=16,
mask_ratio=0.75,
):
super().__init__()
self.mask_ratio = mask_ratio
self.patch_size = patch_size
num_patches = (img_size // patch_size) ** 2
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, encoder_dim)
self.cls_token = nn.Parameter(torch.zeros(1, 1, encoder_dim))
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, encoder_dim),
requires_grad=False
)
self.encoder_blocks = nn.ModuleList([
TransformerBlock(encoder_dim, encoder_heads)
for _ in range(encoder_depth)
])
self.encoder_norm = nn.LayerNorm(encoder_dim)
self.decoder_embed = nn.Linear(encoder_dim, decoder_dim, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))
self.decoder_pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, decoder_dim),
requires_grad=False
)
self.decoder_blocks = nn.ModuleList([
TransformerBlock(decoder_dim, decoder_heads)
for _ in range(decoder_depth)
])
self.decoder_norm = nn.LayerNorm(decoder_dim)
self.decoder_pred = nn.Linear(decoder_dim, patch_size**2 * in_channels)
def random_masking(self, x, mask_ratio):
N, L, D = x.shape
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device)
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(
x, dim=1,
index=ids_keep.unsqueeze(-1).repeat(1, 1, D)
)
mask = torch.ones(N, L, device=x.device)
mask[:, :len_keep] = 0
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
def forward_encoder(self, x, mask_ratio):
x = self.patch_embed(x)
x = x + self.pos_embed[:, 1:, :]
x, mask, ids_restore = self.random_masking(x, mask_ratio)
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
for block in self.encoder_blocks:
x = block(x)
x = self.encoder_norm(x)
return x, mask, ids_restore
def forward_decoder(self, x, ids_restore):
x = self.decoder_embed(x)
mask_tokens = self.mask_token.repeat(
x.shape[0],
ids_restore.shape[1] + 1 - x.shape[1],
1
)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
x_ = torch.gather(
x_, dim=1,
index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
)
x = torch.cat([x[:, :1, :], x_], dim=1)
x = x + self.decoder_pos_embed
for block in self.decoder_blocks:
x = block(x)
x = self.decoder_norm(x)
x = self.decoder_pred(x)
x = x[:, 1:, :] # Remove CLS token
return x
def forward(self, imgs, mask_ratio=0.75):
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
pred = self.forward_decoder(latent, ids_restore)
loss = self.forward_loss(imgs, pred, mask)
return loss, pred, mask
def forward_loss(self, imgs, pred, mask):
target = self.patchify(imgs)
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.e-6).sqrt()
loss = (pred - target) ** 2
loss = loss.mean(dim=-1)
loss = (loss * mask).sum() / mask.sum()
return loss
def patchify(self, imgs):
p = self.patch_size
h = w = imgs.shape[2] // p
x = imgs.reshape(imgs.shape[0], 3, h, p, w, p)
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(imgs.shape[0], h * w, p**2 * 3)
return x
# Create and train the MAE model
mae_model = MAE(
img_size=224,
patch_size=16,
encoder_dim=768,
encoder_depth=12,
encoder_heads=12,
decoder_dim=512,
decoder_depth=8,
decoder_heads=16,
mask_ratio=0.75
).cuda()
optimizer = torch.optim.AdamW(
mae_model.parameters(),
lr=1.5e-4,
betas=(0.9, 0.95),
weight_decay=0.05
)
def train_step(model, imgs, optimizer):
model.train()
optimizer.zero_grad()
loss, pred, mask = model(imgs, mask_ratio=0.75)
loss.backward()
optimizer.step()
return loss.item()
5. DINO
DINO (Self-DIstillation with NO labels) was published by Facebook AI Research in 2021. When ViT is trained with SSL, remarkable properties emerge.
5.1 Key Discovery
DINO-trained ViT's self-attention maps perform semantic segmentation without labels. The model perfectly separates foreground from background.
5.2 Self-Distillation Architecture
Student Network Teacher Network
↑ ↑
Local Views (many) Global Views (2)
- Teacher: Processes full-image views, richer context
- Student: Processes local crops, learns to match teacher's predictions
- Teacher update: EMA of the student (no gradients)
5.3 Centering and Sharpening
Two techniques to prevent collapse:
class DINOHead(nn.Module):
def __init__(self, in_dim, out_dim=65536, bottleneck_dim=256):
super().__init__()
layers = [nn.Linear(in_dim, 2048), nn.GELU()]
layers += [nn.Linear(2048, 2048), nn.GELU()]
layers += [nn.Linear(2048, bottleneck_dim)]
self.mlp = nn.Sequential(*layers)
self.last_layer = nn.utils.weight_norm(
nn.Linear(bottleneck_dim, out_dim, bias=False)
)
self.last_layer.weight_g.data.fill_(1)
self.last_layer.weight_g.requires_grad = False
def forward(self, x):
x = self.mlp(x)
x = F.normalize(x, dim=-1)
x = self.last_layer(x)
return x
class DINO(nn.Module):
def __init__(self, backbone, out_dim=65536, teacher_temp=0.04, student_temp=0.1):
super().__init__()
self.student_temp = student_temp
self.teacher_temp = teacher_temp
self.student = backbone
self.student_head = DINOHead(self.student.embed_dim, out_dim)
self.teacher = copy.deepcopy(backbone)
self.teacher_head = copy.deepcopy(self.student_head)
for p in self.teacher.parameters():
p.requires_grad = False
for p in self.teacher_head.parameters():
p.requires_grad = False
self.register_buffer("center", torch.zeros(1, out_dim))
@torch.no_grad()
def update_teacher(self, momentum):
for student_ps, teacher_ps in zip(
self.student.parameters(), self.teacher.parameters()
):
teacher_ps.data.mul_(momentum).add_((1 - momentum) * student_ps.data)
for student_ps, teacher_ps in zip(
self.student_head.parameters(), self.teacher_head.parameters()
):
teacher_ps.data.mul_(momentum).add_((1 - momentum) * student_ps.data)
@torch.no_grad()
def update_center(self, teacher_output):
batch_center = teacher_output.mean(dim=0, keepdim=True)
self.center = self.center * 0.9 + batch_center * 0.1
def dino_loss(self, student_output, teacher_output):
student_out = student_output / self.student_temp
student_out = student_out.chunk(2)
# Teacher: centering + sharpening
teacher_out = F.softmax(
(teacher_output - self.center) / self.teacher_temp, dim=-1
)
teacher_out = teacher_out.detach().chunk(2)
total_loss = 0
n_loss_terms = 0
for iq, q in enumerate(teacher_out):
for v in range(len(student_out)):
if v == iq:
continue
loss = torch.sum(
-q * F.log_softmax(student_out[v], dim=-1), dim=-1
)
total_loss += loss.mean()
n_loss_terms += 1
return total_loss / n_loss_terms
def forward(self, global_views, local_views, momentum=0.996):
all_views = global_views + local_views
student_output = torch.cat([
self.student_head(self.student(view))
for view in all_views
])
with torch.no_grad():
teacher_output = torch.cat([
self.teacher_head(self.teacher(view))
for view in global_views
])
loss = self.dino_loss(student_output, teacher_output)
self.update_center(teacher_output)
self.update_teacher(momentum)
return loss
class DINOAugmentation:
def __init__(self, global_crops_scale=(0.4, 1.), local_crops_scale=(0.05, 0.4),
n_local_crops=8, image_size=224, local_size=96):
self.n_local_crops = n_local_crops
self.global_transform = T.Compose([
T.RandomResizedCrop(image_size, scale=global_crops_scale,
interpolation=T.InterpolationMode.BICUBIC),
T.RandomHorizontalFlip(),
T.RandomApply([T.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
T.RandomGrayscale(p=0.2),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
self.local_transform = T.Compose([
T.RandomResizedCrop(local_size, scale=local_crops_scale,
interpolation=T.InterpolationMode.BICUBIC),
T.RandomHorizontalFlip(),
T.RandomApply([T.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
T.RandomGrayscale(p=0.2),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def __call__(self, image):
global_views = [self.global_transform(image) for _ in range(2)]
local_views = [self.local_transform(image) for _ in range(self.n_local_crops)]
return global_views, local_views
5.4 DINOv2
DINOv2, released in 2023, is even more powerful:
- Trained on 142 million curated images
- Added iBOT (Image BERT Pre-Training with Online Tokenizer) loss
- More stable training dynamics
6. CLIP
CLIP (Contrastive Language-Image Pretraining) is an innovative model published by OpenAI in 2021.
6.1 Text-Image Contrastive Learning
Trained on 400 million (image, text) pairs collected from the internet.
Image Encoder (ViT / ResNet)
↓
Image embedding
Text Encoder (Transformer)
↓
Text embedding
Goal: Matching (image, text) pairs → similar embeddings
Non-matching pairs → dissimilar embeddings
6.2 CLIP Loss
def clip_loss(image_embeddings, text_embeddings, temperature):
"""
CLIP Symmetric Contrastive Loss
image_embeddings: (N, D)
text_embeddings: (N, D)
"""
image_embeddings = F.normalize(image_embeddings, dim=1)
text_embeddings = F.normalize(text_embeddings, dim=1)
# Similarity matrix: (N, N)
logits = torch.matmul(image_embeddings, text_embeddings.T) / temperature
# Labels: diagonal (image i matches text i)
labels = torch.arange(len(logits)).to(logits.device)
# Bidirectional loss: image→text and text→image
loss_i2t = F.cross_entropy(logits, labels)
loss_t2i = F.cross_entropy(logits.T, labels)
return (loss_i2t + loss_t2i) / 2
6.3 Full CLIP Usage
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
import torch.nn.functional as F
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# ---- Zero-shot Image Classification ----
def zero_shot_classify(image_path, candidate_labels):
"""Classify an image without any label examples"""
image = Image.open(image_path)
# Create text prompts
texts = [f"a photo of a {label}" for label in candidate_labels]
inputs = processor(
text=texts,
images=image,
return_tensors="pt",
padding=True
)
with torch.no_grad():
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
results = {}
for label, prob in zip(candidate_labels, probs[0]):
results[label] = prob.item()
return results
# Usage
labels = ["cat", "dog", "car", "airplane", "bird"]
results = zero_shot_classify("example.jpg", labels)
for label, prob in sorted(results.items(), key=lambda x: -x[1]):
print(f"{label}: {prob:.4f}")
# ---- Image-Text Similarity Search ----
class CLIPRetrieval:
def __init__(self, model_name="openai/clip-vit-base-patch32"):
self.model = CLIPModel.from_pretrained(model_name)
self.processor = CLIPProcessor.from_pretrained(model_name)
self.model.eval()
self.image_embeddings = None
self.image_paths = []
def index_images(self, image_paths):
"""Index an image database"""
self.image_paths = image_paths
embeddings = []
for path in image_paths:
image = Image.open(path)
inputs = self.processor(images=image, return_tensors="pt")
with torch.no_grad():
embedding = self.model.get_image_features(**inputs)
embeddings.append(F.normalize(embedding, dim=1))
self.image_embeddings = torch.cat(embeddings, dim=0)
print(f"Indexed {len(image_paths)} images")
def search(self, query_text, top_k=5):
"""Search images using a text query"""
inputs = self.processor(text=[query_text], return_tensors="pt", padding=True)
with torch.no_grad():
text_embedding = self.model.get_text_features(**inputs)
text_embedding = F.normalize(text_embedding, dim=1)
similarities = torch.matmul(text_embedding, self.image_embeddings.T).squeeze()
top_indices = similarities.argsort(descending=True)[:top_k]
results = [
(self.image_paths[i], similarities[i].item())
for i in top_indices
]
return results
# Visualize CLIP embeddings with t-SNE
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np
def visualize_clip_embeddings(images, texts):
"""t-SNE visualization of CLIP embeddings"""
image_inputs = processor(images=images, return_tensors="pt")
text_inputs = processor(text=texts, return_tensors="pt", padding=True)
with torch.no_grad():
image_embs = F.normalize(model.get_image_features(**image_inputs), dim=1)
text_embs = F.normalize(model.get_text_features(**text_inputs), dim=1)
all_embs = torch.cat([image_embs, text_embs], dim=0).numpy()
tsne = TSNE(n_components=2, random_state=42)
reduced = tsne.fit_transform(all_embs)
n = len(images)
plt.figure(figsize=(12, 8))
plt.scatter(reduced[:n, 0], reduced[:n, 1], c='blue', label='Images', alpha=0.7)
plt.scatter(reduced[n:, 0], reduced[n:, 1], c='red', label='Texts', alpha=0.7)
for i, text in enumerate(texts):
plt.annotate(text, reduced[n+i], fontsize=8)
plt.legend()
plt.title('CLIP Embeddings (t-SNE)')
plt.savefig('clip_tsne.png', dpi=150, bbox_inches='tight')
plt.show()
7. BEiT
BEiT (BERT Pre-Training of Image Transformers) was published by Microsoft in 2021. Images are converted to discrete tokens and trained in a BERT-like fashion.
7.1 Discrete Visual Tokens (dVAE)
BEiT's core: instead of predicting raw pixels, the model predicts discrete tokens.
Step 1: dVAE converts image → discrete tokens (vocabulary of 8192)
Step 2: ViT predicts the token of each masked patch
from transformers import BeitForMaskedImageModeling, BeitImageProcessor
import torch
class BEiTPretraining:
def __init__(self, model_name="microsoft/beit-base-patch16-224"):
self.model = BeitForMaskedImageModeling.from_pretrained(model_name)
self.processor = BeitImageProcessor.from_pretrained(model_name)
def create_bool_masked_pos(self, num_patches, mask_ratio=0.4):
num_masked = int(num_patches * mask_ratio)
mask = torch.zeros(num_patches, dtype=torch.bool)
masked_indices = torch.randperm(num_patches)[:num_masked]
mask[masked_indices] = True
return mask
def forward(self, images):
inputs = self.processor(images, return_tensors="pt")
pixel_values = inputs.pixel_values
num_patches = (224 // 16) ** 2 # 196
bool_masked_pos = self.create_bool_masked_pos(num_patches)
bool_masked_pos = bool_masked_pos.unsqueeze(0).expand(
pixel_values.size(0), -1
)
outputs = self.model(
pixel_values=pixel_values,
bool_masked_pos=bool_masked_pos
)
return outputs.loss
8. Language Model Pretraining
SSL has a long history in NLP.
8.1 GPT - Next Token Prediction
Autoregressive approach that predicts the next token from left to right.
import torch
import torch.nn as nn
from torch.nn import functional as F
class GPTLanguageModel(nn.Module):
def __init__(self, vocab_size, block_size, n_embd, n_head, n_layer, dropout=0.1):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, n_embd)
self.position_embedding = nn.Embedding(block_size, n_embd)
self.blocks = nn.Sequential(*[
TransformerBlock(n_embd, n_head, dropout=dropout)
for _ in range(n_layer)
])
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size)
self.block_size = block_size
def forward(self, idx, targets=None):
B, T = idx.shape
tok_emb = self.token_embedding(idx)
pos_emb = self.position_embedding(torch.arange(T, device=idx.device))
x = tok_emb + pos_emb
x = self.blocks(x)
x = self.ln_f(x)
logits = self.lm_head(x)
if targets is None:
return logits, None
B, T, C = logits.shape
loss = F.cross_entropy(logits.view(B * T, C), targets.view(B * T))
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.block_size:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float('-inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
8.2 BERT - Masked Language Modeling
Learns bidirectional context. 15% of tokens are masked and must be reconstructed.
from transformers import BertForMaskedLM, BertTokenizer
import torch
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
def create_mlm_input(text, mask_prob=0.15):
"""Generate masked language model input"""
tokens = tokenizer.encode(text, return_tensors="pt")
labels = tokens.clone()
rand = torch.rand(tokens.shape)
mask_arr = (rand < mask_prob) & (tokens != tokenizer.cls_token_id) & \
(tokens != tokenizer.sep_token_id)
# Masking strategy:
# - 80%: replace with [MASK] token
# - 10%: replace with a random token
# - 10%: keep original token
for i, masked in enumerate(mask_arr[0]):
if masked:
prob = torch.rand(1).item()
if prob < 0.8:
tokens[0][i] = tokenizer.mask_token_id
elif prob < 0.9:
tokens[0][i] = torch.randint(len(tokenizer), (1,))
labels[~mask_arr] = -100
return tokens, labels
# Compute MLM loss
text = "The quick brown fox jumps over the lazy dog"
input_ids, labels = create_mlm_input(text)
with torch.no_grad():
outputs = model(input_ids=input_ids, labels=labels)
print(f"MLM Loss: {outputs.loss.item():.4f}")
# Predict masked tokens
input_text = "The [MASK] brown fox"
inputs = tokenizer(input_text, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
mask_idx = (inputs.input_ids == tokenizer.mask_token_id).nonzero()[0][1]
top_5 = logits[0, mask_idx].topk(5)
for token_id, score in zip(top_5.indices, top_5.values):
print(f"{tokenizer.decode([token_id])}: {score:.3f}")
8.3 T5 - Text-to-Text Transfer
from transformers import T5ForConditionalGeneration, T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained("t5-base")
model = T5ForConditionalGeneration.from_pretrained("t5-base")
# T5 unifies all tasks as text-to-text
tasks = {
"translation": "translate English to French: The weather is nice today",
"summarization": "summarize: Scientists have discovered a new species of deep-sea fish...",
"sentiment": "sentiment analysis: This movie was absolutely fantastic!",
"qa": "question: What is the capital of France? context: Paris is the capital..."
}
for task_name, input_text in tasks.items():
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
with torch.no_grad():
output_ids = model.generate(
inputs.input_ids,
max_length=128,
num_beams=4,
early_stopping=True
)
output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(f"\n[{task_name}]")
print(f"Input: {input_text[:60]}...")
print(f"Output: {output}")
9. Practical Applications
9.1 Few-shot Learning
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
class LinearProbe(nn.Module):
"""Linear classifier on top of SSL features"""
def __init__(self, feature_dim, num_classes):
super().__init__()
self.linear = nn.Linear(feature_dim, num_classes)
def forward(self, x):
return self.linear(x)
def extract_features(model, dataloader, device='cuda'):
"""Extract features from a pretrained model"""
model.eval()
features_list = []
labels_list = []
with torch.no_grad():
for images, labels in dataloader:
images = images.to(device)
features = model.encode(images)
features_list.append(features.cpu())
labels_list.append(labels)
return torch.cat(features_list), torch.cat(labels_list)
def few_shot_evaluation(ssl_model, train_loader, val_loader,
n_shots=[1, 5, 10, 100], device='cuda'):
"""Evaluate with n labeled examples per class"""
ssl_model.eval()
all_features, all_labels = extract_features(ssl_model, train_loader, device)
val_features, val_labels = extract_features(ssl_model, val_loader, device)
results = {}
num_classes = len(all_labels.unique())
for n_shot in n_shots:
selected_indices = []
for cls in range(num_classes):
cls_indices = (all_labels == cls).nonzero(as_tuple=True)[0]
selected = cls_indices[torch.randperm(len(cls_indices))[:n_shot]]
selected_indices.append(selected)
selected_indices = torch.cat(selected_indices)
train_feats = all_features[selected_indices]
train_labs = all_labels[selected_indices]
probe = LinearProbe(all_features.shape[1], num_classes).to(device)
optimizer = torch.optim.Adam(probe.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
train_feats = train_feats.to(device)
train_labs = train_labs.to(device)
for epoch in range(100):
probe.train()
logits = probe(train_feats)
loss = criterion(logits, train_labs)
optimizer.zero_grad()
loss.backward()
optimizer.step()
probe.eval()
with torch.no_grad():
val_logits = probe(val_features.to(device))
preds = val_logits.argmax(dim=1).cpu()
acc = (preds == val_labels).float().mean().item()
results[n_shot] = acc
print(f"{n_shot}-shot accuracy: {acc * 100:.2f}%")
return results
def knn_evaluation(ssl_model, train_loader, val_loader, k=20, device='cuda'):
"""k-NN evaluation of SSL representation quality (parameter-free)"""
ssl_model.eval()
train_features, train_labels = extract_features(ssl_model, train_loader, device)
val_features, val_labels = extract_features(ssl_model, val_loader, device)
train_features = nn.functional.normalize(train_features, dim=1)
val_features = nn.functional.normalize(val_features, dim=1)
correct = 0
total = 0
batch_size = 256
for i in range(0, len(val_features), batch_size):
batch_feats = val_features[i:i+batch_size].to(device)
batch_labels = val_labels[i:i+batch_size]
sims = torch.mm(batch_feats, train_features.to(device).T)
topk_sims, topk_indices = sims.topk(k, dim=1)
topk_labels = train_labels[topk_indices.cpu()]
preds = topk_labels.mode(dim=1).values
correct += (preds == batch_labels).sum().item()
total += len(batch_labels)
accuracy = correct / total
print(f"k-NN ({k}) accuracy: {accuracy * 100:.2f}%")
return accuracy
9.2 Domain-Specific SSL
# Medical images (chest X-ray example)
class MedicalSSL(nn.Module):
"""Self-supervised learning for medical images"""
def __init__(self, backbone):
super().__init__()
self.backbone = backbone
# Weaker augmentation for medical images
self.augmentation = T.Compose([
T.RandomResizedCrop(224, scale=(0.8, 1.0)),
T.RandomHorizontalFlip(p=0.5),
# Minimize color changes for medical images
T.RandomApply([T.ColorJitter(0.1, 0.1, 0.0, 0.0)], p=0.3),
T.ToTensor(),
T.Normalize([0.485], [0.229]) # Grayscale X-ray
])
def forward(self, x):
v1 = self.augmentation(x)
v2 = self.augmentation(x)
return v1, v2
# Satellite images
class SatelliteSSL:
"""Augmentation for satellite images (preserving geographic properties)"""
def __init__(self):
self.transform = T.Compose([
T.RandomResizedCrop(256, scale=(0.4, 1.0)),
# Satellite images: rotation invariance is important
T.RandomRotation(90),
T.RandomHorizontalFlip(),
T.RandomVerticalFlip(),
# Simulate time-of-day / seasonal variation
T.ColorJitter(brightness=0.4, contrast=0.4),
T.ToTensor(),
])
9.3 Using Pretrained SSL Models from the Hub
from transformers import (
ViTModel, ViTFeatureExtractor,
CLIPModel, CLIPProcessor
)
from PIL import Image
import torch
# Load MAE pretrained ViT
def load_mae_pretrained():
model = ViTModel.from_pretrained("facebook/vit-mae-base")
feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")
return model, feature_extractor
# DINO ViT feature extraction
def extract_dino_features(image_path):
model = ViTModel.from_pretrained("facebook/dino-vits16")
processor = ViTFeatureExtractor.from_pretrained("facebook/dino-vits16")
image = Image.open(image_path)
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# CLS token (global image representation)
cls_features = outputs.last_hidden_state[:, 0, :]
# Per-patch features (with spatial information)
patch_features = outputs.last_hidden_state[:, 1:, :]
return cls_features, patch_features
# Extract features with CLIP
def clip_feature_extraction(images, texts=None):
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
if texts:
inputs = processor(text=texts, images=images, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
return outputs.image_embeds, outputs.text_embeds
else:
inputs = processor(images=images, return_tensors="pt")
with torch.no_grad():
image_features = model.get_image_features(**inputs)
return image_features
Conclusion
Self-supervised learning is transforming AI. Key takeaways:
| Method | Year | Core Idea | Strength |
|---|---|---|---|
| SimCLR | 2020 | Strong augmentation + contrastive | Simple and powerful |
| MoCo | 2020 | Momentum queue | Batch-size independent |
| BYOL | 2020 | Contrastive without negatives | Batch-size independent |
| MAE | 2021 | 75% masking + pixel reconstruction | Efficient, ViT-optimal |
| DINO | 2021 | Self-distillation | Segmentation properties |
| CLIP | 2021 | Image-text contrastive | Zero-shot classification |
| BEiT | 2021 | Discrete token masking | BERT-style |
| DINOv2 | 2023 | Large-scale + iBOT | Universal feature extractor |
Learning Roadmap:
- Implement SimCLR to understand contrastive learning
- Study MAE to grasp the masking approach
- Apply CLIP to multimodal learning
- Fine-tune on your domain data
SSL has become the foundation of LLM pretraining and computer vision. As an approach that effectively leverages vast amounts of unlabeled data, SSL will continue to be a core driver of AI progress.
References
- SimCLR Paper - Chen et al., 2020
- MAE Paper - He et al., 2021
- DINO Paper - Caron et al., 2021
- CLIP Paper - Radford et al., 2021
- BYOL Paper - Grill et al., 2020