- Authors

- Name
- Youngju Kim
- @fjvbn20031
목차
- 자기지도 학습이란?
- 초기 SSL 방법들
- 대조 학습 (Contrastive Learning)
- Masked Autoencoder (MAE)
- DINO - 라벨 없는 자기 증류
- CLIP - 텍스트-이미지 대조 학습
- BEiT - BERT의 이미지 버전
- 언어 모델의 사전학습
- 실전 활용
1. 자기지도 학습이란?
1.1 학습 패러다임의 비교
딥러닝의 세 가지 주요 학습 패러다임을 이해하는 것부터 시작하겠습니다.
**지도 학습 (Supervised Learning)**은 레이블이 붙은 데이터로 학습합니다. ImageNet 1000만 장의 이미지에 인간이 하나하나 레이블을 달았고, 이 데이터로 ResNet, ViT 같은 모델을 학습합니다. 문제는 레이블 수집 비용이 엄청나다는 것입니다. 의료 이미지 한 장에 전문 의사가 레이블을 다는 데 수십 분이 걸릴 수 있습니다.
**비지도 학습 (Unsupervised Learning)**은 레이블 없이 데이터의 구조를 파악합니다. K-Means 클러스터링, PCA, GAN이 여기에 해당합니다. 하지만 전통적인 비지도 학습은 다운스트림 태스크에 바로 쓸 수 있는 표현(representation)을 학습하기 어려웠습니다.
**자기지도 학습 (Self-Supervised Learning, SSL)**은 두 세계의 장점을 결합합니다. 레이블 없는 대규모 데이터를 활용하면서도, 데이터 자체에서 지도 신호를 만들어냅니다.
핵심 아이디어: 데이터 자체가 레이블이다.
인터넷에는 수십억 장의 이미지, 수조 개의 단어가 존재합니다. 자기지도 학습은 이 방대한 비레이블 데이터를 활용해 강력한 표현을 학습합니다.
1.2 레이블 부족 문제
현실 세계에서는 레이블 수집이 매우 어렵습니다.
| 도메인 | 레이블 수집 비용 | 이유 |
|---|---|---|
| 의료 이미지 | 매우 높음 | 전문 의사 필요 |
| 위성 이미지 | 높음 | 전문 지리 지식 필요 |
| 법률 문서 | 높음 | 법률 전문가 필요 |
| 산업 결함 감지 | 높음 | 희귀한 불량 샘플 |
| 음성 인식 | 중간 | 전사(transcription) 필요 |
SSL은 이 문제를 근본적으로 해결합니다. 레이블 없는 대규모 데이터로 먼저 사전학습(pretraining)하고, 소량의 레이블 데이터로 파인튜닝(fine-tuning)하는 방식입니다.
1.3 프리텍스트 태스크 (Pretext Task)
SSL의 핵심은 프리텍스트 태스크입니다. 데이터 자체에서 인위적인 예측 문제를 만들어 모델이 의미 있는 표현을 학습하도록 합니다.
좋은 프리텍스트 태스크의 조건:
- 레이블 없이 자동으로 만들 수 있어야 함
- 태스크를 잘 풀려면 의미 있는 표현이 필요해야 함
- 너무 쉽거나 너무 어렵지 않아야 함
예시:
- 이미지 회전 예측: 이미지를 0°, 90°, 180°, 270° 회전 후, 몇 도 회전했는지 예측
- 마스킹: 이미지/텍스트 일부를 가리고 복원
- 대조 학습: 같은 이미지의 두 뷰가 유사하도록 학습
1.4 SSL의 응용 분야
SSL은 현대 AI의 근간이 되었습니다:
- GPT, BERT: 텍스트 SSL → ChatGPT, Gemini의 기반
- CLIP: 이미지-텍스트 SSL → DALL-E, Stable Diffusion의 기반
- MAE, DINO: 이미지 SSL → 의료 영상, 자율주행의 기반
- wav2vec: 오디오 SSL → 음성 인식의 기반
2. 초기 SSL 방법들
초기 SSL 연구들은 다양한 프리텍스트 태스크를 탐구했습니다.
2.1 회전 예측 (Rotation Prediction)
2018년 Gidaris et al.이 제안한 방법입니다. 이미지를 4가지 각도(0°, 90°, 180°, 270°)로 회전시키고, 어떤 회전이 적용되었는지 분류하는 문제입니다.
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
class RotationSSL(nn.Module):
def __init__(self, backbone, num_classes=4):
super().__init__()
self.backbone = backbone
self.classifier = nn.Linear(backbone.output_dim, num_classes)
def forward(self, x):
features = self.backbone(x)
return self.classifier(features)
def create_rotation_dataset(images):
"""이미지를 4가지 방향으로 회전하여 (image, label) 쌍 생성"""
rotated_images = []
labels = []
for img in images:
for k in range(4): # 0, 90, 180, 270도
# PIL Image를 k*90도 회전
rotated = T.functional.rotate(img, k * 90)
rotated_images.append(rotated)
labels.append(k)
return rotated_images, labels
# 학습 루프
def train_rotation_ssl(model, dataloader, optimizer, epochs=10):
criterion = nn.CrossEntropyLoss()
model.train()
for epoch in range(epochs):
total_loss = 0
for images, labels in dataloader:
images, labels = images.cuda(), labels.cuda()
optimizer.zero_grad()
logits = model(images)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
이 방법의 한계: 자연 이미지에는 명확한 방향성이 있지만, 하늘 사진이나 텍스처 이미지에는 효과적이지 않습니다.
2.2 Jigsaw Puzzle 풀기
2016년 Noroozi & Favaro가 제안했습니다. 이미지를 3x3 그리드로 자른 후 무작위로 섞고, 원래 순서를 예측합니다.
import itertools
import numpy as np
class JigsawSSL(nn.Module):
def __init__(self, backbone, num_permutations=100):
super().__init__()
self.backbone = backbone
# 각 패치를 독립적으로 처리
self.patch_encoder = nn.Sequential(
backbone,
nn.Linear(backbone.output_dim, 512)
)
# 9개 패치를 모두 처리 후 순열 분류
self.classifier = nn.Sequential(
nn.Linear(512 * 9, 4096),
nn.ReLU(),
nn.Linear(4096, num_permutations)
)
def forward(self, patches):
# patches: (batch, 9, C, H, W)
batch_size = patches.size(0)
patch_features = []
for i in range(9):
feat = self.patch_encoder(patches[:, i]) # (batch, 512)
patch_features.append(feat)
# 모든 패치 특징 연결
combined = torch.cat(patch_features, dim=1) # (batch, 512*9)
return self.classifier(combined)
def create_jigsaw_dataset(image, grid_size=3):
"""이미지를 grid_size x grid_size 패치로 분할하고 섞기"""
h, w = image.shape[-2:]
patch_h, patch_w = h // grid_size, w // grid_size
patches = []
for i in range(grid_size):
for j in range(grid_size):
patch = image[...,
i*patch_h:(i+1)*patch_h,
j*patch_w:(j+1)*patch_w]
patches.append(patch)
# 미리 정의된 순열 집합에서 선택
permutation_idx = np.random.randint(0, len(PREDEFINED_PERMUTATIONS))
perm = PREDEFINED_PERMUTATIONS[permutation_idx]
shuffled = [patches[p] for p in perm]
return torch.stack(shuffled), permutation_idx
2.3 컬러화 (Colorization)
2016년 Zhang et al.이 제안한 방법입니다. 흑백 이미지를 입력받아 색상을 예측합니다.
class ColorizationSSL(nn.Module):
def __init__(self):
super().__init__()
# L 채널 입력 (밝기)
self.encoder = nn.Sequential(
nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(),
nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.ReLU(),
nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.ReLU(),
)
# ab 채널 출력 (색상)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), nn.ReLU(),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.ReLU(),
nn.Conv2d(64, 2, 3, padding=1), # ab 채널
nn.Tanh()
)
def forward(self, l_channel):
features = self.encoder(l_channel)
ab_channels = self.decoder(features)
return ab_channels
# LAB 색 공간에서 작업
from skimage.color import rgb2lab, lab2rgb
def prepare_colorization_data(rgb_image):
lab = rgb2lab(rgb_image)
l_channel = lab[:, :, 0:1] # 밝기 채널
ab_channels = lab[:, :, 1:] # 색상 채널
return l_channel, ab_channels
2.4 다음 프레임 예측
비디오에서 시간적 연속성을 이용합니다. 이전 프레임들로 다음 프레임을 예측합니다.
class VideoSSL(nn.Module):
def __init__(self, frame_encoder, temporal_model):
super().__init__()
self.frame_encoder = frame_encoder
self.temporal_model = temporal_model # LSTM 또는 Transformer
self.decoder = nn.ConvTranspose2d(...)
def forward(self, frames):
# frames: (batch, T, C, H, W)
T = frames.size(1)
# 각 프레임 인코딩
frame_features = []
for t in range(T - 1):
feat = self.frame_encoder(frames[:, t])
frame_features.append(feat)
# 시간 모델링
features = torch.stack(frame_features, dim=1)
next_feat = self.temporal_model(features)
# 다음 프레임 디코딩
next_frame_pred = self.decoder(next_feat[:, -1])
return next_frame_pred
3. 대조 학습
대조 학습(Contrastive Learning)은 현대 SSL의 핵심입니다. 2020년 전후로 폭발적으로 발전했습니다.
3.1 핵심 아이디어
같은 것은 가깝게, 다른 것은 멀게.
같은 이미지의 두 가지 다른 뷰(crop, flip, color jitter 등)는 임베딩 공간에서 가까워야 하고, 다른 이미지들은 멀어야 합니다.
이미지 x
├── 증강 뷰 v1 → z1 (임베딩)
└── 증강 뷰 v2 → z2 (임베딩)
목표: z1과 z2는 가깝게, z1과 다른 이미지의 z3은 멀게
3.2 InfoNCE Loss
대조 학습의 표준 손실 함수입니다. NT-Xent(Normalized Temperature-scaled Cross Entropy)라고도 합니다.
수식 (배치 내 N개 이미지, 각 이미지당 2개 뷰):
여기서 sim은 코사인 유사도, τ는 온도 파라미터입니다.
import torch
import torch.nn.functional as F
def info_nce_loss(z1, z2, temperature=0.5):
"""
InfoNCE / NT-Xent Loss 구현
z1, z2: (batch_size, embedding_dim) - 같은 이미지의 두 뷰
"""
batch_size = z1.size(0)
# L2 정규화
z1 = F.normalize(z1, dim=1)
z2 = F.normalize(z2, dim=1)
# 두 뷰를 합쳐서 2N x D 행렬 생성
# [z1_1, z1_2, ..., z1_N, z2_1, z2_2, ..., z2_N]
z = torch.cat([z1, z2], dim=0) # (2N, D)
# 모든 쌍의 유사도 계산
similarity = torch.mm(z, z.t()) / temperature # (2N, 2N)
# 자기 자신과의 유사도는 제외 (대각선 -inf)
mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
similarity.masked_fill_(mask, float('-inf'))
# 포지티브 쌍: i번째와 i+N번째, i+N번째와 i번째
labels = torch.cat([
torch.arange(batch_size, 2 * batch_size),
torch.arange(batch_size)
]).to(z.device)
loss = F.cross_entropy(similarity, labels)
return loss
# 사용 예시
z1 = torch.randn(32, 128) # 배치 32, 임베딩 128차원
z2 = torch.randn(32, 128)
loss = info_nce_loss(z1, z2, temperature=0.5)
print(f"InfoNCE Loss: {loss.item():.4f}")
3.3 SimCLR
2020년 Google이 발표한 Simple Framework for Contrastive Learning입니다. 단순하지만 강력합니다.
핵심 구성요소:
- 데이터 증강 (강한 증강이 핵심)
- 인코더 (ResNet)
- Projection Head (MLP)
- NT-Xent Loss
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as T
class SimCLR(nn.Module):
def __init__(self, backbone='resnet50', projection_dim=128, temperature=0.5):
super().__init__()
self.temperature = temperature
# 백본 인코더
if backbone == 'resnet50':
resnet = models.resnet50(pretrained=False)
self.encoder = nn.Sequential(*list(resnet.children())[:-1]) # FC 제거
self.feature_dim = 2048
elif backbone == 'resnet18':
resnet = models.resnet18(pretrained=False)
self.encoder = nn.Sequential(*list(resnet.children())[:-1])
self.feature_dim = 512
# Projection Head (중요: 파인튜닝 시 제거)
self.projection_head = nn.Sequential(
nn.Linear(self.feature_dim, self.feature_dim),
nn.ReLU(inplace=True),
nn.Linear(self.feature_dim, projection_dim)
)
def encode(self, x):
h = self.encoder(x)
h = h.squeeze(-1).squeeze(-1) # (batch, feature_dim)
return h
def forward(self, x):
h = self.encode(x)
z = self.projection_head(h)
return z
def contrastive_loss(self, z1, z2):
return info_nce_loss(z1, z2, self.temperature)
class SimCLRDataAugmentation:
"""SimCLR의 강한 데이터 증강"""
def __init__(self, image_size=224, s=1.0):
color_jitter = T.ColorJitter(
brightness=0.8*s,
contrast=0.8*s,
saturation=0.8*s,
hue=0.2*s
)
self.transform = T.Compose([
T.RandomResizedCrop(image_size, scale=(0.2, 1.0)),
T.RandomHorizontalFlip(p=0.5),
T.RandomApply([color_jitter], p=0.8),
T.RandomGrayscale(p=0.2),
T.GaussianBlur(kernel_size=int(0.1 * image_size)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def __call__(self, x):
return self.transform(x), self.transform(x)
# SimCLR 학습
def train_simclr(model, dataloader, optimizer, epochs=200):
model.train()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=epochs
)
for epoch in range(epochs):
total_loss = 0
for (x1, x2), _ in dataloader:
x1, x2 = x1.cuda(), x2.cuda()
z1 = model(x1)
z2 = model(x2)
loss = model.contrastive_loss(z1, z2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
scheduler.step()
avg_loss = total_loss / len(dataloader)
if (epoch + 1) % 10 == 0:
print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")
# 전체 파이프라인
model = SimCLR(backbone='resnet50', projection_dim=128).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-6)
augmentation = SimCLRDataAugmentation(image_size=224)
# 파인튜닝 시 projection head 제거
def get_finetuning_model(pretrained_simclr, num_classes):
# 인코더만 사용
encoder = pretrained_simclr.encoder
classifier = nn.Linear(pretrained_simclr.feature_dim, num_classes)
return nn.Sequential(encoder, nn.Flatten(), classifier)
3.4 MoCo (Momentum Contrast)
Facebook AI가 2020년 발표했습니다. 큰 배치 없이도 많은 네거티브 샘플을 사용할 수 있는 방법입니다.
핵심 아이디어: 모멘텀 업데이트로 안정적인 키 인코더 유지 + 큐(Queue)로 네거티브 샘플 관리
class MoCo(nn.Module):
def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07):
super().__init__()
self.K = K # 큐 크기
self.m = m # 모멘텀 계수
self.T = T # 온도
# 쿼리 인코더와 키 인코더
self.encoder_q = base_encoder(num_classes=dim)
self.encoder_k = base_encoder(num_classes=dim)
# 키 인코더는 그래디언트 없이 모멘텀 업데이트
for param_q, param_k in zip(
self.encoder_q.parameters(),
self.encoder_k.parameters()
):
param_k.data.copy_(param_q.data)
param_k.requires_grad = False
# 네거티브 샘플 큐
self.register_buffer("queue", torch.randn(dim, K))
self.queue = F.normalize(self.queue, dim=0)
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def _momentum_update_key_encoder(self):
"""키 인코더를 모멘텀으로 업데이트"""
for param_q, param_k in zip(
self.encoder_q.parameters(),
self.encoder_k.parameters()
):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
"""큐에 새 키 추가, 오래된 키 제거"""
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
self.queue[:, ptr:ptr + batch_size] = keys.T
ptr = (ptr + batch_size) % self.K
self.queue_ptr[0] = ptr
def forward(self, im_q, im_k):
# 쿼리 임베딩
q = self.encoder_q(im_q)
q = F.normalize(q, dim=1)
# 키 임베딩 (그래디언트 없음)
with torch.no_grad():
self._momentum_update_key_encoder()
k = self.encoder_k(im_k)
k = F.normalize(k, dim=1)
# 포지티브 로짓: (batch, 1)
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
# 네거티브 로짓: (batch, K)
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
logits = torch.cat([l_pos, l_neg], dim=1) / self.T
labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
loss = F.cross_entropy(logits, labels)
self._dequeue_and_enqueue(k)
return loss
3.5 BYOL (Bootstrap Your Own Latent)
2020년 DeepMind가 발표한 획기적인 방법입니다. 네거티브 샘플 없이 작동합니다.
핵심: 온라인 네트워크가 타깃 네트워크의 표현을 예측. 타깃은 모멘텀 업데이트.
class BYOL(nn.Module):
def __init__(self, backbone, projection_dim=256, prediction_dim=128, tau=0.996):
super().__init__()
self.tau = tau # 모멘텀 계수
# 온라인 네트워크
self.online_encoder = backbone
self.online_projector = nn.Sequential(
nn.Linear(2048, 4096), nn.BatchNorm1d(4096), nn.ReLU(),
nn.Linear(4096, projection_dim)
)
self.online_predictor = nn.Sequential(
nn.Linear(projection_dim, 4096), nn.BatchNorm1d(4096), nn.ReLU(),
nn.Linear(4096, prediction_dim)
)
# 타깃 네트워크 (그래디언트 없음)
self.target_encoder = copy.deepcopy(backbone)
self.target_projector = copy.deepcopy(self.online_projector)
for param in self.target_encoder.parameters():
param.requires_grad = False
for param in self.target_projector.parameters():
param.requires_grad = False
@torch.no_grad()
def update_target(self):
"""EMA 업데이트"""
for online, target in zip(
self.online_encoder.parameters(),
self.target_encoder.parameters()
):
target.data = self.tau * target.data + (1 - self.tau) * online.data
for online, target in zip(
self.online_projector.parameters(),
self.target_projector.parameters()
):
target.data = self.tau * target.data + (1 - self.tau) * online.data
def forward(self, x1, x2):
# 온라인 예측
online_feat1 = self.online_encoder(x1).squeeze()
online_proj1 = self.online_projector(online_feat1)
online_pred1 = self.online_predictor(online_proj1)
online_feat2 = self.online_encoder(x2).squeeze()
online_proj2 = self.online_projector(online_feat2)
online_pred2 = self.online_predictor(online_proj2)
# 타깃 (stop gradient)
with torch.no_grad():
target_feat1 = self.target_encoder(x1).squeeze()
target_proj1 = self.target_projector(target_feat1)
target_feat2 = self.target_encoder(x2).squeeze()
target_proj2 = self.target_projector(target_feat2)
# BYOL Loss: cosine similarity 최대화
loss1 = byol_loss(online_pred1, target_proj2.detach())
loss2 = byol_loss(online_pred2, target_proj1.detach())
return (loss1 + loss2).mean()
def byol_loss(p, z):
"""Negative cosine similarity"""
p = F.normalize(p, dim=-1)
z = F.normalize(z, dim=-1)
return -(p * z).sum(dim=-1)
3.6 SimSiam
2021년 Facebook AI가 발표. BYOL보다 더 단순합니다. 모멘텀 없이도 collapse 없이 학습됩니다.
class SimSiam(nn.Module):
def __init__(self, backbone, dim=2048, pred_dim=512):
super().__init__()
self.encoder = backbone
# Projector
self.projector = nn.Sequential(
nn.Linear(dim, dim, bias=False),
nn.BatchNorm1d(dim),
nn.ReLU(inplace=True),
nn.Linear(dim, dim, bias=False),
nn.BatchNorm1d(dim),
nn.ReLU(inplace=True),
nn.Linear(dim, dim, bias=False),
nn.BatchNorm1d(dim, affine=False) # 마지막 BN: affine=False
)
# Predictor
self.predictor = nn.Sequential(
nn.Linear(dim, pred_dim, bias=False),
nn.BatchNorm1d(pred_dim),
nn.ReLU(inplace=True),
nn.Linear(pred_dim, dim)
)
def forward(self, x1, x2):
z1 = self.projector(self.encoder(x1).squeeze())
z2 = self.projector(self.encoder(x2).squeeze())
p1 = self.predictor(z1)
p2 = self.predictor(z2)
# Stop gradient - 비대칭 구조의 핵심
loss = simsiam_loss(p1, z2.detach()) / 2 + \
simsiam_loss(p2, z1.detach()) / 2
return loss
def simsiam_loss(p, z):
"""Negative cosine similarity"""
z = z.detach() # Stop gradient
p = F.normalize(p, dim=1)
z = F.normalize(z, dim=1)
return -(p * z).sum(dim=1).mean()
4. Masked Autoencoder (MAE)
2021년 Kaiming He et al.이 발표한 MAE는 자연어 처리의 BERT를 이미지에 적용한 것입니다.
4.1 핵심 아이디어
이미지의 75%를 마스킹하고 나머지 25%로 전체 이미지를 복원.
왜 75%라는 높은 마스킹 비율인가?
- 텍스트는 각 토큰이 높은 의미 밀도를 가져 15% 마스킹으로 충분
- 이미지는 인접 픽셀 간 높은 중복성 → 75%는 마스킹해야 진정한 이해 필요
4.2 비대칭 인코더-디코더
MAE의 혁신: 인코더는 가시 패치만 처리, 디코더는 전체를 복원.
입력 이미지 (196개 패치)
↓ 75% 마스킹
49개 가시 패치 → 인코더 (ViT-Large) → 인코딩
↓
196개 패치 (마스크 토큰 포함) → 디코더 (얕은 ViT) → 픽셀 복원
인코더는 원래 크기의 25%만 처리 → 3-4배 빠른 학습
4.3 완전한 MAE 구현
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
class PatchEmbedding(nn.Module):
"""이미지를 패치 임베딩으로 변환"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
self.projection = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
# x: (batch, C, H, W)
x = self.projection(x) # (batch, embed_dim, H/p, W/p)
x = x.flatten(2) # (batch, embed_dim, num_patches)
x = x.transpose(1, 2) # (batch, num_patches, embed_dim)
return x
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., dropout=0.):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
self.norm2 = nn.LayerNorm(dim)
mlp_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
# Self-attention
norm_x = self.norm1(x)
attn_out, _ = self.attn(norm_x, norm_x, norm_x)
x = x + attn_out
# MLP
x = x + self.mlp(self.norm2(x))
return x
class MAE(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
in_channels=3,
# 인코더 설정
encoder_dim=768,
encoder_depth=12,
encoder_heads=12,
# 디코더 설정 (인코더보다 훨씬 작음)
decoder_dim=512,
decoder_depth=8,
decoder_heads=16,
# 마스킹 비율
mask_ratio=0.75,
):
super().__init__()
self.mask_ratio = mask_ratio
self.patch_size = patch_size
num_patches = (img_size // patch_size) ** 2
# 패치 임베딩
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, encoder_dim)
# [CLS] 토큰과 위치 임베딩 (인코더)
self.cls_token = nn.Parameter(torch.zeros(1, 1, encoder_dim))
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, encoder_dim),
requires_grad=False
)
# 인코더 (ViT)
self.encoder_blocks = nn.ModuleList([
TransformerBlock(encoder_dim, encoder_heads)
for _ in range(encoder_depth)
])
self.encoder_norm = nn.LayerNorm(encoder_dim)
# 인코더-디코더 연결
self.decoder_embed = nn.Linear(encoder_dim, decoder_dim, bias=True)
# 마스크 토큰
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))
# 디코더 위치 임베딩
self.decoder_pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, decoder_dim),
requires_grad=False
)
# 디코더 (얕은 ViT)
self.decoder_blocks = nn.ModuleList([
TransformerBlock(decoder_dim, decoder_heads)
for _ in range(decoder_depth)
])
self.decoder_norm = nn.LayerNorm(decoder_dim)
# 픽셀 예측 헤드
self.decoder_pred = nn.Linear(decoder_dim, patch_size**2 * in_channels)
self._init_weights()
def _init_weights(self):
"""Sinusoidal 위치 임베딩 초기화"""
# 위치 임베딩 초기화 (sinusoidal)
pos_embed = self._get_2d_sincos_pos_embed(
self.pos_embed.shape[-1],
int(self.patch_embed.num_patches**0.5)
)
self.pos_embed.data.copy_(
torch.from_numpy(pos_embed).float().unsqueeze(0)
)
# 나머지 초기화
nn.init.normal_(self.cls_token, std=.02)
nn.init.normal_(self.mask_token, std=.02)
def _get_2d_sincos_pos_embed(self, embed_dim, grid_size):
"""2D sinusoidal 위치 임베딩"""
import numpy as np
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h)
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
emb_h = self._get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
emb_w = self._get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
emb = np.concatenate([emb_h, emb_w], axis=1)
# CLS 토큰을 위한 제로 임베딩 추가
emb = np.concatenate([np.zeros([1, embed_dim]), emb], axis=0)
return emb
def _get_1d_sincos_pos_embed_from_grid(self, embed_dim, pos):
import numpy as np
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.
omega = 1. / 10000**omega
pos = pos.reshape(-1)
out = np.einsum('m,d->md', pos, omega)
emb_sin = np.sin(out)
emb_cos = np.cos(out)
emb = np.concatenate([emb_sin, emb_cos], axis=1)
return emb
def random_masking(self, x, mask_ratio):
"""랜덤 마스킹: 일부 패치만 남김"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
# 랜덤 노이즈로 셔플 인덱스 생성
noise = torch.rand(N, L, device=x.device)
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
# 가시 패치만 선택
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(
x, dim=1,
index=ids_keep.unsqueeze(-1).repeat(1, 1, D)
)
# 마스크 생성 (1: 마스킹됨, 0: 가시)
mask = torch.ones(N, L, device=x.device)
mask[:, :len_keep] = 0
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
def forward_encoder(self, x, mask_ratio):
"""인코더: 가시 패치만 처리"""
# 패치 임베딩
x = self.patch_embed(x)
# 위치 임베딩 추가 (CLS 제외)
x = x + self.pos_embed[:, 1:, :]
# 랜덤 마스킹
x, mask, ids_restore = self.random_masking(x, mask_ratio)
# CLS 토큰 추가
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# 인코더 트랜스포머
for block in self.encoder_blocks:
x = block(x)
x = self.encoder_norm(x)
return x, mask, ids_restore
def forward_decoder(self, x, ids_restore):
"""디코더: 마스크 토큰 포함 전체 복원"""
# 인코더 차원 → 디코더 차원
x = self.decoder_embed(x)
# 마스크 토큰 추가
mask_tokens = self.mask_token.repeat(
x.shape[0],
ids_restore.shape[1] + 1 - x.shape[1],
1
)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # CLS 제외
x_ = torch.gather(
x_, dim=1,
index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
)
x = torch.cat([x[:, :1, :], x_], dim=1) # CLS 다시 추가
# 디코더 위치 임베딩
x = x + self.decoder_pos_embed
# 디코더 트랜스포머
for block in self.decoder_blocks:
x = block(x)
x = self.decoder_norm(x)
# 픽셀 예측
x = self.decoder_pred(x)
x = x[:, 1:, :] # CLS 토큰 제거
return x
def forward(self, imgs, mask_ratio=0.75):
# 인코딩 (마스킹 포함)
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
# 디코딩 (픽셀 복원)
pred = self.forward_decoder(latent, ids_restore)
# 손실 계산
loss = self.forward_loss(imgs, pred, mask)
return loss, pred, mask
def forward_loss(self, imgs, pred, mask):
"""마스킹된 패치에 대한 MSE Loss"""
# 타깃: 패치화된 이미지
target = self.patchify(imgs)
# 정규화 (patch normalization)
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.e-6).sqrt()
# 마스킹된 패치에 대해서만 손실 계산
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # 패치당 평균
loss = (loss * mask).sum() / mask.sum() # 마스킹된 패치 평균
return loss
def patchify(self, imgs):
"""이미지를 패치 시퀀스로 변환"""
p = self.patch_size
h = w = imgs.shape[2] // p
x = imgs.reshape(imgs.shape[0], 3, h, p, w, p)
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(imgs.shape[0], h * w, p**2 * 3)
return x
# MAE 모델 생성 및 학습
mae_model = MAE(
img_size=224,
patch_size=16,
encoder_dim=768,
encoder_depth=12,
encoder_heads=12,
decoder_dim=512,
decoder_depth=8,
decoder_heads=16,
mask_ratio=0.75
).cuda()
optimizer = torch.optim.AdamW(
mae_model.parameters(),
lr=1.5e-4,
betas=(0.9, 0.95),
weight_decay=0.05
)
# 학습 스텝
def train_step(model, imgs, optimizer):
model.train()
optimizer.zero_grad()
loss, pred, mask = model(imgs, mask_ratio=0.75)
loss.backward()
optimizer.step()
return loss.item()
5. DINO
DINO(Self-DIstillation with NO labels)는 2021년 Facebook AI Research가 발표했습니다. ViT를 SSL로 학습했을 때 놀라운 특성이 나타납니다.
5.1 DINO의 핵심 발견
DINO로 학습된 ViT의 자기 주의(self-attention) 맵은 **레이블 없이도 의미론적 분할(segmentation)**을 수행합니다. 배경과 전경을 완벽히 구분합니다.
5.2 Self-Distillation 구조
Student Network (학생) Teacher Network (선생)
↑ ↑
Local Views (여러 개) Global Views (2개)
- 선생: 전체 이미지 뷰 처리, 더 풍부한 컨텍스트
- 학생: 로컬 크롭 처리, 선생의 예측을 따라 배움
- 선생 업데이트: 학생의 EMA (그래디언트 없음)
5.3 Centering과 Sharpening
Collapse 방지를 위한 두 가지 기법:
class DINOHead(nn.Module):
"""DINO Projection Head"""
def __init__(self, in_dim, out_dim=65536, bottleneck_dim=256):
super().__init__()
layers = [nn.Linear(in_dim, 2048), nn.GELU()]
layers += [nn.Linear(2048, 2048), nn.GELU()]
layers += [nn.Linear(2048, bottleneck_dim)]
self.mlp = nn.Sequential(*layers)
# Weight normalization된 마지막 레이어
self.last_layer = nn.utils.weight_norm(
nn.Linear(bottleneck_dim, out_dim, bias=False)
)
self.last_layer.weight_g.data.fill_(1)
self.last_layer.weight_g.requires_grad = False
def forward(self, x):
x = self.mlp(x)
x = F.normalize(x, dim=-1)
x = self.last_layer(x)
return x
class DINO(nn.Module):
def __init__(self, backbone, out_dim=65536, teacher_temp=0.04, student_temp=0.1):
super().__init__()
self.student_temp = student_temp
self.teacher_temp = teacher_temp
# Student 네트워크
self.student = backbone
self.student_head = DINOHead(self.student.embed_dim, out_dim)
# Teacher 네트워크 (EMA)
self.teacher = copy.deepcopy(backbone)
self.teacher_head = copy.deepcopy(self.student_head)
for p in self.teacher.parameters():
p.requires_grad = False
for p in self.teacher_head.parameters():
p.requires_grad = False
# Centering을 위한 center 벡터
self.register_buffer("center", torch.zeros(1, out_dim))
@torch.no_grad()
def update_teacher(self, momentum):
"""EMA 업데이트"""
for student_ps, teacher_ps in zip(
self.student.parameters(), self.teacher.parameters()
):
teacher_ps.data.mul_(momentum).add_(
(1 - momentum) * student_ps.data
)
for student_ps, teacher_ps in zip(
self.student_head.parameters(), self.teacher_head.parameters()
):
teacher_ps.data.mul_(momentum).add_(
(1 - momentum) * student_ps.data
)
@torch.no_grad()
def update_center(self, teacher_output):
"""Centering 업데이트 (collapse 방지)"""
batch_center = teacher_output.mean(dim=0, keepdim=True)
self.center = self.center * 0.9 + batch_center * 0.1
def dino_loss(self, student_output, teacher_output):
"""DINO Cross-entropy Loss"""
student_out = student_output / self.student_temp
student_out = student_out.chunk(2) # global crops
# Teacher: centering + sharpening
teacher_out = F.softmax(
(teacher_output - self.center) / self.teacher_temp, dim=-1
)
teacher_out = teacher_out.detach().chunk(2)
total_loss = 0
n_loss_terms = 0
for iq, q in enumerate(teacher_out):
for v in range(len(student_out)):
if v == iq:
continue # 같은 뷰는 제외
loss = torch.sum(
-q * F.log_softmax(student_out[v], dim=-1), dim=-1
)
total_loss += loss.mean()
n_loss_terms += 1
return total_loss / n_loss_terms
def forward(self, global_views, local_views, momentum=0.996):
# Student: 모든 뷰 처리
all_views = global_views + local_views
student_output = torch.cat([
self.student_head(self.student(view))
for view in all_views
])
# Teacher: 글로벌 뷰만 처리
with torch.no_grad():
teacher_output = torch.cat([
self.teacher_head(self.teacher(view))
for view in global_views
])
loss = self.dino_loss(student_output, teacher_output)
# 업데이트
self.update_center(teacher_output)
self.update_teacher(momentum)
return loss
# DINO 데이터 증강 (멀티 크롭)
class DINOAugmentation:
def __init__(self, global_crops_scale=(0.4, 1.), local_crops_scale=(0.05, 0.4),
n_local_crops=8, image_size=224, local_size=96):
self.n_local_crops = n_local_crops
# 전역 크롭 (2개, 큰 뷰)
self.global_transform = T.Compose([
T.RandomResizedCrop(image_size, scale=global_crops_scale, interpolation=T.InterpolationMode.BICUBIC),
T.RandomHorizontalFlip(),
T.RandomApply([T.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
T.RandomGrayscale(p=0.2),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
# 로컬 크롭 (8개, 작은 뷰)
self.local_transform = T.Compose([
T.RandomResizedCrop(local_size, scale=local_crops_scale, interpolation=T.InterpolationMode.BICUBIC),
T.RandomHorizontalFlip(),
T.RandomApply([T.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
T.RandomGrayscale(p=0.2),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def __call__(self, image):
global_views = [self.global_transform(image) for _ in range(2)]
local_views = [self.local_transform(image) for _ in range(self.n_local_crops)]
return global_views, local_views
5.4 DINOv2
2023년 발표된 DINOv2는 더 강력합니다:
- 1억 4200만 장의 큐레이션된 데이터로 학습
- iBOT(Image BERT Pre-Training with Online Tokenizer) 손실 추가
- 더 안정적인 학습
6. CLIP
CLIP(Contrastive Language-Image Pretraining)은 2021년 OpenAI가 발표한 혁신적인 모델입니다.
6.1 텍스트-이미지 대조 학습
인터넷에서 수집한 4억 쌍의 (이미지, 텍스트) 데이터로 학습합니다.
이미지 인코더 (ViT / ResNet)
↓
이미지 임베딩
텍스트 인코더 (Transformer)
↓
텍스트 임베딩
목표: 매칭되는 (이미지, 텍스트) 쌍은 유사하게
매칭 안 되는 쌍은 다르게
6.2 CLIP 손실 함수
def clip_loss(image_embeddings, text_embeddings, temperature):
"""
CLIP 대칭 대조 학습 손실
image_embeddings: (N, D)
text_embeddings: (N, D)
"""
# L2 정규화
image_embeddings = F.normalize(image_embeddings, dim=1)
text_embeddings = F.normalize(text_embeddings, dim=1)
# 유사도 행렬 계산: (N, N)
logits = torch.matmul(image_embeddings, text_embeddings.T) / temperature
# 정답: 대각선 원소 (i번째 이미지 = i번째 텍스트)
labels = torch.arange(len(logits)).to(logits.device)
# 이미지 → 텍스트, 텍스트 → 이미지 양방향 손실
loss_i2t = F.cross_entropy(logits, labels)
loss_t2i = F.cross_entropy(logits.T, labels)
return (loss_i2t + loss_t2i) / 2
6.3 CLIP 완전 구현
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import requests
import torch
import torch.nn.functional as F
# 사전학습된 CLIP 모델 로드
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# ---- Zero-shot 이미지 분류 ----
def zero_shot_classify(image_path, candidate_labels):
"""
레이블 예제 없이 이미지 분류
"""
image = Image.open(image_path)
# 텍스트 프롬프트 생성
texts = [f"a photo of a {label}" for label in candidate_labels]
# 인코딩
inputs = processor(
text=texts,
images=image,
return_tensors="pt",
padding=True
)
with torch.no_grad():
outputs = model(**inputs)
# 유사도 계산
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
results = {}
for label, prob in zip(candidate_labels, probs[0]):
results[label] = prob.item()
return results
# 사용 예시
labels = ["cat", "dog", "car", "airplane", "bird"]
results = zero_shot_classify("example.jpg", labels)
for label, prob in sorted(results.items(), key=lambda x: -x[1]):
print(f"{label}: {prob:.4f}")
# ---- 이미지-텍스트 유사도 검색 ----
class CLIPRetrieval:
def __init__(self, model_name="openai/clip-vit-base-patch32"):
self.model = CLIPModel.from_pretrained(model_name)
self.processor = CLIPProcessor.from_pretrained(model_name)
self.model.eval()
self.image_embeddings = None
self.image_paths = []
def index_images(self, image_paths):
"""이미지 데이터베이스 인덱싱"""
self.image_paths = image_paths
embeddings = []
for path in image_paths:
image = Image.open(path)
inputs = self.processor(images=image, return_tensors="pt")
with torch.no_grad():
embedding = self.model.get_image_features(**inputs)
embeddings.append(F.normalize(embedding, dim=1))
self.image_embeddings = torch.cat(embeddings, dim=0)
print(f"{len(image_paths)}개 이미지 인덱싱 완료")
def search(self, query_text, top_k=5):
"""텍스트 쿼리로 이미지 검색"""
inputs = self.processor(text=[query_text], return_tensors="pt", padding=True)
with torch.no_grad():
text_embedding = self.model.get_text_features(**inputs)
text_embedding = F.normalize(text_embedding, dim=1)
# 코사인 유사도
similarities = torch.matmul(text_embedding, self.image_embeddings.T).squeeze()
top_indices = similarities.argsort(descending=True)[:top_k]
results = [
(self.image_paths[i], similarities[i].item())
for i in top_indices
]
return results
# 멀티모달 임베딩 시각화
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np
def visualize_clip_embeddings(images, texts):
"""CLIP 임베딩 t-SNE 시각화"""
image_inputs = processor(images=images, return_tensors="pt")
text_inputs = processor(text=texts, return_tensors="pt", padding=True)
with torch.no_grad():
image_embs = F.normalize(model.get_image_features(**image_inputs), dim=1)
text_embs = F.normalize(model.get_text_features(**text_inputs), dim=1)
# 이미지와 텍스트 임베딩 합치기
all_embs = torch.cat([image_embs, text_embs], dim=0).numpy()
# t-SNE 차원 축소
tsne = TSNE(n_components=2, random_state=42)
reduced = tsne.fit_transform(all_embs)
n = len(images)
plt.figure(figsize=(12, 8))
plt.scatter(reduced[:n, 0], reduced[:n, 1], c='blue', label='Images', alpha=0.7)
plt.scatter(reduced[n:, 0], reduced[n:, 1], c='red', label='Texts', alpha=0.7)
for i, text in enumerate(texts):
plt.annotate(text, reduced[n+i], fontsize=8)
plt.legend()
plt.title('CLIP Embeddings (t-SNE)')
plt.savefig('clip_tsne.png', dpi=150, bbox_inches='tight')
plt.show()
7. BEiT
BEiT(BERT Pre-Training of Image Transformers)는 2021년 Microsoft가 발표했습니다. 이미지를 이산 토큰으로 변환한 후 BERT처럼 학습합니다.
7.1 이산 시각 토큰 (dVAE)
BEiT의 핵심: 이미지를 연속 픽셀이 아닌 이산 토큰으로 예측합니다.
1단계: dVAE로 이미지 → 이산 토큰 변환 (어휘 8192개)
2단계: ViT가 마스킹된 패치의 토큰을 예측
from transformers import BeitForMaskedImageModeling, BeitImageProcessor
import torch
# BEiT 마스크드 이미지 모델링
class BEiTPretraining:
def __init__(self, model_name="microsoft/beit-base-patch16-224"):
self.model = BeitForMaskedImageModeling.from_pretrained(model_name)
self.processor = BeitImageProcessor.from_pretrained(model_name)
def create_bool_masked_pos(self, num_patches, mask_ratio=0.4):
"""마스킹 위치 생성"""
num_masked = int(num_patches * mask_ratio)
mask = torch.zeros(num_patches, dtype=torch.bool)
# 블록 마스킹 (BEiT v2)
masked_indices = torch.randperm(num_patches)[:num_masked]
mask[masked_indices] = True
return mask
def forward(self, images):
inputs = self.processor(images, return_tensors="pt")
pixel_values = inputs.pixel_values
num_patches = (224 // 16) ** 2 # 196
bool_masked_pos = self.create_bool_masked_pos(num_patches)
bool_masked_pos = bool_masked_pos.unsqueeze(0).expand(
pixel_values.size(0), -1
)
outputs = self.model(
pixel_values=pixel_values,
bool_masked_pos=bool_masked_pos
)
return outputs.loss
8. 언어 모델의 사전학습
자연어 처리에서 SSL은 오랜 역사를 가집니다.
8.1 GPT - 다음 단어 예측
자기회귀(Autoregressive) 방식으로 왼쪽에서 오른쪽으로 다음 토큰을 예측합니다.
import torch
import torch.nn as nn
from torch.nn import functional as F
class GPTLanguageModel(nn.Module):
def __init__(self, vocab_size, block_size, n_embd, n_head, n_layer, dropout=0.1):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, n_embd)
self.position_embedding = nn.Embedding(block_size, n_embd)
self.blocks = nn.Sequential(*[
TransformerBlock(n_embd, n_head, dropout=dropout)
for _ in range(n_layer)
])
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size)
self.block_size = block_size
def forward(self, idx, targets=None):
B, T = idx.shape
tok_emb = self.token_embedding(idx) # (B, T, C)
pos_emb = self.position_embedding(
torch.arange(T, device=idx.device)
) # (T, C)
x = tok_emb + pos_emb
x = self.blocks(x)
x = self.ln_f(x)
logits = self.lm_head(x) # (B, T, vocab_size)
if targets is None:
return logits, None
# 다음 토큰 예측 손실
B, T, C = logits.shape
logits_flat = logits.view(B * T, C)
targets_flat = targets.view(B * T)
loss = F.cross_entropy(logits_flat, targets_flat)
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.block_size:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float('-inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
8.2 BERT - 마스킹 언어 모델
양방향 문맥을 학습합니다. 15%의 토큰을 마스킹하고 복원합니다.
from transformers import BertForMaskedLM, BertTokenizer
import torch
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
def create_mlm_input(text, mask_prob=0.15):
"""마스크드 언어 모델 입력 생성"""
tokens = tokenizer.encode(text, return_tensors="pt")
labels = tokens.clone()
# 15% 무작위 마스킹
rand = torch.rand(tokens.shape)
mask_arr = (rand < mask_prob) & (tokens != tokenizer.cls_token_id) & \
(tokens != tokenizer.sep_token_id)
# 마스킹 전략:
# - 80%: [MASK] 토큰으로 교체
# - 10%: 랜덤 토큰으로 교체
# - 10%: 원래 토큰 유지
for i, masked in enumerate(mask_arr[0]):
if masked:
prob = torch.rand(1).item()
if prob < 0.8:
tokens[0][i] = tokenizer.mask_token_id
elif prob < 0.9:
tokens[0][i] = torch.randint(len(tokenizer), (1,))
# 나머지 10%: 원래 토큰 유지
# 마스킹되지 않은 위치는 손실 계산에서 제외
labels[~mask_arr] = -100
return tokens, labels
# MLM 손실 계산
text = "The quick brown fox jumps over the lazy dog"
input_ids, labels = create_mlm_input(text)
with torch.no_grad():
outputs = model(input_ids=input_ids, labels=labels)
print(f"MLM Loss: {outputs.loss.item():.4f}")
# 마스킹된 토큰 예측
input_text = "The [MASK] brown fox"
inputs = tokenizer(input_text, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
mask_idx = (inputs.input_ids == tokenizer.mask_token_id).nonzero()[0][1]
top_5 = logits[0, mask_idx].topk(5)
for token_id, score in zip(top_5.indices, top_5.values):
print(f"{tokenizer.decode([token_id])}: {score:.3f}")
8.3 T5 - 텍스트-텍스트 변환
from transformers import T5ForConditionalGeneration, T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained("t5-base")
model = T5ForConditionalGeneration.from_pretrained("t5-base")
# T5는 모든 태스크를 텍스트-텍스트로 통일
tasks = {
"번역": "translate English to Korean: The weather is nice today",
"요약": "summarize: Scientists have discovered a new species of deep-sea fish...",
"감성 분석": "sentiment analysis: This movie was absolutely fantastic!",
"질의응답": "question: What is the capital of France? context: Paris is the capital..."
}
for task_name, input_text in tasks.items():
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
with torch.no_grad():
output_ids = model.generate(
inputs.input_ids,
max_length=128,
num_beams=4,
early_stopping=True
)
output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(f"\n[{task_name}]")
print(f"입력: {input_text[:60]}...")
print(f"출력: {output}")
9. 실전 활용
9.1 소수 레이블 학습 (Few-shot Learning)
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import numpy as np
class LinearProbe(nn.Module):
"""SSL 특징 위에 선형 분류기"""
def __init__(self, feature_dim, num_classes):
super().__init__()
self.linear = nn.Linear(feature_dim, num_classes)
def forward(self, x):
return self.linear(x)
def extract_features(model, dataloader, device='cuda'):
"""사전학습된 모델로 특징 추출"""
model.eval()
features_list = []
labels_list = []
with torch.no_grad():
for images, labels in dataloader:
images = images.to(device)
features = model.encode(images)
features_list.append(features.cpu())
labels_list.append(labels)
return torch.cat(features_list), torch.cat(labels_list)
def few_shot_evaluation(ssl_model, train_loader, val_loader,
n_shots=[1, 5, 10, 100], device='cuda'):
"""Few-shot 평가: n개 레이블로 선형 분류기 학습"""
ssl_model.eval()
# 전체 특징 추출
all_features, all_labels = extract_features(ssl_model, train_loader, device)
val_features, val_labels = extract_features(ssl_model, val_loader, device)
results = {}
num_classes = len(all_labels.unique())
for n_shot in n_shots:
# n_shot개 레이블만 사용
selected_indices = []
for cls in range(num_classes):
cls_indices = (all_labels == cls).nonzero(as_tuple=True)[0]
selected = cls_indices[torch.randperm(len(cls_indices))[:n_shot]]
selected_indices.append(selected)
selected_indices = torch.cat(selected_indices)
train_feats = all_features[selected_indices]
train_labs = all_labels[selected_indices]
# 선형 분류기 학습
probe = LinearProbe(all_features.shape[1], num_classes).to(device)
optimizer = torch.optim.Adam(probe.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
train_feats = train_feats.to(device)
train_labs = train_labs.to(device)
for epoch in range(100):
probe.train()
logits = probe(train_feats)
loss = criterion(logits, train_labs)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 검증
probe.eval()
with torch.no_grad():
val_logits = probe(val_features.to(device))
preds = val_logits.argmax(dim=1).cpu()
acc = (preds == val_labels).float().mean().item()
results[n_shot] = acc
print(f"{n_shot}-shot 정확도: {acc * 100:.2f}%")
return results
# k-NN 평가 (파라미터 없는 평가)
def knn_evaluation(ssl_model, train_loader, val_loader, k=20, device='cuda'):
"""k-NN 분류기로 SSL 표현 품질 평가"""
ssl_model.eval()
train_features, train_labels = extract_features(ssl_model, train_loader, device)
val_features, val_labels = extract_features(ssl_model, val_loader, device)
# 정규화
train_features = nn.functional.normalize(train_features, dim=1)
val_features = nn.functional.normalize(val_features, dim=1)
# 배치별 k-NN 분류
correct = 0
total = 0
batch_size = 256
for i in range(0, len(val_features), batch_size):
batch_feats = val_features[i:i+batch_size].to(device)
batch_labels = val_labels[i:i+batch_size]
# 코사인 유사도 계산
sims = torch.mm(batch_feats, train_features.to(device).T)
topk_sims, topk_indices = sims.topk(k, dim=1)
# 다수결 투표
topk_labels = train_labels[topk_indices.cpu()]
preds = topk_labels.mode(dim=1).values
correct += (preds == batch_labels).sum().item()
total += len(batch_labels)
accuracy = correct / total
print(f"k-NN ({k}) 정확도: {accuracy * 100:.2f}%")
return accuracy
9.2 도메인 특화 SSL
# 의료 이미지용 SSL (CheXpert 흉부 X-ray 예시)
class MedicalSSL(nn.Module):
"""의료 이미지 특화 자기지도 학습"""
def __init__(self, backbone):
super().__init__()
self.backbone = backbone
# 의료 이미지 특화 증강 (약한 증강)
self.augmentation = T.Compose([
T.RandomResizedCrop(224, scale=(0.8, 1.0)), # 작은 크롭 변화
T.RandomHorizontalFlip(p=0.5),
# 의료 이미지는 색상 변화 최소화
T.RandomApply([T.ColorJitter(0.1, 0.1, 0.0, 0.0)], p=0.3),
T.ToTensor(),
T.Normalize([0.485], [0.229]) # 흑백 X-ray
])
def forward(self, x):
v1 = self.augmentation(x)
v2 = self.augmentation(x)
return v1, v2
# 위성 이미지용 SSL
class SatelliteSSL:
"""위성 이미지 특화 증강 (지리적 특성 보존)"""
def __init__(self):
self.transform = T.Compose([
T.RandomResizedCrop(256, scale=(0.4, 1.0)),
# 위성 이미지: 회전 불변성 중요
T.RandomRotation(90),
T.RandomHorizontalFlip(),
T.RandomVerticalFlip(),
# 시간대/계절 변화 시뮬레이션
T.ColorJitter(brightness=0.4, contrast=0.4),
T.ToTensor(),
])
# 학습 성능 비교
def compare_ssl_methods(dataset, num_epochs=100):
methods = {
'SimCLR': SimCLR(backbone='resnet50'),
'BYOL': BYOL(backbone=resnet50()),
'MAE': MAE(img_size=224),
}
results = {}
for name, model in methods.items():
print(f"\n{name} 학습 중...")
model = model.cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for epoch in range(num_epochs):
train_one_epoch(model, dataset, optimizer)
# 100-shot 평가
acc = few_shot_evaluation(model, dataset, n_shots=[100])
results[name] = acc[100]
print(f"{name} 100-shot 정확도: {acc[100]*100:.2f}%")
return results
9.3 SSL 모델 허브 활용
# Hugging Face에서 사전학습된 SSL 모델 사용
from transformers import (
ViTModel, ViTFeatureExtractor,
DeiTModel,
CLIPModel, CLIPProcessor
)
from PIL import Image
import torch
# ViT (MAE로 사전학습된 버전)
def load_mae_pretrained():
model = ViTModel.from_pretrained("facebook/vit-mae-base")
feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")
return model, feature_extractor
# DINO ViT 특징 추출
def extract_dino_features(image_path):
from transformers import ViTModel
model = ViTModel.from_pretrained("facebook/dino-vits16")
processor = ViTFeatureExtractor.from_pretrained("facebook/dino-vits16")
image = Image.open(image_path)
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# [CLS] 토큰 특징 (전체 이미지 표현)
cls_features = outputs.last_hidden_state[:, 0, :]
# 패치별 특징 (공간 정보 포함)
patch_features = outputs.last_hidden_state[:, 1:, :]
return cls_features, patch_features
# CLIP으로 이미지 특징 추출 후 다운스트림 태스크
def clip_feature_extraction(images, texts=None):
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
if texts:
inputs = processor(text=texts, images=images, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
return outputs.image_embeds, outputs.text_embeds
else:
inputs = processor(images=images, return_tensors="pt")
with torch.no_grad():
image_features = model.get_image_features(**inputs)
return image_features
마무리
자기지도 학습은 AI의 판도를 바꾸고 있습니다. 핵심 정리:
| 방법 | 연도 | 핵심 아이디어 | 강점 |
|---|---|---|---|
| SimCLR | 2020 | 강한 증강 + 대조 학습 | 단순하고 강력 |
| MoCo | 2020 | 모멘텀 큐 | 배치 크기 독립 |
| BYOL | 2020 | 네거티브 없는 대조 학습 | 배치 크기 독립 |
| MAE | 2021 | 75% 마스킹 + 픽셀 복원 | 효율적, ViT 최적 |
| DINO | 2021 | Self-distillation | 세그멘테이션 특성 |
| CLIP | 2021 | 이미지-텍스트 대조 | Zero-shot 분류 |
| BEiT | 2021 | 이산 토큰 마스킹 | BERT 스타일 |
| DINOv2 | 2023 | 대규모 + iBOT | 범용 특징 추출 |
학습 로드맵:
- SimCLR 구현으로 대조 학습 이해
- MAE로 마스킹 접근법 이해
- CLIP으로 멀티모달 학습 적용
- 도메인 데이터에 파인튜닝
SSL은 이제 LLM의 사전학습, 컴퓨터 비전의 근간이 되었습니다. 레이블 없는 대규모 데이터를 효과적으로 활용하는 SSL은 앞으로도 AI 발전의 핵심 동력이 될 것입니다.