- Authors

- Name
- Youngju Kim
- @fjvbn20031
메타러닝과 퓨샷 학습 완전 가이드: MAML, Prototypical Networks, In-Context Learning
인간은 새로운 개념을 단 몇 개의 예시만으로도 빠르게 학습한다. 어린아이에게 "이게 얼룩말이야"라고 보여주면, 그 다음 날 처음 보는 얼룩말 사진도 바로 알아본다. 하지만 전통적인 딥러닝 모델은 얼룩말 분류기 하나를 만들기 위해 수천 장의 사진이 필요하다.
이 간극을 좁히는 것이 **메타러닝(Meta-Learning)**과 **퓨샷 학습(Few-Shot Learning)**이다. 메타러닝의 핵심 아이디어는 "학습하는 방법을 학습(Learning to Learn)"하는 것이다. 모델이 다양한 태스크를 경험하면서 새로운 태스크에 빠르게 적응하는 능력 자체를 학습한다.
이 가이드에서는 메타러닝의 이론적 기반부터 최신 In-Context Learning까지 완전히 다룬다.
1. 메타러닝 기초
1.1 학습하는 방법을 학습 (Learning to Learn)
전통적인 머신러닝에서 모델은 단일 태스크를 위해 학습된다. 고양이 분류기는 고양이 데이터로만 학습하고, 새로운 동물(예: 오셀롯)을 분류해야 할 때는 처음부터 다시 학습해야 한다.
메타러닝에서는 관점을 바꾼다. 모델이 학습해야 할 것은 "고양이를 분류하는 방법"이 아니라 "새로운 동물을 빠르게 분류하는 방법을 배우는 방법"이다.
이를 위해 메타러닝은 두 수준의 학습을 가진다.
- 메타 학습 (Outer Loop): 여러 태스크를 경험하면서 좋은 초기화 또는 학습 알고리즘을 학습
- 태스크 학습 (Inner Loop): 각 특정 태스크에서 소수의 예시로 빠르게 적응
1.2 기존 학습의 한계
전통적 학습의 한계는 구체적으로 다음과 같다.
데이터 효율성: ImageNet으로 학습된 모델은 수백만 개의 이미지가 필요하다. 새로운 클래스를 추가하려면 수천 장이 필요하다.
일반화 능력 부족: 훈련 분포에서 크게 벗어난 새로운 태스크에는 성능이 급격히 떨어진다.
지속적 학습 문제: 새로운 태스크를 학습하면 이전 태스크를 잊어버리는 "재앙적 망각(catastrophic forgetting)" 현상이 발생한다.
1.3 태스크 분포 (Task Distribution)
메타러닝에서 핵심 개념은 태스크 분포 p(T) 이다. 메타 학습은 단순히 데이터로부터 학습하는 것이 아니라, 태스크들의 분포로부터 학습한다.
각 태스크 T는 다음을 포함한다.
- 입력-출력 쌍의 분포 p(x, y)
- 태스크 손실 함수 L
메타 학습 목표:
min over theta: E over T ~ p(T) [L_T(f_theta)]
1.4 Support Set vs Query Set
퓨샷 학습에서 데이터는 두 가지 역할로 나뉜다.
Support Set (지원 집합): 모델이 새로운 태스크를 학습할 때 참조하는 소수의 예시들. 전통적 학습의 훈련 데이터에 해당하지만, 극히 적다 (예: 클래스당 1~5개).
Query Set (쿼리 집합): 모델의 성능을 평가하는 데이터. 전통적 학습의 테스트 데이터에 해당한다.
1.5 N-way K-shot 설정
퓨샷 학습에서 가장 중요한 설정은 N-way K-shot이다.
- N-way: 분류해야 할 클래스 수
- K-shot: 각 클래스당 support 예시 수
예를 들어, 5-way 1-shot은 5개 클래스를 각각 1개의 예시만으로 분류하는 태스크이다. 5-way 5-shot은 각 클래스당 5개의 예시를 사용한다.
import torch
import torch.nn as nn
import numpy as np
from typing import List, Tuple, Dict
def create_episode(
dataset,
n_way: int,
k_shot: int,
n_query: int,
classes: List[int] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
N-way K-shot 에피소드 생성
Returns: (support_x, support_y, query_x, query_y)
"""
# 클래스 선택
all_classes = list(set(dataset.targets.tolist()))
if classes is None:
selected_classes = np.random.choice(
all_classes, n_way, replace=False
)
else:
selected_classes = classes
support_x, support_y = [], []
query_x, query_y = [], []
for new_label, cls in enumerate(selected_classes):
# 해당 클래스의 모든 인덱스
cls_indices = (dataset.targets == cls).nonzero(as_tuple=True)[0]
chosen = np.random.choice(
len(cls_indices),
k_shot + n_query,
replace=False
)
for i, idx in enumerate(cls_indices[chosen]):
x, _ = dataset[idx.item()]
if i < k_shot:
support_x.append(x)
support_y.append(new_label)
else:
query_x.append(x)
query_y.append(new_label)
support_x = torch.stack(support_x)
support_y = torch.tensor(support_y)
query_x = torch.stack(query_x)
query_y = torch.tensor(query_y)
return support_x, support_y, query_x, query_y
2. 거리 기반 메타러닝
2.1 Matching Networks
Vinyals 등이 2016년에 제안한 Matching Networks는 어텐션 메커니즘과 kNN의 아이디어를 결합한다. 핵심 아이디어는 쿼리 샘플을 support 셋의 레이블의 어텐션 가중 합으로 예측하는 것이다.
예측 공식
y_hat = sum over i: a(x_hat, x_i) * y_i
여기서 a(x_hat, x_i)는 쿼리 x_hat과 support 샘플 x_i 사이의 어텐션 가중치다. 코사인 유사도의 소프트맥스를 사용한다.
class MatchingNetworks(nn.Module):
"""Matching Networks 구현"""
def __init__(self, encoder: nn.Module, use_fce: bool = False):
"""
encoder: 특성 추출기
use_fce: Full Context Embedding 사용 여부
"""
super().__init__()
self.encoder = encoder
self.use_fce = use_fce
def cosine_similarity(
self,
query: torch.Tensor,
support: torch.Tensor
) -> torch.Tensor:
"""
코사인 유사도 계산
query: (n_query, embed_dim)
support: (n_support, embed_dim)
Returns: (n_query, n_support)
"""
query_norm = nn.functional.normalize(query, dim=-1)
support_norm = nn.functional.normalize(support, dim=-1)
return torch.mm(query_norm, support_norm.t())
def forward(
self,
support_x: torch.Tensor,
support_y: torch.Tensor,
query_x: torch.Tensor,
n_way: int
) -> torch.Tensor:
"""
support_x: (n_way * k_shot, C, H, W)
support_y: (n_way * k_shot,)
query_x: (n_query, C, H, W)
"""
n_support = support_x.size(0)
n_query = query_x.size(0)
# 인코딩
support_emb = self.encoder(support_x) # (n_support, D)
query_emb = self.encoder(query_x) # (n_query, D)
# 유사도 계산
similarities = self.cosine_similarity(
query_emb, support_emb
) # (n_query, n_support)
# 소프트맥스 어텐션
attention = nn.functional.softmax(similarities, dim=-1)
# One-hot 레이블로 변환
support_labels_one_hot = nn.functional.one_hot(
support_y, n_way
).float() # (n_support, n_way)
# 어텐션 가중 합으로 예측
# (n_query, n_support) x (n_support, n_way) = (n_query, n_way)
logits = torch.mm(attention, support_labels_one_hot)
return logits # log 확률 형태로 반환
2.2 Prototypical Networks
Snell 등이 2017년 발표한 Prototypical Networks는 메타러닝 알고리즘 중 가장 우아하고 직관적인 방법이다. 핵심 아이디어: 각 클래스를 임베딩 공간에서 하나의 프로토타입(중심점)으로 표현한다.
프로토타입 계산
클래스 c의 프로토타입은 해당 클래스의 support 샘플들의 임베딩 평균이다.
p_c = (1/|S_c|) sum over (x_i, y_i) in S_c: f_phi(x_i)
분류
쿼리 샘플 x를 가장 가까운 프로토타입으로 분류한다.
p(y=c | x) = softmax(-d(f_phi(x), p_c))
여기서 d는 유클리드 거리이다.
class ConvEncoder(nn.Module):
"""퓨샷 학습용 4층 CNN 인코더"""
def __init__(
self,
in_channels: int = 1,
hidden_dim: int = 64,
out_dim: int = 64
):
super().__init__()
def conv_block(in_ch, out_ch):
return nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.net = nn.Sequential(
conv_block(in_channels, hidden_dim),
conv_block(hidden_dim, hidden_dim),
conv_block(hidden_dim, hidden_dim),
conv_block(hidden_dim, out_dim),
nn.Flatten()
)
def forward(self, x):
return self.net(x)
class PrototypicalNetworks(nn.Module):
"""Prototypical Networks 완전 구현"""
def __init__(self, encoder: nn.Module):
super().__init__()
self.encoder = encoder
def compute_prototypes(
self,
support_x: torch.Tensor,
support_y: torch.Tensor,
n_way: int
) -> torch.Tensor:
"""
각 클래스의 프로토타입 계산 (임베딩 평균)
support_x: (n_way * k_shot, C, H, W)
support_y: (n_way * k_shot,)
Returns: (n_way, embed_dim)
"""
support_emb = self.encoder(support_x) # (n_support, D)
prototypes = []
for cls in range(n_way):
mask = (support_y == cls)
cls_embeddings = support_emb[mask]
prototype = cls_embeddings.mean(dim=0)
prototypes.append(prototype)
return torch.stack(prototypes) # (n_way, D)
def euclidean_dist(
self,
x: torch.Tensor,
y: torch.Tensor
) -> torch.Tensor:
"""
유클리드 거리 계산
x: (n, D)
y: (m, D)
Returns: (n, m)
"""
# ||x - y||^2 = ||x||^2 + ||y||^2 - 2*x·y
n = x.size(0)
m = y.size(0)
x_sq = (x ** 2).sum(dim=1, keepdim=True).expand(n, m)
y_sq = (y ** 2).sum(dim=1, keepdim=True).expand(m, n).t()
xy = torch.mm(x, y.t())
dist = x_sq + y_sq - 2 * xy
return dist.clamp(min=0).sqrt()
def forward(
self,
support_x: torch.Tensor,
support_y: torch.Tensor,
query_x: torch.Tensor,
n_way: int
) -> torch.Tensor:
"""
Prototypical Networks 순전파
Returns: 쿼리 샘플에 대한 로그 확률 (n_query, n_way)
"""
# 프로토타입 계산
prototypes = self.compute_prototypes(
support_x, support_y, n_way
) # (n_way, D)
# 쿼리 임베딩
query_emb = self.encoder(query_x) # (n_query, D)
# 거리 계산
dists = self.euclidean_dist(
query_emb, prototypes
) # (n_query, n_way)
# 음의 거리를 로짓으로 사용 (가까울수록 높은 확률)
log_probs = nn.functional.log_softmax(-dists, dim=-1)
return log_probs
# ========== 훈련 루프 ==========
def train_prototypical(
model: PrototypicalNetworks,
train_dataset,
n_way: int = 5,
k_shot: int = 5,
n_query: int = 15,
n_episodes: int = 100,
lr: float = 1e-3,
device: str = 'cpu'
):
"""Prototypical Networks 훈련"""
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.NLLLoss()
model.train()
episode_losses = []
episode_accs = []
for episode in range(n_episodes):
# 에피소드 생성
support_x, support_y, query_x, query_y = create_episode(
train_dataset, n_way, k_shot, n_query
)
support_x = support_x.to(device)
support_y = support_y.to(device)
query_x = query_x.to(device)
query_y = query_y.to(device)
optimizer.zero_grad()
log_probs = model(support_x, support_y, query_x, n_way)
loss = criterion(log_probs, query_y)
loss.backward()
optimizer.step()
# 정확도 계산
preds = log_probs.argmax(dim=-1)
acc = (preds == query_y).float().mean().item()
episode_losses.append(loss.item())
episode_accs.append(acc)
if (episode + 1) % 100 == 0:
print(
f"에피소드 {episode+1}/{n_episodes} | "
f"손실: {np.mean(episode_losses[-100:]):.4f} | "
f"정확도: {np.mean(episode_accs[-100:]):.4f}"
)
return episode_losses, episode_accs
2.3 Relation Networks
Sung 등의 Relation Networks는 Prototypical Networks와 비슷하지만, 거리 함수를 학습 가능한 신경망으로 대체한다. 쿼리 임베딩과 클래스 프로토타입을 연결(concatenate)한 후 관계 점수를 계산하는 네트워크를 학습한다.
class RelationNetwork(nn.Module):
"""Relation Networks: 학습 가능한 거리 함수"""
def __init__(self, encoder: nn.Module, embed_dim: int = 64):
super().__init__()
self.encoder = encoder
# 관계 모듈: 두 임베딩을 연결하여 관계 점수 출력
self.relation_module = nn.Sequential(
nn.Linear(embed_dim * 2, 256),
nn.ReLU(),
nn.Linear(256, 64),
nn.ReLU(),
nn.Linear(64, 1),
nn.Sigmoid()
)
def forward(
self,
support_x: torch.Tensor,
support_y: torch.Tensor,
query_x: torch.Tensor,
n_way: int
) -> torch.Tensor:
"""
Returns: 관계 점수 (n_query, n_way)
"""
support_emb = self.encoder(support_x)
query_emb = self.encoder(query_x)
# 클래스별 프로토타입
prototypes = []
for cls in range(n_way):
mask = (support_y == cls)
proto = support_emb[mask].mean(dim=0)
prototypes.append(proto)
prototypes = torch.stack(prototypes) # (n_way, D)
n_query = query_emb.size(0)
# 각 쿼리와 각 프로토타입 쌍에 대한 관계 점수
query_expanded = query_emb.unsqueeze(1).expand(
n_query, n_way, -1
) # (n_query, n_way, D)
proto_expanded = prototypes.unsqueeze(0).expand(
n_query, n_way, -1
) # (n_query, n_way, D)
# 연결
pairs = torch.cat(
[query_expanded, proto_expanded], dim=-1
) # (n_query, n_way, 2D)
# 관계 점수 계산
scores = self.relation_module(
pairs.view(-1, pairs.size(-1))
).view(n_query, n_way)
return scores
3. 최적화 기반 메타러닝: MAML
3.1 MAML의 핵심 아이디어
MAML(Model-Agnostic Meta-Learning)은 Finn 등이 2017년 제안한 알고리즘으로, 메타러닝에서 가장 영향력 있는 연구 중 하나이다.
MAML의 목표는 **"빠르게 적응할 수 있는 초기 파라미터 theta를 찾는 것"**이다.
구체적으로, 새로운 태스크가 주어졌을 때 단 몇 번의 그래디언트 업데이트만으로 좋은 성능을 낼 수 있는 초기점을 찾는다.
3.2 내부 루프 vs 외부 루프
MAML은 두 루프로 구성된다.
내부 루프 (Inner Loop / Task-specific adaptation)
각 태스크 T_i에 대해:
theta_i' = theta - alpha * grad_theta L_{T_i}(f_theta)
support 셋으로 1~5번의 그래디언트 업데이트를 수행하여 태스크별 파라미터 theta_i'를 얻는다.
외부 루프 (Outer Loop / Meta-update)
theta = theta - beta * grad_theta sum_i L_{T_i}(f_{theta_i'})
태스크별로 적응된 파라미터 theta_i'를 사용하여 query 셋에서 손실을 계산하고, 이를 바탕으로 메타 파라미터 theta를 업데이트한다.
3.3 이중 역전파 (Second-order Gradients)
MAML의 핵심 기술적 도전은 외부 루프의 그래디언트가 **내부 루프를 통한 이중 역전파(second-order gradients)**를 요구한다는 것이다.
grad_theta L(f_{theta_i'}) = grad_theta L(f_{theta - alpha * grad L(theta)})
이는 theta에 대한 그래디언트를 계산할 때, 내부 루프 업데이트도 미분해야 한다는 의미다 (헤시안 행렬이 포함됨). 계산 비용이 높다.
실용적으로는 **FOMAML (First-Order MAML)**을 사용한다. 이차 미분 항을 무시하고 근사 그래디언트를 사용한다.
3.4 완전한 MAML 구현
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from copy import deepcopy
from typing import List, Tuple
class MAML:
"""
MAML: Model-Agnostic Meta-Learning
Finn et al., 2017 (arXiv:1703.03400)
"""
def __init__(
self,
model: nn.Module,
inner_lr: float = 0.01, # alpha: 내부 루프 학습률
outer_lr: float = 0.001, # beta: 외부 루프 학습률
n_inner_steps: int = 5, # 내부 루프 업데이트 횟수
first_order: bool = False, # FOMAML 사용 여부
device: str = 'cpu'
):
self.model = model.to(device)
self.inner_lr = inner_lr
self.outer_lr = outer_lr
self.n_inner_steps = n_inner_steps
self.first_order = first_order
self.device = device
self.meta_optimizer = torch.optim.Adam(
self.model.parameters(), lr=outer_lr
)
def inner_loop(
self,
support_x: torch.Tensor,
support_y: torch.Tensor,
model_params=None
) -> dict:
"""
내부 루프: support 셋으로 태스크별 적응
Returns: 적응된 파라미터 딕셔너리
"""
# 현재 파라미터 복사
if model_params is None:
params = {
name: param.clone()
for name, param in self.model.named_parameters()
}
else:
params = {k: v.clone() for k, v in model_params.items()}
for step in range(self.n_inner_steps):
# 현재 파라미터로 순전파 (functional API 사용)
logits = self._forward_with_params(support_x, params)
loss = F.cross_entropy(logits, support_y)
# 파라미터에 대한 그래디언트
grads = torch.autograd.grad(
loss,
params.values(),
create_graph=not self.first_order # 이차 미분을 위해
)
# 파라미터 업데이트 (SGD)
params = {
name: param - self.inner_lr * grad
for (name, param), grad in zip(params.items(), grads)
}
return params
def _forward_with_params(
self,
x: torch.Tensor,
params: dict
) -> torch.Tensor:
"""
특정 파라미터를 사용한 순전파
torch.nn.utils.stateless 또는 functorch 사용
"""
# 모델의 현재 파라미터를 임시로 교체
original_params = {}
for name, param in self.model.named_parameters():
original_params[name] = param.data
param.data = params[name].data if name in params else param.data
output = self.model(x)
# 원래 파라미터 복원 (create_graph=True인 경우 필요 없음)
for name, param in self.model.named_parameters():
if name in original_params:
param.data = original_params[name]
return output
def meta_train_step(
self,
tasks: List[Tuple]
) -> float:
"""
하나의 메타 훈련 스텝 (여러 태스크에 대해)
tasks: [(support_x, support_y, query_x, query_y), ...]
"""
self.meta_optimizer.zero_grad()
meta_loss = 0.0
for support_x, support_y, query_x, query_y in tasks:
support_x = support_x.to(self.device)
support_y = support_y.to(self.device)
query_x = query_x.to(self.device)
query_y = query_y.to(self.device)
# 내부 루프: 태스크별 적응
adapted_params = self.inner_loop(support_x, support_y)
# 외부 루프: 적응된 파라미터로 query 셋 평가
query_logits = self._forward_with_params(
query_x, adapted_params
)
query_loss = F.cross_entropy(query_logits, query_y)
meta_loss += query_loss
# 태스크 수로 평균
meta_loss /= len(tasks)
# 메타 파라미터 업데이트
meta_loss.backward()
self.meta_optimizer.step()
return meta_loss.item()
def fine_tune(
self,
support_x: torch.Tensor,
support_y: torch.Tensor,
n_steps: int = None
) -> nn.Module:
"""
새로운 태스크에 대한 파인튜닝 (추론 시 사용)
"""
n_steps = n_steps or self.n_inner_steps
model_copy = deepcopy(self.model)
optimizer = torch.optim.SGD(
model_copy.parameters(), lr=self.inner_lr
)
model_copy.train()
for step in range(n_steps):
optimizer.zero_grad()
logits = model_copy(support_x.to(self.device))
loss = F.cross_entropy(logits, support_y.to(self.device))
loss.backward()
optimizer.step()
return model_copy
def evaluate(
self,
tasks: List[Tuple],
n_fine_tune_steps: int = 5
) -> Tuple[float, float]:
"""
메타 테스트: 새로운 태스크들에 대해 적응 후 평가
"""
total_loss = 0.0
total_acc = 0.0
for support_x, support_y, query_x, query_y in tasks:
# 파인튜닝
adapted_model = self.fine_tune(support_x, support_y, n_fine_tune_steps)
adapted_model.eval()
with torch.no_grad():
query_logits = adapted_model(query_x.to(self.device))
loss = F.cross_entropy(query_logits, query_y.to(self.device))
preds = query_logits.argmax(dim=-1)
acc = (preds == query_y.to(self.device)).float().mean()
total_loss += loss.item()
total_acc += acc.item()
return total_loss / len(tasks), total_acc / len(tasks)
# ========== MAML 훈련 루프 ==========
def train_maml(
maml: MAML,
dataset,
n_way: int = 5,
k_shot: int = 1,
n_query: int = 15,
meta_batch_size: int = 32,
n_iterations: int = 60000,
device: str = 'cpu'
):
"""MAML 훈련"""
print(f"MAML 훈련 시작: {n_way}-way {k_shot}-shot")
print(f"메타 배치 크기: {meta_batch_size}")
print(f"총 이터레이션: {n_iterations}")
losses = []
for iteration in range(n_iterations):
# 메타 배치 생성
tasks = []
for _ in range(meta_batch_size):
task = create_episode(dataset, n_way, k_shot, n_query)
tasks.append(task)
# 메타 훈련 스텝
meta_loss = maml.meta_train_step(tasks)
losses.append(meta_loss)
if (iteration + 1) % 1000 == 0:
avg_loss = np.mean(losses[-1000:])
print(f"이터레이션 {iteration+1}/{n_iterations} | 메타 손실: {avg_loss:.4f}")
return losses
3.5 Reptile: MAML의 단순화
Reptile(Nichol et al., 2018)은 MAML을 크게 단순화한 알고리즘이다. 이차 미분이 필요 없고 구현이 훨씬 간단하다.
핵심 아이디어: 태스크에 대해 SGD를 여러 번 수행한 후, 최종 파라미터 방향으로 메타 파라미터를 이동시킨다.
theta = theta + epsilon * (W_k - theta)
여기서 W_k는 태스크 T에서 k번 SGD 업데이트 후의 파라미터이다.
class Reptile:
"""
Reptile: A Scalable Meta-learning Algorithm
Nichol et al., 2018 (arXiv:1803.02999)
"""
def __init__(
self,
model: nn.Module,
inner_lr: float = 0.02,
outer_lr: float = 0.001,
n_inner_steps: int = 5,
device: str = 'cpu'
):
self.model = model.to(device)
self.inner_lr = inner_lr
self.outer_lr = outer_lr # epsilon
self.n_inner_steps = n_inner_steps
self.device = device
def inner_train(
self,
support_x: torch.Tensor,
support_y: torch.Tensor
) -> dict:
"""태스크별 내부 훈련 (SGD k번)"""
model_copy = deepcopy(self.model)
optimizer = torch.optim.SGD(
model_copy.parameters(), lr=self.inner_lr
)
model_copy.train()
for step in range(self.n_inner_steps):
optimizer.zero_grad()
logits = model_copy(support_x)
loss = F.cross_entropy(logits, support_y)
loss.backward()
optimizer.step()
return dict(model_copy.named_parameters())
def meta_update(self, task_params_list: List[dict]):
"""
Reptile 메타 업데이트:
theta += epsilon * (mean(W_k) - theta)
"""
with torch.no_grad():
for name, param in self.model.named_parameters():
# 태스크별 파라미터의 평균
task_mean = torch.stack([
task_params[name].data
for task_params in task_params_list
]).mean(dim=0)
# Reptile 업데이트
param.data += self.outer_lr * (task_mean - param.data)
def train(
self,
dataset,
n_way: int = 5,
k_shot: int = 5,
meta_batch_size: int = 5,
n_iterations: int = 100000
):
"""Reptile 전체 훈련 루프"""
print(f"Reptile 훈련 시작: {n_way}-way {k_shot}-shot")
for iteration in range(n_iterations):
task_params_list = []
for _ in range(meta_batch_size):
support_x, support_y, _, _ = create_episode(
dataset, n_way, k_shot, n_query=0
)
support_x = support_x.to(self.device)
support_y = support_y.to(self.device)
task_params = self.inner_train(support_x, support_y)
task_params_list.append(task_params)
self.meta_update(task_params_list)
if (iteration + 1) % 10000 == 0:
print(f"이터레이션 {iteration+1}/{n_iterations} 완료")
4. LLM의 In-Context Learning
4.1 In-Context Learning이란?
In-Context Learning(ICL)은 대형 언어 모델(LLM)이 프롬프트 내의 예시들로부터 새로운 태스크를 수행하는 능력이다. 모델 파라미터를 업데이트하지 않고, 오직 입력 컨텍스트(프롬프트)만으로 태스크를 학습한다.
GPT-3의 등장과 함께 이 능력이 크게 주목받았다. 예를 들어:
영어→프랑스어 번역:
sea otter => loutre de mer
peppermint => menthe poivrée
plush giraffe => girafe en peluche
cheese => ?
이 형식의 프롬프트를 주면 GPT-3는 "fromage"라고 답한다. 프랑스어를 특별히 훈련하지 않은 것처럼 보이지만, 프리트레이닝에서 이미 패턴을 학습한 것이다.
4.2 왜 효과적인가?
ICL이 효과적인 이유에 대한 설명은 여전히 활발히 연구 중이지만, 주요 가설은 다음과 같다.
패턴 완성 관점: LLM은 다음 토큰 예측으로 훈련되었다. 프롬프트의 패턴(입력→출력 쌍)을 보고 그 패턴을 계속하도록 학습되어 있다.
잠재 개념 추론 관점: Brown 등의 연구에 따르면, ICL은 모델이 프롬프트에서 잠재 개념(latent concept)을 추론하는 베이지안 추론과 유사하다.
그래디언트 하강 메타포: Akyürek 등은 Transformer의 어텐션 메커니즘이 암묵적으로 그래디언트 하강과 유사한 연산을 수행한다는 것을 보였다.
4.3 효과적인 퓨샷 프롬프트 전략
from typing import List, Dict, Any
import openai
import anthropic
class FewShotPromptBuilder:
"""퓨샷 프롬프트 생성기"""
def __init__(self):
self.examples = []
self.instruction = ""
self.template = "{input} => {output}"
def set_instruction(self, instruction: str):
"""태스크 지시문 설정"""
self.instruction = instruction
return self
def add_example(self, input_text: str, output_text: str):
"""예시 추가"""
self.examples.append({
'input': input_text,
'output': output_text
})
return self
def build_prompt(self, query: str) -> str:
"""완전한 퓨샷 프롬프트 생성"""
parts = []
if self.instruction:
parts.append(self.instruction)
parts.append("")
for ex in self.examples:
parts.append(self.template.format(
input=ex['input'],
output=ex['output']
))
# 쿼리 (출력은 비워둠)
parts.append(f"{query} =>")
return "\n".join(parts)
def build_chat_messages(
self,
query: str,
system_prompt: str = None
) -> List[Dict]:
"""Chat 형식 메시지 생성 (GPT-4, Claude 등에 사용)"""
messages = []
if system_prompt:
messages.append({
'role': 'system',
'content': system_prompt
})
# 예시를 대화 형식으로 변환
for ex in self.examples:
messages.append({
'role': 'user',
'content': ex['input']
})
messages.append({
'role': 'assistant',
'content': ex['output']
})
# 실제 쿼리
messages.append({
'role': 'user',
'content': query
})
return messages
class DynamicExampleSelector:
"""
동적 예시 선택기
쿼리와 가장 유사한 예시를 선택하여 더 효과적인 퓨샷 학습
"""
def __init__(self, examples: List[Dict], encoder=None):
self.examples = examples
self.encoder = encoder # sentence-transformers 등
def select_similar(
self,
query: str,
n_examples: int = 3
) -> List[Dict]:
"""
쿼리와 가장 유사한 n개의 예시 선택
"""
if self.encoder is None:
# 인코더가 없으면 랜덤 선택
return np.random.choice(
self.examples, n_examples, replace=False
).tolist()
# 의미적 유사도 기반 선택
query_emb = self.encoder.encode(query)
example_embs = self.encoder.encode(
[ex['input'] for ex in self.examples]
)
# 코사인 유사도
similarities = np.dot(example_embs, query_emb) / (
np.linalg.norm(example_embs, axis=1)
* np.linalg.norm(query_emb)
)
top_indices = np.argsort(similarities)[-n_examples:][::-1]
return [self.examples[i] for i in top_indices]
def select_diverse(
self,
n_examples: int = 3
) -> List[Dict]:
"""
다양성을 최대화하는 예시 선택 (MMR 알고리즘)
"""
if len(self.examples) <= n_examples:
return self.examples
if self.encoder is None:
return np.random.choice(
self.examples, n_examples, replace=False
).tolist()
embeddings = self.encoder.encode(
[ex['input'] for ex in self.examples]
)
selected = [0] # 첫 번째 예시 선택
remaining = list(range(1, len(self.examples)))
while len(selected) < n_examples:
# 선택된 예시들과의 최대 유사도
selected_embs = embeddings[selected]
best_idx = None
best_score = float('-inf')
for idx in remaining:
# MMR: 쿼리와의 유사도 - 선택된 예시들과의 유사도
sim_to_selected = np.max(
np.dot(selected_embs, embeddings[idx]) / (
np.linalg.norm(selected_embs, axis=1)
* np.linalg.norm(embeddings[idx])
)
)
score = -sim_to_selected # 다양성 최대화
if score > best_score:
best_score = score
best_idx = idx
selected.append(best_idx)
remaining.remove(best_idx)
return [self.examples[i] for i in selected]
# ========== 실전 사용 예시 ==========
def sentiment_analysis_few_shot():
"""감성 분석 퓨샷 예시"""
builder = FewShotPromptBuilder()
builder.set_instruction("다음 영화 리뷰의 감성을 분석하세요. 긍정(Positive) 또는 부정(Negative)으로 답하세요.")
builder.add_example(
"이 영화는 정말 감동적이었고 배우들의 연기가 훌륭했다.",
"Positive"
)
builder.add_example(
"스토리가 너무 지루하고 결말이 실망스러웠다.",
"Negative"
)
builder.add_example(
"특수효과는 멋있었지만 시나리오가 너무 허술했다.",
"Negative"
)
builder.add_example(
"오랜만에 가족과 함께 볼 수 있는 따뜻한 영화였다.",
"Positive"
)
query = "연출이 독특하고 음악이 영상과 완벽하게 어우러졌다."
prompt = builder.build_prompt(query)
print("=== 퓨샷 프롬프트 ===")
print(prompt)
return prompt
def korean_nlp_few_shot():
"""한국어 NLP 퓨샷 예시: 개체명 인식"""
builder = FewShotPromptBuilder()
builder.set_instruction(
"다음 문장에서 인물(PER), 장소(LOC), 기관(ORG) 개체를 찾아 태깅하세요."
)
builder.add_example(
"이순신 장군이 한산도에서 일본군을 격파했다.",
"이순신[PER] 장군이 한산도[LOC]에서 일본군[ORG]을 격파했다."
)
builder.add_example(
"삼성전자는 수원 사업장에서 반도체를 생산한다.",
"삼성전자[ORG]는 수원[LOC] 사업장에서 반도체를 생산한다."
)
query = "박지성이 맨체스터 유나이티드에서 활약했다."
prompt = builder.build_prompt(query)
print("=== 한국어 NLP 퓨샷 프롬프트 ===")
print(prompt)
return prompt
4.4 Cross-lingual Few-shot
다국어 모델은 제로샷 크로스링구얼 전이(zero-shot cross-lingual transfer) 능력을 보인다. 영어로만 훈련된 태스크를 한국어나 다른 언어에 적용할 수 있다.
class CrossLingualFewShot:
"""
크로스링구얼 퓨샷 학습
영어 예시 + 타겟 언어 쿼리
"""
def __init__(self, model_name: str = "xlm-roberta-large"):
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=2
)
def encode_text(self, text: str) -> torch.Tensor:
"""텍스트를 임베딩으로 인코딩"""
inputs = self.tokenizer(
text,
return_tensors='pt',
padding=True,
truncation=True,
max_length=128
)
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
# [CLS] 토큰의 마지막 은닉 상태
embedding = outputs.hidden_states[-1][:, 0, :]
return embedding
def classify_zero_shot(
self,
query_ko: str,
class_descriptions_en: List[str]
) -> int:
"""
영어 클래스 설명으로 한국어 쿼리 분류
"""
query_emb = self.encode_text(query_ko)
class_embs = torch.cat([
self.encode_text(desc) for desc in class_descriptions_en
])
# 코사인 유사도
similarities = F.cosine_similarity(
query_emb.expand(len(class_descriptions_en), -1),
class_embs
)
return similarities.argmax().item()
def few_shot_classify(
self,
support_texts: List[str],
support_labels: List[int],
query_text: str,
n_classes: int
) -> int:
"""
퓨샷 분류 (프로토타입 방식)
support와 query는 다른 언어 가능
"""
support_embs = torch.cat([
self.encode_text(t) for t in support_texts
])
query_emb = self.encode_text(query_text)
# 클래스별 프로토타입
prototypes = []
for cls in range(n_classes):
mask = torch.tensor([l == cls for l in support_labels])
proto = support_embs[mask].mean(dim=0, keepdim=True)
prototypes.append(proto)
prototypes = torch.cat(prototypes)
# 유클리드 거리로 분류
dists = torch.cdist(query_emb, prototypes)
return dists.argmin().item()
5. 실전 활용: 의료 이미지 퓨샷 분류
5.1 희귀 질환 진단 시스템
임상 현장에서 희귀 질환은 훈련 데이터가 매우 부족하다. 퓨샷 학습을 사용하면 소수의 확진 사례만으로도 새로운 질환 패턴을 인식할 수 있다.
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import numpy as np
class MedicalImageEncoder(nn.Module):
"""
의료 이미지용 인코더
ResNet-18 기반, 의료 이미지 특성에 맞게 수정
"""
def __init__(self, embed_dim: int = 512, pretrained: bool = True):
super().__init__()
# ResNet-18 백본
backbone = models.resnet18(pretrained=pretrained)
# FC 레이어 제거
self.backbone = nn.Sequential(*list(backbone.children())[:-1])
# 임베딩 헤드
self.embed_head = nn.Sequential(
nn.Flatten(),
nn.Linear(512, embed_dim),
nn.LayerNorm(embed_dim),
nn.ReLU(),
nn.Linear(embed_dim, embed_dim)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
features = self.backbone(x)
return self.embed_head(features)
class MedicalFewShotClassifier:
"""
의료 이미지 퓨샷 분류기
Prototypical Networks 기반
"""
def __init__(
self,
encoder: MedicalImageEncoder,
device: str = 'cpu'
):
self.encoder = encoder.to(device)
self.device = device
self.prototypes = {}
self.disease_names = {}
def register_disease(
self,
disease_id: int,
disease_name: str,
support_images: List,
transform=None
):
"""
새로운 질환 등록 (소수의 예시로)
"""
self.disease_names[disease_id] = disease_name
self.encoder.eval()
embeddings = []
with torch.no_grad():
for img in support_images:
if transform:
img_tensor = transform(img).unsqueeze(0).to(self.device)
else:
img_tensor = img.unsqueeze(0).to(self.device)
emb = self.encoder(img_tensor)
embeddings.append(emb)
prototype = torch.cat(embeddings).mean(dim=0)
self.prototypes[disease_id] = prototype
print(
f"질환 등록 완료: {disease_name} "
f"(예시 {len(support_images)}개)"
)
def diagnose(
self,
query_image: torch.Tensor,
top_k: int = 3
) -> List[Dict]:
"""
쿼리 이미지 진단
Returns: 상위 k개 질환과 유사도 점수
"""
self.encoder.eval()
with torch.no_grad():
query_emb = self.encoder(
query_image.unsqueeze(0).to(self.device)
)
results = []
for disease_id, prototype in self.prototypes.items():
# 코사인 유사도
similarity = F.cosine_similarity(
query_emb, prototype.unsqueeze(0)
).item()
results.append({
'disease_id': disease_id,
'disease_name': self.disease_names[disease_id],
'similarity': similarity
})
# 유사도 순으로 정렬
results.sort(key=lambda x: x['similarity'], reverse=True)
return results[:top_k]
def update_prototype(
self,
disease_id: int,
new_image: torch.Tensor,
momentum: float = 0.9
):
"""
새로운 확진 사례로 프로토타입 업데이트 (온라인 학습)
"""
self.encoder.eval()
with torch.no_grad():
new_emb = self.encoder(
new_image.unsqueeze(0).to(self.device)
).squeeze(0)
if disease_id in self.prototypes:
# 지수 이동 평균으로 업데이트
self.prototypes[disease_id] = (
momentum * self.prototypes[disease_id]
+ (1 - momentum) * new_emb
)
print(f"프로토타입 업데이트 완료: {self.disease_names[disease_id]}")
else:
print(f"경고: 질환 {disease_id}이 등록되지 않았습니다.")
def demo_medical_few_shot():
"""의료 퓨샷 분류 데모"""
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
encoder = MedicalImageEncoder(embed_dim=512, pretrained=True)
classifier = MedicalFewShotClassifier(encoder)
print("=== 의료 이미지 퓨샷 분류 시스템 ===")
print("이 시스템은 소수의 확진 사례만으로 새로운 질환을 인식합니다.")
print()
# 실제 사용 예시
# 1. 각 질환에 대해 5~10장의 레퍼런스 이미지로 등록
# 2. 새로운 환자 이미지를 입력하면 유사한 질환 진단
# 3. 새로운 확진 사례가 나올 때마다 프로토타입 업데이트
print("사용 방법:")
print("1. classifier.register_disease(id, name, support_images)")
print("2. results = classifier.diagnose(patient_image)")
print("3. classifier.update_prototype(disease_id, new_confirmed_image)")
6. learn2learn 라이브러리 활용
6.1 learn2learn 소개
learn2learn은 MAML, ProtoNet 등 메타러닝 알고리즘을 쉽게 구현할 수 있는 라이브러리이다.
pip install learn2learn
6.2 learn2learn으로 MAML 구현
import learn2learn as l2l
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
def build_l2l_maml(
model: nn.Module,
lr: float = 0.01,
first_order: bool = False
) -> l2l.algorithms.MAML:
"""learn2learn MAML 래퍼 생성"""
return l2l.algorithms.MAML(
model,
lr=lr,
first_order=first_order,
allow_unused=True
)
def train_with_l2l(
maml_model: l2l.algorithms.MAML,
tasksets,
n_way: int = 5,
k_shot: int = 1,
n_query: int = 15,
meta_lr: float = 0.003,
n_iterations: int = 1000,
adaptation_steps: int = 1,
device: str = 'cpu'
):
"""
learn2learn을 사용한 MAML 훈련
"""
maml_model = maml_model.to(device)
meta_optimizer = torch.optim.Adam(
maml_model.parameters(), lr=meta_lr
)
criterion = nn.CrossEntropyLoss(reduction='mean')
for iteration in range(n_iterations):
meta_optimizer.zero_grad()
meta_train_loss = 0.0
meta_train_acc = 0.0
# 미니 메타 배치
for task in range(4): # 4개 태스크 동시 처리
# 태스크 샘플링
X, y = tasksets.train.sample()
X, y = X.to(device), y.to(device)
# 분류기 클론 생성 (메타러닝을 위한 복사본)
learner = maml_model.clone()
# Support / Query 분리
support_indices = torch.zeros(X.size(0), dtype=torch.bool)
for cls in range(n_way):
cls_idx = (y == cls).nonzero(as_tuple=True)[0]
support_idx = cls_idx[:k_shot]
support_indices[support_idx] = True
query_indices = ~support_indices
support_x, support_y = X[support_indices], y[support_indices]
query_x, query_y = X[query_indices], y[query_indices]
# 내부 루프: adaptation
for step in range(adaptation_steps):
support_logits = learner(support_x)
support_loss = criterion(support_logits, support_y)
learner.adapt(support_loss)
# 외부 루프: meta-gradient
query_logits = learner(query_x)
query_loss = criterion(query_logits, query_y)
meta_train_loss += query_loss
# 정확도
preds = query_logits.argmax(dim=-1)
acc = (preds == query_y).float().mean()
meta_train_acc += acc
meta_train_loss /= 4
meta_train_acc /= 4
meta_train_loss.backward()
meta_optimizer.step()
if (iteration + 1) % 100 == 0:
print(
f"이터레이션 {iteration+1}/{n_iterations} | "
f"메타 손실: {meta_train_loss.item():.4f} | "
f"메타 정확도: {meta_train_acc.item():.4f}"
)
def setup_omniglot_maml():
"""
Omniglot 데이터셋으로 MAML 설정
Omniglot: 50개 언어의 알파벳 문자 (1623 클래스, 각 20개 샘플)
"""
# learn2learn의 벤치마크 데이터셋 사용
tasksets = l2l.vision.benchmarks.get_tasksets(
'omniglot',
train_ways=5,
train_samples=2 * 1 + 2 * 15, # k_shot + n_query
test_ways=5,
test_samples=2 * 1 + 2 * 15,
root='./data',
device='cpu'
)
# CNN 모델
model = l2l.vision.models.OmniglotCNN(
output_size=5,
hidden_size=64,
layers=4
)
# MAML 래퍼
maml = build_l2l_maml(model, lr=0.4, first_order=False)
return maml, tasksets
def evaluate_l2l(
maml_model: l2l.algorithms.MAML,
tasksets,
n_way: int = 5,
k_shot: int = 1,
n_query: int = 15,
n_test_tasks: int = 600,
adaptation_steps: int = 3,
device: str = 'cpu'
) -> Tuple[float, float]:
"""learn2learn 메타 평가"""
criterion = nn.CrossEntropyLoss()
total_loss = 0.0
total_acc = 0.0
maml_model.eval()
for _ in range(n_test_tasks):
X, y = tasksets.test.sample()
X, y = X.to(device), y.to(device)
learner = maml_model.clone()
support_indices = torch.zeros(X.size(0), dtype=torch.bool)
for cls in range(n_way):
cls_idx = (y == cls).nonzero(as_tuple=True)[0]
support_idx = cls_idx[:k_shot]
support_indices[support_idx] = True
query_indices = ~support_indices
support_x, support_y = X[support_indices], y[support_indices]
query_x, query_y = X[query_indices], y[query_indices]
# 테스트 시 더 많은 적응 스텝 사용
for step in range(adaptation_steps):
support_loss = criterion(learner(support_x), support_y)
learner.adapt(support_loss)
with torch.no_grad():
query_logits = learner(query_x)
loss = criterion(query_logits, query_y)
acc = (query_logits.argmax(dim=-1) == query_y).float().mean()
total_loss += loss.item()
total_acc += acc.item()
return total_loss / n_test_tasks, total_acc / n_test_tasks
7. 메타러닝 벤치마크
7.1 주요 벤치마크 데이터셋
Omniglot
50개 알파벳 체계의 1623개 문자 클래스. 각 클래스당 20개 샘플. 주로 20-way 1-shot 설정에서 평가한다.
Mini-ImageNet
ImageNet의 100개 클래스 서브셋. 각 클래스당 600개 이미지 (84x84). 5-way 1/5-shot 설정이 표준이다.
tieredImageNet
Mini-ImageNet보다 더 어려운 버전. 상위 개념으로 그룹화된 클래스를 사용하여 메타-훈련과 메타-테스트 클래스 간의 의미적 격차를 크게 한다.
CIFAR-FS
CIFAR-100에서 파생된 퓨샷 벤치마크. Mini-ImageNet보다 빠르게 실험할 수 있다.
7.2 평가 프로토콜
def standard_few_shot_evaluation(
model,
test_dataset,
n_way: int = 5,
k_shot: int = 1,
n_query: int = 15,
n_episodes: int = 600,
confidence_interval: bool = True
) -> Dict:
"""
표준 퓨샷 평가 프로토콜
600개 에피소드의 평균과 95% 신뢰 구간
"""
accs = []
model.eval()
for episode in range(n_episodes):
support_x, support_y, query_x, query_y = create_episode(
test_dataset, n_way, k_shot, n_query
)
with torch.no_grad():
log_probs = model(support_x, support_y, query_x, n_way)
preds = log_probs.argmax(dim=-1)
acc = (preds == query_y).float().mean().item()
accs.append(acc)
mean_acc = np.mean(accs)
std_acc = np.std(accs)
if confidence_interval:
# 95% 신뢰 구간
ci = 1.96 * std_acc / np.sqrt(n_episodes)
return {
'mean_accuracy': mean_acc,
'std': std_acc,
'confidence_interval_95': ci,
'result_string': f"{mean_acc*100:.2f} ± {ci*100:.2f}%"
}
return {'mean_accuracy': mean_acc, 'std': std_acc}
8. 마무리: 메타러닝의 미래
메타러닝은 AI 연구의 핵심 패러다임 중 하나로 자리잡고 있다. 특히 주목할 트렌드는 다음과 같다.
LLM과의 융합
GPT-4, Claude 같은 대형 언어 모델은 강력한 ICL 능력을 보여준다. 이 모델들이 임의의 도메인에서 퓨샷 학습을 수행하는 메타 러너로 작동한다는 관점에서 연구가 활발하다.
멀티모달 퓨샷 학습
텍스트, 이미지, 음성을 통합한 멀티모달 퓨샷 학습. GPT-4V, Gemini Ultra 등이 시각 퓨샷 태스크에서 인상적인 성능을 보이고 있다.
지속 학습(Continual Learning)과의 결합
메타러닝으로 초기화된 모델은 새로운 태스크를 학습할 때 이전 지식을 덜 잊어버린다는 연구 결과가 있다. 지속 학습과 메타러닝의 결합이 활발히 연구되고 있다.
참고 자료
- Finn, C., et al. (2017). Model-Agnostic Meta-Learning for Fast Adaptation. ICML 2017. arXiv:1703.03400
- Snell, J., et al. (2017). Prototypical Networks for Few-shot Learning. NeurIPS 2017. arXiv:1703.05175
- Vinyals, O., et al. (2016). Matching Networks for One Shot Learning. NeurIPS 2016. arXiv:1606.04080
- Nichol, A., et al. (2018). On First-Order Meta-Learning Algorithms. arXiv:1803.02999
- Brown, T., et al. (2020). Language Models are Few-Shot Learners (GPT-3). NeurIPS 2020.
- learn2learn library: https://github.com/learnables/learn2learn