Skip to content
Published on

Self-Supervised Learning Complete Guide: SimCLR, MAE, DINO, CLIP

Authors

Table of Contents

  1. What is Self-Supervised Learning?
  2. Early SSL Methods
  3. Contrastive Learning
  4. Masked Autoencoder (MAE)
  5. DINO - Self-Distillation with No Labels
  6. CLIP - Contrastive Language-Image Pretraining
  7. BEiT - BERT for Images
  8. Language Model Pretraining
  9. Practical Applications

1. What is Self-Supervised Learning?

1.1 Comparing Learning Paradigms

Let's start by understanding the three major learning paradigms in deep learning.

Supervised Learning trains on labeled data. ImageNet's 10 million images each carry a human-annotated label, and models like ResNet and ViT are trained on this data. The problem is that label collection is extremely expensive. A single medical image can take a specialist dozens of minutes to annotate.

Unsupervised Learning discovers patterns in data without labels. K-Means clustering, PCA, and GANs fall into this category. However, traditional unsupervised learning struggles to learn representations directly usable for downstream tasks.

Self-Supervised Learning (SSL) combines the best of both worlds. It leverages large amounts of unlabeled data while generating supervisory signals from the data itself.

Core idea: The data itself is the label.

The internet contains billions of images and trillions of words. SSL exploits this vast unlabeled data to learn powerful representations.

1.2 The Label Scarcity Problem

In the real world, label collection is extremely difficult.

DomainLabel Collection CostReason
Medical imagesVery highRequires specialist physicians
Satellite imageryHighRequires geographic expertise
Legal documentsHighRequires legal experts
Industrial defect detectionHighRare defect samples
Speech recognitionMediumTranscription required

SSL fundamentally solves this problem. The approach is to first pretrain on large-scale unlabeled data, then fine-tune with a small amount of labeled data.

1.3 Pretext Tasks

The key to SSL is the pretext task. We create artificial prediction problems from the data itself to force the model to learn meaningful representations.

Conditions for a good pretext task:

  1. Must be constructable automatically without labels
  2. Solving the task well must require meaningful representations
  3. Must not be too easy or too hard

Examples:

  • Rotation prediction: Rotate an image by 0°, 90°, 180°, 270° and predict how many degrees it was rotated
  • Masking: Cover part of an image/text and restore it
  • Contrastive learning: Make two views of the same image similar in embedding space

1.4 Applications of SSL

SSL has become the foundation of modern AI:

  • GPT, BERT: Text SSL → basis for ChatGPT, Gemini
  • CLIP: Image-text SSL → basis for DALL-E, Stable Diffusion
  • MAE, DINO: Image SSL → basis for medical imaging, autonomous driving
  • wav2vec: Audio SSL → basis for speech recognition

2. Early SSL Methods

Early SSL research explored a variety of pretext tasks.

2.1 Rotation Prediction

Proposed by Gidaris et al. in 2018. The model predicts which of four rotations (0°, 90°, 180°, 270°) was applied to an image.

import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image

class RotationSSL(nn.Module):
    def __init__(self, backbone, num_classes=4):
        super().__init__()
        self.backbone = backbone
        self.classifier = nn.Linear(backbone.output_dim, num_classes)

    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)

def create_rotation_dataset(images):
    """Create (image, label) pairs by rotating images in 4 directions"""
    rotated_images = []
    labels = []

    for img in images:
        for k in range(4):  # 0, 90, 180, 270 degrees
            rotated = T.functional.rotate(img, k * 90)
            rotated_images.append(rotated)
            labels.append(k)

    return rotated_images, labels

def train_rotation_ssl(model, dataloader, optimizer, epochs=10):
    criterion = nn.CrossEntropyLoss()
    model.train()

    for epoch in range(epochs):
        total_loss = 0
        for images, labels in dataloader:
            images, labels = images.cuda(), labels.cuda()
            optimizer.zero_grad()
            logits = model(images)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

Limitation: While natural images have clear orientations, textures and sky images may not benefit.

2.2 Jigsaw Puzzle Solving

Proposed by Noroozi & Favaro in 2016. An image is divided into a 3x3 grid and shuffled; the model must predict the original ordering.

import itertools
import numpy as np

class JigsawSSL(nn.Module):
    def __init__(self, backbone, num_permutations=100):
        super().__init__()
        self.backbone = backbone
        self.patch_encoder = nn.Sequential(
            backbone,
            nn.Linear(backbone.output_dim, 512)
        )
        self.classifier = nn.Sequential(
            nn.Linear(512 * 9, 4096),
            nn.ReLU(),
            nn.Linear(4096, num_permutations)
        )

    def forward(self, patches):
        # patches: (batch, 9, C, H, W)
        batch_size = patches.size(0)
        patch_features = []

        for i in range(9):
            feat = self.patch_encoder(patches[:, i])  # (batch, 512)
            patch_features.append(feat)

        combined = torch.cat(patch_features, dim=1)  # (batch, 512*9)
        return self.classifier(combined)

def create_jigsaw_dataset(image, grid_size=3):
    """Split image into grid_size x grid_size patches and shuffle"""
    h, w = image.shape[-2:]
    patch_h, patch_w = h // grid_size, w // grid_size

    patches = []
    for i in range(grid_size):
        for j in range(grid_size):
            patch = image[...,
                         i*patch_h:(i+1)*patch_h,
                         j*patch_w:(j+1)*patch_w]
            patches.append(patch)

    permutation_idx = np.random.randint(0, len(PREDEFINED_PERMUTATIONS))
    perm = PREDEFINED_PERMUTATIONS[permutation_idx]
    shuffled = [patches[p] for p in perm]

    return torch.stack(shuffled), permutation_idx

2.3 Colorization

Proposed by Zhang et al. in 2016. A grayscale image is given as input; the model predicts the colors.

class ColorizationSSL(nn.Module):
    def __init__(self):
        super().__init__()
        # L channel input (lightness)
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.ReLU(),
        )
        # ab channel output (color)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(64, 2, 3, padding=1),  # ab channels
            nn.Tanh()
        )

    def forward(self, l_channel):
        features = self.encoder(l_channel)
        ab_channels = self.decoder(features)
        return ab_channels

from skimage.color import rgb2lab, lab2rgb

def prepare_colorization_data(rgb_image):
    lab = rgb2lab(rgb_image)
    l_channel = lab[:, :, 0:1]   # lightness channel
    ab_channels = lab[:, :, 1:]  # color channels
    return l_channel, ab_channels

3. Contrastive Learning

Contrastive Learning is the core of modern SSL. It exploded in development around 2020.

3.1 Core Idea

Pull similar things together, push different things apart.

Two different views (crop, flip, color jitter, etc.) of the same image should be close together in embedding space, while different images should be far apart.

Image x
├── Augmented view v1 → z1 (embedding)
└── Augmented view v2 → z2 (embedding)

Goal: z1 and z2 close together, z1 and z3 (from different image) far apart

3.2 InfoNCE Loss

The standard loss function for contrastive learning — also called NT-Xent (Normalized Temperature-scaled Cross Entropy).

L=12Ni=1N[logesim(zi,zj(i))/τkiesim(zi,zk)/τ]\mathcal{L} = -\frac{1}{2N} \sum_{i=1}^{N} \left[ \log \frac{e^{sim(z_i, z_{j(i)})/\tau}}{\sum_{k \neq i} e^{sim(z_i, z_k)/\tau}} \right]

where sim is cosine similarity and τ is the temperature parameter.

import torch
import torch.nn.functional as F

def info_nce_loss(z1, z2, temperature=0.5):
    """
    InfoNCE / NT-Xent Loss
    z1, z2: (batch_size, embedding_dim) - two views of the same images
    """
    batch_size = z1.size(0)

    # L2 normalize
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)

    # Concatenate: [z1_1, ..., z1_N, z2_1, ..., z2_N]
    z = torch.cat([z1, z2], dim=0)  # (2N, D)

    # Compute pairwise similarities
    similarity = torch.mm(z, z.t()) / temperature  # (2N, 2N)

    # Exclude self-similarity (diagonal = -inf)
    mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
    similarity.masked_fill_(mask, float('-inf'))

    # Positive pairs: index i and i+N
    labels = torch.cat([
        torch.arange(batch_size, 2 * batch_size),
        torch.arange(batch_size)
    ]).to(z.device)

    loss = F.cross_entropy(similarity, labels)
    return loss

# Usage
z1 = torch.randn(32, 128)
z2 = torch.randn(32, 128)
loss = info_nce_loss(z1, z2, temperature=0.5)
print(f"InfoNCE Loss: {loss.item():.4f}")

3.3 SimCLR

Published by Google in 2020. Simple Framework for Contrastive Learning. Simple but powerful.

Key Components:

  1. Data augmentation (strong augmentation is critical)
  2. Encoder (ResNet)
  3. Projection Head (MLP)
  4. NT-Xent Loss
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as T

class SimCLR(nn.Module):
    def __init__(self, backbone='resnet50', projection_dim=128, temperature=0.5):
        super().__init__()
        self.temperature = temperature

        if backbone == 'resnet50':
            resnet = models.resnet50(pretrained=False)
            self.encoder = nn.Sequential(*list(resnet.children())[:-1])
            self.feature_dim = 2048
        elif backbone == 'resnet18':
            resnet = models.resnet18(pretrained=False)
            self.encoder = nn.Sequential(*list(resnet.children())[:-1])
            self.feature_dim = 512

        # Projection Head (removed during fine-tuning)
        self.projection_head = nn.Sequential(
            nn.Linear(self.feature_dim, self.feature_dim),
            nn.ReLU(inplace=True),
            nn.Linear(self.feature_dim, projection_dim)
        )

    def encode(self, x):
        h = self.encoder(x)
        h = h.squeeze(-1).squeeze(-1)
        return h

    def forward(self, x):
        h = self.encode(x)
        z = self.projection_head(h)
        return z

    def contrastive_loss(self, z1, z2):
        return info_nce_loss(z1, z2, self.temperature)


class SimCLRDataAugmentation:
    """SimCLR's strong data augmentation"""
    def __init__(self, image_size=224, s=1.0):
        color_jitter = T.ColorJitter(
            brightness=0.8*s,
            contrast=0.8*s,
            saturation=0.8*s,
            hue=0.2*s
        )
        self.transform = T.Compose([
            T.RandomResizedCrop(image_size, scale=(0.2, 1.0)),
            T.RandomHorizontalFlip(p=0.5),
            T.RandomApply([color_jitter], p=0.8),
            T.RandomGrayscale(p=0.2),
            T.GaussianBlur(kernel_size=int(0.1 * image_size)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
        ])

    def __call__(self, x):
        return self.transform(x), self.transform(x)


def train_simclr(model, dataloader, optimizer, epochs=200):
    model.train()
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=epochs
    )

    for epoch in range(epochs):
        total_loss = 0
        for (x1, x2), _ in dataloader:
            x1, x2 = x1.cuda(), x2.cuda()

            z1 = model(x1)
            z2 = model(x2)

            loss = model.contrastive_loss(z1, z2)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        scheduler.step()
        avg_loss = total_loss / len(dataloader)

        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")


model = SimCLR(backbone='resnet50', projection_dim=128).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-6)

def get_finetuning_model(pretrained_simclr, num_classes):
    encoder = pretrained_simclr.encoder
    classifier = nn.Linear(pretrained_simclr.feature_dim, num_classes)
    return nn.Sequential(encoder, nn.Flatten(), classifier)

3.4 MoCo (Momentum Contrast)

Published by Facebook AI in 2020. Enables using many negative samples without requiring a large batch size.

Key idea: Stable key encoder via momentum update + Queue for managing negative samples

class MoCo(nn.Module):
    def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07):
        super().__init__()
        self.K = K   # queue size
        self.m = m   # momentum coefficient
        self.T = T   # temperature

        # Query encoder and key encoder
        self.encoder_q = base_encoder(num_classes=dim)
        self.encoder_k = base_encoder(num_classes=dim)

        # Key encoder: no gradients, updated by momentum
        for param_q, param_k in zip(
            self.encoder_q.parameters(),
            self.encoder_k.parameters()
        ):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False

        # Negative sample queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = F.normalize(self.queue, dim=0)
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        for param_q, param_k in zip(
            self.encoder_q.parameters(),
            self.encoder_k.parameters()
        ):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        batch_size = keys.shape[0]
        ptr = int(self.queue_ptr)
        self.queue[:, ptr:ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K
        self.queue_ptr[0] = ptr

    def forward(self, im_q, im_k):
        q = self.encoder_q(im_q)
        q = F.normalize(q, dim=1)

        with torch.no_grad():
            self._momentum_update_key_encoder()
            k = self.encoder_k(im_k)
            k = F.normalize(k, dim=1)

        # Positive logits: (batch, 1)
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        # Negative logits: (batch, K)
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

        logits = torch.cat([l_pos, l_neg], dim=1) / self.T
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

        loss = F.cross_entropy(logits, labels)

        self._dequeue_and_enqueue(k)
        return loss

3.5 BYOL (Bootstrap Your Own Latent)

Published by DeepMind in 2020. A groundbreaking method that works without negative samples.

Key insight: An online network predicts the representation of a target network. The target is updated by momentum.

import copy

class BYOL(nn.Module):
    def __init__(self, backbone, projection_dim=256, prediction_dim=128, tau=0.996):
        super().__init__()
        self.tau = tau

        # Online network
        self.online_encoder = backbone
        self.online_projector = nn.Sequential(
            nn.Linear(2048, 4096), nn.BatchNorm1d(4096), nn.ReLU(),
            nn.Linear(4096, projection_dim)
        )
        self.online_predictor = nn.Sequential(
            nn.Linear(projection_dim, 4096), nn.BatchNorm1d(4096), nn.ReLU(),
            nn.Linear(4096, prediction_dim)
        )

        # Target network (no gradients)
        self.target_encoder = copy.deepcopy(backbone)
        self.target_projector = copy.deepcopy(self.online_projector)

        for param in self.target_encoder.parameters():
            param.requires_grad = False
        for param in self.target_projector.parameters():
            param.requires_grad = False

    @torch.no_grad()
    def update_target(self):
        for online, target in zip(
            self.online_encoder.parameters(),
            self.target_encoder.parameters()
        ):
            target.data = self.tau * target.data + (1 - self.tau) * online.data

        for online, target in zip(
            self.online_projector.parameters(),
            self.target_projector.parameters()
        ):
            target.data = self.tau * target.data + (1 - self.tau) * online.data

    def forward(self, x1, x2):
        online_feat1 = self.online_encoder(x1).squeeze()
        online_proj1 = self.online_projector(online_feat1)
        online_pred1 = self.online_predictor(online_proj1)

        online_feat2 = self.online_encoder(x2).squeeze()
        online_proj2 = self.online_projector(online_feat2)
        online_pred2 = self.online_predictor(online_proj2)

        with torch.no_grad():
            target_feat1 = self.target_encoder(x1).squeeze()
            target_proj1 = self.target_projector(target_feat1)

            target_feat2 = self.target_encoder(x2).squeeze()
            target_proj2 = self.target_projector(target_feat2)

        loss1 = byol_loss(online_pred1, target_proj2.detach())
        loss2 = byol_loss(online_pred2, target_proj1.detach())

        return (loss1 + loss2).mean()

def byol_loss(p, z):
    p = F.normalize(p, dim=-1)
    z = F.normalize(z, dim=-1)
    return -(p * z).sum(dim=-1)

3.6 SimSiam

Published by Facebook AI in 2021. Even simpler than BYOL — no momentum needed to prevent collapse.

class SimSiam(nn.Module):
    def __init__(self, backbone, dim=2048, pred_dim=512):
        super().__init__()
        self.encoder = backbone

        self.projector = nn.Sequential(
            nn.Linear(dim, dim, bias=False),
            nn.BatchNorm1d(dim),
            nn.ReLU(inplace=True),
            nn.Linear(dim, dim, bias=False),
            nn.BatchNorm1d(dim),
            nn.ReLU(inplace=True),
            nn.Linear(dim, dim, bias=False),
            nn.BatchNorm1d(dim, affine=False)
        )

        self.predictor = nn.Sequential(
            nn.Linear(dim, pred_dim, bias=False),
            nn.BatchNorm1d(pred_dim),
            nn.ReLU(inplace=True),
            nn.Linear(pred_dim, dim)
        )

    def forward(self, x1, x2):
        z1 = self.projector(self.encoder(x1).squeeze())
        z2 = self.projector(self.encoder(x2).squeeze())

        p1 = self.predictor(z1)
        p2 = self.predictor(z2)

        # Stop gradient - the key to the asymmetric design
        loss = simsiam_loss(p1, z2.detach()) / 2 + \
               simsiam_loss(p2, z1.detach()) / 2
        return loss

def simsiam_loss(p, z):
    z = z.detach()  # Stop gradient
    p = F.normalize(p, dim=1)
    z = F.normalize(z, dim=1)
    return -(p * z).sum(dim=1).mean()

4. Masked Autoencoder (MAE)

MAE, published by Kaiming He et al. in 2021, applies BERT's masking approach from NLP to images.

4.1 Core Idea

Mask 75% of an image and reconstruct the full image from the remaining 25%.

Why such a high masking ratio of 75%?

  • Text tokens carry high semantic density, so 15% masking suffices for BERT
  • Adjacent pixels in images have high redundancy — 75% masking forces genuine understanding

4.2 Asymmetric Encoder-Decoder

MAE's innovation: The encoder processes only visible patches; the decoder reconstructs the full image.

Input image (196 patches)
75% masking
49 visible patches → Encoder (ViT-Large)Encoded representation
196 patches (with mask tokens)Decoder (shallow ViT)Pixel reconstruction

The encoder processes only 25% of the original size — 3-4x faster training.

4.3 Complete MAE Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2

        self.projection = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        x = self.projection(x)   # (batch, embed_dim, H/p, W/p)
        x = x.flatten(2)         # (batch, embed_dim, num_patches)
        x = x.transpose(1, 2)   # (batch, num_patches, embed_dim)
        return x


class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., dropout=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        mlp_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        norm_x = self.norm1(x)
        attn_out, _ = self.attn(norm_x, norm_x, norm_x)
        x = x + attn_out
        x = x + self.mlp(self.norm2(x))
        return x


class MAE(nn.Module):
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_channels=3,
        encoder_dim=768,
        encoder_depth=12,
        encoder_heads=12,
        decoder_dim=512,
        decoder_depth=8,
        decoder_heads=16,
        mask_ratio=0.75,
    ):
        super().__init__()
        self.mask_ratio = mask_ratio
        self.patch_size = patch_size
        num_patches = (img_size // patch_size) ** 2

        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, encoder_dim)

        self.cls_token = nn.Parameter(torch.zeros(1, 1, encoder_dim))
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, encoder_dim),
            requires_grad=False
        )

        self.encoder_blocks = nn.ModuleList([
            TransformerBlock(encoder_dim, encoder_heads)
            for _ in range(encoder_depth)
        ])
        self.encoder_norm = nn.LayerNorm(encoder_dim)

        self.decoder_embed = nn.Linear(encoder_dim, decoder_dim, bias=True)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))

        self.decoder_pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, decoder_dim),
            requires_grad=False
        )

        self.decoder_blocks = nn.ModuleList([
            TransformerBlock(decoder_dim, decoder_heads)
            for _ in range(decoder_depth)
        ])
        self.decoder_norm = nn.LayerNorm(decoder_dim)
        self.decoder_pred = nn.Linear(decoder_dim, patch_size**2 * in_channels)

    def random_masking(self, x, mask_ratio):
        N, L, D = x.shape
        len_keep = int(L * (1 - mask_ratio))

        noise = torch.rand(N, L, device=x.device)
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(
            x, dim=1,
            index=ids_keep.unsqueeze(-1).repeat(1, 1, D)
        )

        mask = torch.ones(N, L, device=x.device)
        mask[:, :len_keep] = 0
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

    def forward_encoder(self, x, mask_ratio):
        x = self.patch_embed(x)
        x = x + self.pos_embed[:, 1:, :]
        x, mask, ids_restore = self.random_masking(x, mask_ratio)

        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        for block in self.encoder_blocks:
            x = block(x)
        x = self.encoder_norm(x)

        return x, mask, ids_restore

    def forward_decoder(self, x, ids_restore):
        x = self.decoder_embed(x)

        mask_tokens = self.mask_token.repeat(
            x.shape[0],
            ids_restore.shape[1] + 1 - x.shape[1],
            1
        )
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
        x_ = torch.gather(
            x_, dim=1,
            index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
        )
        x = torch.cat([x[:, :1, :], x_], dim=1)

        x = x + self.decoder_pos_embed

        for block in self.decoder_blocks:
            x = block(x)
        x = self.decoder_norm(x)
        x = self.decoder_pred(x)
        x = x[:, 1:, :]  # Remove CLS token
        return x

    def forward(self, imgs, mask_ratio=0.75):
        latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
        pred = self.forward_decoder(latent, ids_restore)
        loss = self.forward_loss(imgs, pred, mask)
        return loss, pred, mask

    def forward_loss(self, imgs, pred, mask):
        target = self.patchify(imgs)

        mean = target.mean(dim=-1, keepdim=True)
        var = target.var(dim=-1, keepdim=True)
        target = (target - mean) / (var + 1.e-6).sqrt()

        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)
        loss = (loss * mask).sum() / mask.sum()
        return loss

    def patchify(self, imgs):
        p = self.patch_size
        h = w = imgs.shape[2] // p
        x = imgs.reshape(imgs.shape[0], 3, h, p, w, p)
        x = torch.einsum('nchpwq->nhwpqc', x)
        x = x.reshape(imgs.shape[0], h * w, p**2 * 3)
        return x


# Create and train the MAE model
mae_model = MAE(
    img_size=224,
    patch_size=16,
    encoder_dim=768,
    encoder_depth=12,
    encoder_heads=12,
    decoder_dim=512,
    decoder_depth=8,
    decoder_heads=16,
    mask_ratio=0.75
).cuda()

optimizer = torch.optim.AdamW(
    mae_model.parameters(),
    lr=1.5e-4,
    betas=(0.9, 0.95),
    weight_decay=0.05
)

def train_step(model, imgs, optimizer):
    model.train()
    optimizer.zero_grad()
    loss, pred, mask = model(imgs, mask_ratio=0.75)
    loss.backward()
    optimizer.step()
    return loss.item()

5. DINO

DINO (Self-DIstillation with NO labels) was published by Facebook AI Research in 2021. When ViT is trained with SSL, remarkable properties emerge.

5.1 Key Discovery

DINO-trained ViT's self-attention maps perform semantic segmentation without labels. The model perfectly separates foreground from background.

5.2 Self-Distillation Architecture

Student Network                 Teacher Network
      ↑                               ↑
 Local Views (many)             Global Views (2)
  • Teacher: Processes full-image views, richer context
  • Student: Processes local crops, learns to match teacher's predictions
  • Teacher update: EMA of the student (no gradients)

5.3 Centering and Sharpening

Two techniques to prevent collapse:

class DINOHead(nn.Module):
    def __init__(self, in_dim, out_dim=65536, bottleneck_dim=256):
        super().__init__()
        layers = [nn.Linear(in_dim, 2048), nn.GELU()]
        layers += [nn.Linear(2048, 2048), nn.GELU()]
        layers += [nn.Linear(2048, bottleneck_dim)]
        self.mlp = nn.Sequential(*layers)

        self.last_layer = nn.utils.weight_norm(
            nn.Linear(bottleneck_dim, out_dim, bias=False)
        )
        self.last_layer.weight_g.data.fill_(1)
        self.last_layer.weight_g.requires_grad = False

    def forward(self, x):
        x = self.mlp(x)
        x = F.normalize(x, dim=-1)
        x = self.last_layer(x)
        return x


class DINO(nn.Module):
    def __init__(self, backbone, out_dim=65536, teacher_temp=0.04, student_temp=0.1):
        super().__init__()
        self.student_temp = student_temp
        self.teacher_temp = teacher_temp

        self.student = backbone
        self.student_head = DINOHead(self.student.embed_dim, out_dim)

        self.teacher = copy.deepcopy(backbone)
        self.teacher_head = copy.deepcopy(self.student_head)

        for p in self.teacher.parameters():
            p.requires_grad = False
        for p in self.teacher_head.parameters():
            p.requires_grad = False

        self.register_buffer("center", torch.zeros(1, out_dim))

    @torch.no_grad()
    def update_teacher(self, momentum):
        for student_ps, teacher_ps in zip(
            self.student.parameters(), self.teacher.parameters()
        ):
            teacher_ps.data.mul_(momentum).add_((1 - momentum) * student_ps.data)

        for student_ps, teacher_ps in zip(
            self.student_head.parameters(), self.teacher_head.parameters()
        ):
            teacher_ps.data.mul_(momentum).add_((1 - momentum) * student_ps.data)

    @torch.no_grad()
    def update_center(self, teacher_output):
        batch_center = teacher_output.mean(dim=0, keepdim=True)
        self.center = self.center * 0.9 + batch_center * 0.1

    def dino_loss(self, student_output, teacher_output):
        student_out = student_output / self.student_temp
        student_out = student_out.chunk(2)

        # Teacher: centering + sharpening
        teacher_out = F.softmax(
            (teacher_output - self.center) / self.teacher_temp, dim=-1
        )
        teacher_out = teacher_out.detach().chunk(2)

        total_loss = 0
        n_loss_terms = 0

        for iq, q in enumerate(teacher_out):
            for v in range(len(student_out)):
                if v == iq:
                    continue
                loss = torch.sum(
                    -q * F.log_softmax(student_out[v], dim=-1), dim=-1
                )
                total_loss += loss.mean()
                n_loss_terms += 1

        return total_loss / n_loss_terms

    def forward(self, global_views, local_views, momentum=0.996):
        all_views = global_views + local_views
        student_output = torch.cat([
            self.student_head(self.student(view))
            for view in all_views
        ])

        with torch.no_grad():
            teacher_output = torch.cat([
                self.teacher_head(self.teacher(view))
                for view in global_views
            ])

        loss = self.dino_loss(student_output, teacher_output)
        self.update_center(teacher_output)
        self.update_teacher(momentum)

        return loss


class DINOAugmentation:
    def __init__(self, global_crops_scale=(0.4, 1.), local_crops_scale=(0.05, 0.4),
                 n_local_crops=8, image_size=224, local_size=96):
        self.n_local_crops = n_local_crops

        self.global_transform = T.Compose([
            T.RandomResizedCrop(image_size, scale=global_crops_scale,
                               interpolation=T.InterpolationMode.BICUBIC),
            T.RandomHorizontalFlip(),
            T.RandomApply([T.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
            T.RandomGrayscale(p=0.2),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])

        self.local_transform = T.Compose([
            T.RandomResizedCrop(local_size, scale=local_crops_scale,
                               interpolation=T.InterpolationMode.BICUBIC),
            T.RandomHorizontalFlip(),
            T.RandomApply([T.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
            T.RandomGrayscale(p=0.2),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])

    def __call__(self, image):
        global_views = [self.global_transform(image) for _ in range(2)]
        local_views = [self.local_transform(image) for _ in range(self.n_local_crops)]
        return global_views, local_views

5.4 DINOv2

DINOv2, released in 2023, is even more powerful:

  • Trained on 142 million curated images
  • Added iBOT (Image BERT Pre-Training with Online Tokenizer) loss
  • More stable training dynamics

6. CLIP

CLIP (Contrastive Language-Image Pretraining) is an innovative model published by OpenAI in 2021.

6.1 Text-Image Contrastive Learning

Trained on 400 million (image, text) pairs collected from the internet.

Image Encoder (ViT / ResNet)
  Image embedding

Text Encoder (Transformer)
  Text embedding

Goal: Matching (image, text) pairs → similar embeddings
      Non-matching pairs → dissimilar embeddings

6.2 CLIP Loss

def clip_loss(image_embeddings, text_embeddings, temperature):
    """
    CLIP Symmetric Contrastive Loss
    image_embeddings: (N, D)
    text_embeddings: (N, D)
    """
    image_embeddings = F.normalize(image_embeddings, dim=1)
    text_embeddings = F.normalize(text_embeddings, dim=1)

    # Similarity matrix: (N, N)
    logits = torch.matmul(image_embeddings, text_embeddings.T) / temperature

    # Labels: diagonal (image i matches text i)
    labels = torch.arange(len(logits)).to(logits.device)

    # Bidirectional loss: image→text and text→image
    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.T, labels)

    return (loss_i2t + loss_t2i) / 2

6.3 Full CLIP Usage

from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
import torch.nn.functional as F

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# ---- Zero-shot Image Classification ----
def zero_shot_classify(image_path, candidate_labels):
    """Classify an image without any label examples"""
    image = Image.open(image_path)

    # Create text prompts
    texts = [f"a photo of a {label}" for label in candidate_labels]

    inputs = processor(
        text=texts,
        images=image,
        return_tensors="pt",
        padding=True
    )

    with torch.no_grad():
        outputs = model(**inputs)

    logits_per_image = outputs.logits_per_image
    probs = logits_per_image.softmax(dim=1)

    results = {}
    for label, prob in zip(candidate_labels, probs[0]):
        results[label] = prob.item()

    return results

# Usage
labels = ["cat", "dog", "car", "airplane", "bird"]
results = zero_shot_classify("example.jpg", labels)
for label, prob in sorted(results.items(), key=lambda x: -x[1]):
    print(f"{label}: {prob:.4f}")


# ---- Image-Text Similarity Search ----
class CLIPRetrieval:
    def __init__(self, model_name="openai/clip-vit-base-patch32"):
        self.model = CLIPModel.from_pretrained(model_name)
        self.processor = CLIPProcessor.from_pretrained(model_name)
        self.model.eval()

        self.image_embeddings = None
        self.image_paths = []

    def index_images(self, image_paths):
        """Index an image database"""
        self.image_paths = image_paths
        embeddings = []

        for path in image_paths:
            image = Image.open(path)
            inputs = self.processor(images=image, return_tensors="pt")
            with torch.no_grad():
                embedding = self.model.get_image_features(**inputs)
            embeddings.append(F.normalize(embedding, dim=1))

        self.image_embeddings = torch.cat(embeddings, dim=0)
        print(f"Indexed {len(image_paths)} images")

    def search(self, query_text, top_k=5):
        """Search images using a text query"""
        inputs = self.processor(text=[query_text], return_tensors="pt", padding=True)
        with torch.no_grad():
            text_embedding = self.model.get_text_features(**inputs)
        text_embedding = F.normalize(text_embedding, dim=1)

        similarities = torch.matmul(text_embedding, self.image_embeddings.T).squeeze()
        top_indices = similarities.argsort(descending=True)[:top_k]

        results = [
            (self.image_paths[i], similarities[i].item())
            for i in top_indices
        ]
        return results


# Visualize CLIP embeddings with t-SNE
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np

def visualize_clip_embeddings(images, texts):
    """t-SNE visualization of CLIP embeddings"""
    image_inputs = processor(images=images, return_tensors="pt")
    text_inputs = processor(text=texts, return_tensors="pt", padding=True)

    with torch.no_grad():
        image_embs = F.normalize(model.get_image_features(**image_inputs), dim=1)
        text_embs = F.normalize(model.get_text_features(**text_inputs), dim=1)

    all_embs = torch.cat([image_embs, text_embs], dim=0).numpy()

    tsne = TSNE(n_components=2, random_state=42)
    reduced = tsne.fit_transform(all_embs)

    n = len(images)
    plt.figure(figsize=(12, 8))
    plt.scatter(reduced[:n, 0], reduced[:n, 1], c='blue', label='Images', alpha=0.7)
    plt.scatter(reduced[n:, 0], reduced[n:, 1], c='red', label='Texts', alpha=0.7)

    for i, text in enumerate(texts):
        plt.annotate(text, reduced[n+i], fontsize=8)

    plt.legend()
    plt.title('CLIP Embeddings (t-SNE)')
    plt.savefig('clip_tsne.png', dpi=150, bbox_inches='tight')
    plt.show()

7. BEiT

BEiT (BERT Pre-Training of Image Transformers) was published by Microsoft in 2021. Images are converted to discrete tokens and trained in a BERT-like fashion.

7.1 Discrete Visual Tokens (dVAE)

BEiT's core: instead of predicting raw pixels, the model predicts discrete tokens.

Step 1: dVAE converts image → discrete tokens (vocabulary of 8192)
Step 2: ViT predicts the token of each masked patch
from transformers import BeitForMaskedImageModeling, BeitImageProcessor
import torch

class BEiTPretraining:
    def __init__(self, model_name="microsoft/beit-base-patch16-224"):
        self.model = BeitForMaskedImageModeling.from_pretrained(model_name)
        self.processor = BeitImageProcessor.from_pretrained(model_name)

    def create_bool_masked_pos(self, num_patches, mask_ratio=0.4):
        num_masked = int(num_patches * mask_ratio)
        mask = torch.zeros(num_patches, dtype=torch.bool)
        masked_indices = torch.randperm(num_patches)[:num_masked]
        mask[masked_indices] = True
        return mask

    def forward(self, images):
        inputs = self.processor(images, return_tensors="pt")
        pixel_values = inputs.pixel_values

        num_patches = (224 // 16) ** 2  # 196
        bool_masked_pos = self.create_bool_masked_pos(num_patches)
        bool_masked_pos = bool_masked_pos.unsqueeze(0).expand(
            pixel_values.size(0), -1
        )

        outputs = self.model(
            pixel_values=pixel_values,
            bool_masked_pos=bool_masked_pos
        )

        return outputs.loss

8. Language Model Pretraining

SSL has a long history in NLP.

8.1 GPT - Next Token Prediction

Autoregressive approach that predicts the next token from left to right.

import torch
import torch.nn as nn
from torch.nn import functional as F

class GPTLanguageModel(nn.Module):
    def __init__(self, vocab_size, block_size, n_embd, n_head, n_layer, dropout=0.1):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, n_embd)
        self.position_embedding = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[
            TransformerBlock(n_embd, n_head, dropout=dropout)
            for _ in range(n_layer)
        ])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        self.block_size = block_size

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding(idx)
        pos_emb = self.position_embedding(torch.arange(T, device=idx.device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        if targets is None:
            return logits, None

        B, T, C = logits.shape
        loss = F.cross_entropy(logits.view(B * T, C), targets.view(B * T))
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature

            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = float('-inf')

            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

8.2 BERT - Masked Language Modeling

Learns bidirectional context. 15% of tokens are masked and must be reconstructed.

from transformers import BertForMaskedLM, BertTokenizer
import torch

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForMaskedLM.from_pretrained("bert-base-uncased")

def create_mlm_input(text, mask_prob=0.15):
    """Generate masked language model input"""
    tokens = tokenizer.encode(text, return_tensors="pt")
    labels = tokens.clone()

    rand = torch.rand(tokens.shape)
    mask_arr = (rand < mask_prob) & (tokens != tokenizer.cls_token_id) & \
               (tokens != tokenizer.sep_token_id)

    # Masking strategy:
    # - 80%: replace with [MASK] token
    # - 10%: replace with a random token
    # - 10%: keep original token
    for i, masked in enumerate(mask_arr[0]):
        if masked:
            prob = torch.rand(1).item()
            if prob < 0.8:
                tokens[0][i] = tokenizer.mask_token_id
            elif prob < 0.9:
                tokens[0][i] = torch.randint(len(tokenizer), (1,))

    labels[~mask_arr] = -100
    return tokens, labels

# Compute MLM loss
text = "The quick brown fox jumps over the lazy dog"
input_ids, labels = create_mlm_input(text)

with torch.no_grad():
    outputs = model(input_ids=input_ids, labels=labels)
    print(f"MLM Loss: {outputs.loss.item():.4f}")

# Predict masked tokens
input_text = "The [MASK] brown fox"
inputs = tokenizer(input_text, return_tensors="pt")

with torch.no_grad():
    logits = model(**inputs).logits

mask_idx = (inputs.input_ids == tokenizer.mask_token_id).nonzero()[0][1]
top_5 = logits[0, mask_idx].topk(5)
for token_id, score in zip(top_5.indices, top_5.values):
    print(f"{tokenizer.decode([token_id])}: {score:.3f}")

8.3 T5 - Text-to-Text Transfer

from transformers import T5ForConditionalGeneration, T5Tokenizer

tokenizer = T5Tokenizer.from_pretrained("t5-base")
model = T5ForConditionalGeneration.from_pretrained("t5-base")

# T5 unifies all tasks as text-to-text
tasks = {
    "translation": "translate English to French: The weather is nice today",
    "summarization": "summarize: Scientists have discovered a new species of deep-sea fish...",
    "sentiment": "sentiment analysis: This movie was absolutely fantastic!",
    "qa": "question: What is the capital of France? context: Paris is the capital..."
}

for task_name, input_text in tasks.items():
    inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
    with torch.no_grad():
        output_ids = model.generate(
            inputs.input_ids,
            max_length=128,
            num_beams=4,
            early_stopping=True
        )
    output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    print(f"\n[{task_name}]")
    print(f"Input: {input_text[:60]}...")
    print(f"Output: {output}")

9. Practical Applications

9.1 Few-shot Learning

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

class LinearProbe(nn.Module):
    """Linear classifier on top of SSL features"""
    def __init__(self, feature_dim, num_classes):
        super().__init__()
        self.linear = nn.Linear(feature_dim, num_classes)

    def forward(self, x):
        return self.linear(x)


def extract_features(model, dataloader, device='cuda'):
    """Extract features from a pretrained model"""
    model.eval()
    features_list = []
    labels_list = []

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            features = model.encode(images)
            features_list.append(features.cpu())
            labels_list.append(labels)

    return torch.cat(features_list), torch.cat(labels_list)


def few_shot_evaluation(ssl_model, train_loader, val_loader,
                        n_shots=[1, 5, 10, 100], device='cuda'):
    """Evaluate with n labeled examples per class"""
    ssl_model.eval()

    all_features, all_labels = extract_features(ssl_model, train_loader, device)
    val_features, val_labels = extract_features(ssl_model, val_loader, device)

    results = {}
    num_classes = len(all_labels.unique())

    for n_shot in n_shots:
        selected_indices = []
        for cls in range(num_classes):
            cls_indices = (all_labels == cls).nonzero(as_tuple=True)[0]
            selected = cls_indices[torch.randperm(len(cls_indices))[:n_shot]]
            selected_indices.append(selected)

        selected_indices = torch.cat(selected_indices)
        train_feats = all_features[selected_indices]
        train_labs = all_labels[selected_indices]

        probe = LinearProbe(all_features.shape[1], num_classes).to(device)
        optimizer = torch.optim.Adam(probe.parameters(), lr=1e-3)
        criterion = nn.CrossEntropyLoss()

        train_feats = train_feats.to(device)
        train_labs = train_labs.to(device)

        for epoch in range(100):
            probe.train()
            logits = probe(train_feats)
            loss = criterion(logits, train_labs)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        probe.eval()
        with torch.no_grad():
            val_logits = probe(val_features.to(device))
            preds = val_logits.argmax(dim=1).cpu()
            acc = (preds == val_labels).float().mean().item()

        results[n_shot] = acc
        print(f"{n_shot}-shot accuracy: {acc * 100:.2f}%")

    return results


def knn_evaluation(ssl_model, train_loader, val_loader, k=20, device='cuda'):
    """k-NN evaluation of SSL representation quality (parameter-free)"""
    ssl_model.eval()

    train_features, train_labels = extract_features(ssl_model, train_loader, device)
    val_features, val_labels = extract_features(ssl_model, val_loader, device)

    train_features = nn.functional.normalize(train_features, dim=1)
    val_features = nn.functional.normalize(val_features, dim=1)

    correct = 0
    total = 0
    batch_size = 256

    for i in range(0, len(val_features), batch_size):
        batch_feats = val_features[i:i+batch_size].to(device)
        batch_labels = val_labels[i:i+batch_size]

        sims = torch.mm(batch_feats, train_features.to(device).T)
        topk_sims, topk_indices = sims.topk(k, dim=1)

        topk_labels = train_labels[topk_indices.cpu()]
        preds = topk_labels.mode(dim=1).values

        correct += (preds == batch_labels).sum().item()
        total += len(batch_labels)

    accuracy = correct / total
    print(f"k-NN ({k}) accuracy: {accuracy * 100:.2f}%")
    return accuracy

9.2 Domain-Specific SSL

# Medical images (chest X-ray example)
class MedicalSSL(nn.Module):
    """Self-supervised learning for medical images"""
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone

        # Weaker augmentation for medical images
        self.augmentation = T.Compose([
            T.RandomResizedCrop(224, scale=(0.8, 1.0)),
            T.RandomHorizontalFlip(p=0.5),
            # Minimize color changes for medical images
            T.RandomApply([T.ColorJitter(0.1, 0.1, 0.0, 0.0)], p=0.3),
            T.ToTensor(),
            T.Normalize([0.485], [0.229])  # Grayscale X-ray
        ])

    def forward(self, x):
        v1 = self.augmentation(x)
        v2 = self.augmentation(x)
        return v1, v2


# Satellite images
class SatelliteSSL:
    """Augmentation for satellite images (preserving geographic properties)"""
    def __init__(self):
        self.transform = T.Compose([
            T.RandomResizedCrop(256, scale=(0.4, 1.0)),
            # Satellite images: rotation invariance is important
            T.RandomRotation(90),
            T.RandomHorizontalFlip(),
            T.RandomVerticalFlip(),
            # Simulate time-of-day / seasonal variation
            T.ColorJitter(brightness=0.4, contrast=0.4),
            T.ToTensor(),
        ])

9.3 Using Pretrained SSL Models from the Hub

from transformers import (
    ViTModel, ViTFeatureExtractor,
    CLIPModel, CLIPProcessor
)
from PIL import Image
import torch

# Load MAE pretrained ViT
def load_mae_pretrained():
    model = ViTModel.from_pretrained("facebook/vit-mae-base")
    feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")
    return model, feature_extractor

# DINO ViT feature extraction
def extract_dino_features(image_path):
    model = ViTModel.from_pretrained("facebook/dino-vits16")
    processor = ViTFeatureExtractor.from_pretrained("facebook/dino-vits16")

    image = Image.open(image_path)
    inputs = processor(images=image, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)

    # CLS token (global image representation)
    cls_features = outputs.last_hidden_state[:, 0, :]
    # Per-patch features (with spatial information)
    patch_features = outputs.last_hidden_state[:, 1:, :]

    return cls_features, patch_features

# Extract features with CLIP
def clip_feature_extraction(images, texts=None):
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    if texts:
        inputs = processor(text=texts, images=images, return_tensors="pt", padding=True)
        with torch.no_grad():
            outputs = model(**inputs)
        return outputs.image_embeds, outputs.text_embeds
    else:
        inputs = processor(images=images, return_tensors="pt")
        with torch.no_grad():
            image_features = model.get_image_features(**inputs)
        return image_features

Conclusion

Self-supervised learning is transforming AI. Key takeaways:

MethodYearCore IdeaStrength
SimCLR2020Strong augmentation + contrastiveSimple and powerful
MoCo2020Momentum queueBatch-size independent
BYOL2020Contrastive without negativesBatch-size independent
MAE202175% masking + pixel reconstructionEfficient, ViT-optimal
DINO2021Self-distillationSegmentation properties
CLIP2021Image-text contrastiveZero-shot classification
BEiT2021Discrete token maskingBERT-style
DINOv22023Large-scale + iBOTUniversal feature extractor

Learning Roadmap:

  1. Implement SimCLR to understand contrastive learning
  2. Study MAE to grasp the masking approach
  3. Apply CLIP to multimodal learning
  4. Fine-tune on your domain data

SSL has become the foundation of LLM pretraining and computer vision. As an approach that effectively leverages vast amounts of unlabeled data, SSL will continue to be a core driver of AI progress.


References