Skip to content

Split View: 메타러닝과 퓨샷 학습 완전 가이드: MAML, Prototypical Networks, In-Context Learning

|

메타러닝과 퓨샷 학습 완전 가이드: MAML, Prototypical Networks, In-Context Learning

메타러닝과 퓨샷 학습 완전 가이드: MAML, Prototypical Networks, In-Context Learning

인간은 새로운 개념을 단 몇 개의 예시만으로도 빠르게 학습한다. 어린아이에게 "이게 얼룩말이야"라고 보여주면, 그 다음 날 처음 보는 얼룩말 사진도 바로 알아본다. 하지만 전통적인 딥러닝 모델은 얼룩말 분류기 하나를 만들기 위해 수천 장의 사진이 필요하다.

이 간극을 좁히는 것이 **메타러닝(Meta-Learning)**과 **퓨샷 학습(Few-Shot Learning)**이다. 메타러닝의 핵심 아이디어는 "학습하는 방법을 학습(Learning to Learn)"하는 것이다. 모델이 다양한 태스크를 경험하면서 새로운 태스크에 빠르게 적응하는 능력 자체를 학습한다.

이 가이드에서는 메타러닝의 이론적 기반부터 최신 In-Context Learning까지 완전히 다룬다.


1. 메타러닝 기초

1.1 학습하는 방법을 학습 (Learning to Learn)

전통적인 머신러닝에서 모델은 단일 태스크를 위해 학습된다. 고양이 분류기는 고양이 데이터로만 학습하고, 새로운 동물(예: 오셀롯)을 분류해야 할 때는 처음부터 다시 학습해야 한다.

메타러닝에서는 관점을 바꾼다. 모델이 학습해야 할 것은 "고양이를 분류하는 방법"이 아니라 "새로운 동물을 빠르게 분류하는 방법을 배우는 방법"이다.

이를 위해 메타러닝은 두 수준의 학습을 가진다.

  • 메타 학습 (Outer Loop): 여러 태스크를 경험하면서 좋은 초기화 또는 학습 알고리즘을 학습
  • 태스크 학습 (Inner Loop): 각 특정 태스크에서 소수의 예시로 빠르게 적응

1.2 기존 학습의 한계

전통적 학습의 한계는 구체적으로 다음과 같다.

데이터 효율성: ImageNet으로 학습된 모델은 수백만 개의 이미지가 필요하다. 새로운 클래스를 추가하려면 수천 장이 필요하다.

일반화 능력 부족: 훈련 분포에서 크게 벗어난 새로운 태스크에는 성능이 급격히 떨어진다.

지속적 학습 문제: 새로운 태스크를 학습하면 이전 태스크를 잊어버리는 "재앙적 망각(catastrophic forgetting)" 현상이 발생한다.

1.3 태스크 분포 (Task Distribution)

메타러닝에서 핵심 개념은 태스크 분포 p(T) 이다. 메타 학습은 단순히 데이터로부터 학습하는 것이 아니라, 태스크들의 분포로부터 학습한다.

각 태스크 T는 다음을 포함한다.

  • 입력-출력 쌍의 분포 p(x, y)
  • 태스크 손실 함수 L

메타 학습 목표:

min over theta: E over T ~ p(T) [L_T(f_theta)]

1.4 Support Set vs Query Set

퓨샷 학습에서 데이터는 두 가지 역할로 나뉜다.

Support Set (지원 집합): 모델이 새로운 태스크를 학습할 때 참조하는 소수의 예시들. 전통적 학습의 훈련 데이터에 해당하지만, 극히 적다 (예: 클래스당 1~5개).

Query Set (쿼리 집합): 모델의 성능을 평가하는 데이터. 전통적 학습의 테스트 데이터에 해당한다.

1.5 N-way K-shot 설정

퓨샷 학습에서 가장 중요한 설정은 N-way K-shot이다.

  • N-way: 분류해야 할 클래스 수
  • K-shot: 각 클래스당 support 예시 수

예를 들어, 5-way 1-shot은 5개 클래스를 각각 1개의 예시만으로 분류하는 태스크이다. 5-way 5-shot은 각 클래스당 5개의 예시를 사용한다.

import torch
import torch.nn as nn
import numpy as np
from typing import List, Tuple, Dict


def create_episode(
    dataset,
    n_way: int,
    k_shot: int,
    n_query: int,
    classes: List[int] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    N-way K-shot 에피소드 생성
    Returns: (support_x, support_y, query_x, query_y)
    """
    # 클래스 선택
    all_classes = list(set(dataset.targets.tolist()))
    if classes is None:
        selected_classes = np.random.choice(
            all_classes, n_way, replace=False
        )
    else:
        selected_classes = classes

    support_x, support_y = [], []
    query_x, query_y = [], []

    for new_label, cls in enumerate(selected_classes):
        # 해당 클래스의 모든 인덱스
        cls_indices = (dataset.targets == cls).nonzero(as_tuple=True)[0]
        chosen = np.random.choice(
            len(cls_indices),
            k_shot + n_query,
            replace=False
        )

        for i, idx in enumerate(cls_indices[chosen]):
            x, _ = dataset[idx.item()]
            if i < k_shot:
                support_x.append(x)
                support_y.append(new_label)
            else:
                query_x.append(x)
                query_y.append(new_label)

    support_x = torch.stack(support_x)
    support_y = torch.tensor(support_y)
    query_x = torch.stack(query_x)
    query_y = torch.tensor(query_y)

    return support_x, support_y, query_x, query_y

2. 거리 기반 메타러닝

2.1 Matching Networks

Vinyals 등이 2016년에 제안한 Matching Networks는 어텐션 메커니즘과 kNN의 아이디어를 결합한다. 핵심 아이디어는 쿼리 샘플을 support 셋의 레이블의 어텐션 가중 합으로 예측하는 것이다.

예측 공식

y_hat = sum over i: a(x_hat, x_i) * y_i

여기서 a(x_hat, x_i)는 쿼리 x_hat과 support 샘플 x_i 사이의 어텐션 가중치다. 코사인 유사도의 소프트맥스를 사용한다.

class MatchingNetworks(nn.Module):
    """Matching Networks 구현"""

    def __init__(self, encoder: nn.Module, use_fce: bool = False):
        """
        encoder: 특성 추출기
        use_fce: Full Context Embedding 사용 여부
        """
        super().__init__()
        self.encoder = encoder
        self.use_fce = use_fce

    def cosine_similarity(
        self,
        query: torch.Tensor,
        support: torch.Tensor
    ) -> torch.Tensor:
        """
        코사인 유사도 계산
        query: (n_query, embed_dim)
        support: (n_support, embed_dim)
        Returns: (n_query, n_support)
        """
        query_norm = nn.functional.normalize(query, dim=-1)
        support_norm = nn.functional.normalize(support, dim=-1)
        return torch.mm(query_norm, support_norm.t())

    def forward(
        self,
        support_x: torch.Tensor,
        support_y: torch.Tensor,
        query_x: torch.Tensor,
        n_way: int
    ) -> torch.Tensor:
        """
        support_x: (n_way * k_shot, C, H, W)
        support_y: (n_way * k_shot,)
        query_x: (n_query, C, H, W)
        """
        n_support = support_x.size(0)
        n_query = query_x.size(0)

        # 인코딩
        support_emb = self.encoder(support_x)   # (n_support, D)
        query_emb = self.encoder(query_x)        # (n_query, D)

        # 유사도 계산
        similarities = self.cosine_similarity(
            query_emb, support_emb
        )  # (n_query, n_support)

        # 소프트맥스 어텐션
        attention = nn.functional.softmax(similarities, dim=-1)

        # One-hot 레이블로 변환
        support_labels_one_hot = nn.functional.one_hot(
            support_y, n_way
        ).float()  # (n_support, n_way)

        # 어텐션 가중 합으로 예측
        # (n_query, n_support) x (n_support, n_way) = (n_query, n_way)
        logits = torch.mm(attention, support_labels_one_hot)

        return logits  # log 확률 형태로 반환

2.2 Prototypical Networks

Snell 등이 2017년 발표한 Prototypical Networks는 메타러닝 알고리즘 중 가장 우아하고 직관적인 방법이다. 핵심 아이디어: 각 클래스를 임베딩 공간에서 하나의 프로토타입(중심점)으로 표현한다.

프로토타입 계산

클래스 c의 프로토타입은 해당 클래스의 support 샘플들의 임베딩 평균이다.

p_c = (1/|S_c|) sum over (x_i, y_i) in S_c: f_phi(x_i)

분류

쿼리 샘플 x를 가장 가까운 프로토타입으로 분류한다.

p(y=c | x) = softmax(-d(f_phi(x), p_c))

여기서 d는 유클리드 거리이다.

class ConvEncoder(nn.Module):
    """퓨샷 학습용 4층 CNN 인코더"""

    def __init__(
        self,
        in_channels: int = 1,
        hidden_dim: int = 64,
        out_dim: int = 64
    ):
        super().__init__()

        def conv_block(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(),
                nn.MaxPool2d(2)
            )

        self.net = nn.Sequential(
            conv_block(in_channels, hidden_dim),
            conv_block(hidden_dim, hidden_dim),
            conv_block(hidden_dim, hidden_dim),
            conv_block(hidden_dim, out_dim),
            nn.Flatten()
        )

    def forward(self, x):
        return self.net(x)


class PrototypicalNetworks(nn.Module):
    """Prototypical Networks 완전 구현"""

    def __init__(self, encoder: nn.Module):
        super().__init__()
        self.encoder = encoder

    def compute_prototypes(
        self,
        support_x: torch.Tensor,
        support_y: torch.Tensor,
        n_way: int
    ) -> torch.Tensor:
        """
        각 클래스의 프로토타입 계산 (임베딩 평균)
        support_x: (n_way * k_shot, C, H, W)
        support_y: (n_way * k_shot,)
        Returns: (n_way, embed_dim)
        """
        support_emb = self.encoder(support_x)  # (n_support, D)
        prototypes = []

        for cls in range(n_way):
            mask = (support_y == cls)
            cls_embeddings = support_emb[mask]
            prototype = cls_embeddings.mean(dim=0)
            prototypes.append(prototype)

        return torch.stack(prototypes)  # (n_way, D)

    def euclidean_dist(
        self,
        x: torch.Tensor,
        y: torch.Tensor
    ) -> torch.Tensor:
        """
        유클리드 거리 계산
        x: (n, D)
        y: (m, D)
        Returns: (n, m)
        """
        # ||x - y||^2 = ||x||^2 + ||y||^2 - 2*x·y
        n = x.size(0)
        m = y.size(0)
        x_sq = (x ** 2).sum(dim=1, keepdim=True).expand(n, m)
        y_sq = (y ** 2).sum(dim=1, keepdim=True).expand(m, n).t()
        xy = torch.mm(x, y.t())
        dist = x_sq + y_sq - 2 * xy
        return dist.clamp(min=0).sqrt()

    def forward(
        self,
        support_x: torch.Tensor,
        support_y: torch.Tensor,
        query_x: torch.Tensor,
        n_way: int
    ) -> torch.Tensor:
        """
        Prototypical Networks 순전파
        Returns: 쿼리 샘플에 대한 로그 확률 (n_query, n_way)
        """
        # 프로토타입 계산
        prototypes = self.compute_prototypes(
            support_x, support_y, n_way
        )  # (n_way, D)

        # 쿼리 임베딩
        query_emb = self.encoder(query_x)  # (n_query, D)

        # 거리 계산
        dists = self.euclidean_dist(
            query_emb, prototypes
        )  # (n_query, n_way)

        # 음의 거리를 로짓으로 사용 (가까울수록 높은 확률)
        log_probs = nn.functional.log_softmax(-dists, dim=-1)

        return log_probs


# ========== 훈련 루프 ==========

def train_prototypical(
    model: PrototypicalNetworks,
    train_dataset,
    n_way: int = 5,
    k_shot: int = 5,
    n_query: int = 15,
    n_episodes: int = 100,
    lr: float = 1e-3,
    device: str = 'cpu'
):
    """Prototypical Networks 훈련"""
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.NLLLoss()

    model.train()
    episode_losses = []
    episode_accs = []

    for episode in range(n_episodes):
        # 에피소드 생성
        support_x, support_y, query_x, query_y = create_episode(
            train_dataset, n_way, k_shot, n_query
        )
        support_x = support_x.to(device)
        support_y = support_y.to(device)
        query_x = query_x.to(device)
        query_y = query_y.to(device)

        optimizer.zero_grad()

        log_probs = model(support_x, support_y, query_x, n_way)
        loss = criterion(log_probs, query_y)
        loss.backward()
        optimizer.step()

        # 정확도 계산
        preds = log_probs.argmax(dim=-1)
        acc = (preds == query_y).float().mean().item()

        episode_losses.append(loss.item())
        episode_accs.append(acc)

        if (episode + 1) % 100 == 0:
            print(
                f"에피소드 {episode+1}/{n_episodes} | "
                f"손실: {np.mean(episode_losses[-100:]):.4f} | "
                f"정확도: {np.mean(episode_accs[-100:]):.4f}"
            )

    return episode_losses, episode_accs

2.3 Relation Networks

Sung 등의 Relation Networks는 Prototypical Networks와 비슷하지만, 거리 함수를 학습 가능한 신경망으로 대체한다. 쿼리 임베딩과 클래스 프로토타입을 연결(concatenate)한 후 관계 점수를 계산하는 네트워크를 학습한다.

class RelationNetwork(nn.Module):
    """Relation Networks: 학습 가능한 거리 함수"""

    def __init__(self, encoder: nn.Module, embed_dim: int = 64):
        super().__init__()
        self.encoder = encoder

        # 관계 모듈: 두 임베딩을 연결하여 관계 점수 출력
        self.relation_module = nn.Sequential(
            nn.Linear(embed_dim * 2, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(
        self,
        support_x: torch.Tensor,
        support_y: torch.Tensor,
        query_x: torch.Tensor,
        n_way: int
    ) -> torch.Tensor:
        """
        Returns: 관계 점수 (n_query, n_way)
        """
        support_emb = self.encoder(support_x)
        query_emb = self.encoder(query_x)

        # 클래스별 프로토타입
        prototypes = []
        for cls in range(n_way):
            mask = (support_y == cls)
            proto = support_emb[mask].mean(dim=0)
            prototypes.append(proto)
        prototypes = torch.stack(prototypes)  # (n_way, D)

        n_query = query_emb.size(0)

        # 각 쿼리와 각 프로토타입 쌍에 대한 관계 점수
        query_expanded = query_emb.unsqueeze(1).expand(
            n_query, n_way, -1
        )  # (n_query, n_way, D)
        proto_expanded = prototypes.unsqueeze(0).expand(
            n_query, n_way, -1
        )  # (n_query, n_way, D)

        # 연결
        pairs = torch.cat(
            [query_expanded, proto_expanded], dim=-1
        )  # (n_query, n_way, 2D)

        # 관계 점수 계산
        scores = self.relation_module(
            pairs.view(-1, pairs.size(-1))
        ).view(n_query, n_way)

        return scores

3. 최적화 기반 메타러닝: MAML

3.1 MAML의 핵심 아이디어

MAML(Model-Agnostic Meta-Learning)은 Finn 등이 2017년 제안한 알고리즘으로, 메타러닝에서 가장 영향력 있는 연구 중 하나이다.

MAML의 목표는 **"빠르게 적응할 수 있는 초기 파라미터 theta를 찾는 것"**이다.

구체적으로, 새로운 태스크가 주어졌을 때 단 몇 번의 그래디언트 업데이트만으로 좋은 성능을 낼 수 있는 초기점을 찾는다.

3.2 내부 루프 vs 외부 루프

MAML은 두 루프로 구성된다.

내부 루프 (Inner Loop / Task-specific adaptation)

각 태스크 T_i에 대해:

theta_i' = theta - alpha * grad_theta L_{T_i}(f_theta)

support 셋으로 1~5번의 그래디언트 업데이트를 수행하여 태스크별 파라미터 theta_i'를 얻는다.

외부 루프 (Outer Loop / Meta-update)

theta = theta - beta * grad_theta sum_i L_{T_i}(f_{theta_i'})

태스크별로 적응된 파라미터 theta_i'를 사용하여 query 셋에서 손실을 계산하고, 이를 바탕으로 메타 파라미터 theta를 업데이트한다.

3.3 이중 역전파 (Second-order Gradients)

MAML의 핵심 기술적 도전은 외부 루프의 그래디언트가 **내부 루프를 통한 이중 역전파(second-order gradients)**를 요구한다는 것이다.

grad_theta L(f_{theta_i'}) = grad_theta L(f_{theta - alpha * grad L(theta)})

이는 theta에 대한 그래디언트를 계산할 때, 내부 루프 업데이트도 미분해야 한다는 의미다 (헤시안 행렬이 포함됨). 계산 비용이 높다.

실용적으로는 **FOMAML (First-Order MAML)**을 사용한다. 이차 미분 항을 무시하고 근사 그래디언트를 사용한다.

3.4 완전한 MAML 구현

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from copy import deepcopy
from typing import List, Tuple


class MAML:
    """
    MAML: Model-Agnostic Meta-Learning
    Finn et al., 2017 (arXiv:1703.03400)
    """

    def __init__(
        self,
        model: nn.Module,
        inner_lr: float = 0.01,       # alpha: 내부 루프 학습률
        outer_lr: float = 0.001,      # beta: 외부 루프 학습률
        n_inner_steps: int = 5,       # 내부 루프 업데이트 횟수
        first_order: bool = False,    # FOMAML 사용 여부
        device: str = 'cpu'
    ):
        self.model = model.to(device)
        self.inner_lr = inner_lr
        self.outer_lr = outer_lr
        self.n_inner_steps = n_inner_steps
        self.first_order = first_order
        self.device = device

        self.meta_optimizer = torch.optim.Adam(
            self.model.parameters(), lr=outer_lr
        )

    def inner_loop(
        self,
        support_x: torch.Tensor,
        support_y: torch.Tensor,
        model_params=None
    ) -> dict:
        """
        내부 루프: support 셋으로 태스크별 적응
        Returns: 적응된 파라미터 딕셔너리
        """
        # 현재 파라미터 복사
        if model_params is None:
            params = {
                name: param.clone()
                for name, param in self.model.named_parameters()
            }
        else:
            params = {k: v.clone() for k, v in model_params.items()}

        for step in range(self.n_inner_steps):
            # 현재 파라미터로 순전파 (functional API 사용)
            logits = self._forward_with_params(support_x, params)
            loss = F.cross_entropy(logits, support_y)

            # 파라미터에 대한 그래디언트
            grads = torch.autograd.grad(
                loss,
                params.values(),
                create_graph=not self.first_order  # 이차 미분을 위해
            )

            # 파라미터 업데이트 (SGD)
            params = {
                name: param - self.inner_lr * grad
                for (name, param), grad in zip(params.items(), grads)
            }

        return params

    def _forward_with_params(
        self,
        x: torch.Tensor,
        params: dict
    ) -> torch.Tensor:
        """
        특정 파라미터를 사용한 순전파
        torch.nn.utils.stateless 또는 functorch 사용
        """
        # 모델의 현재 파라미터를 임시로 교체
        original_params = {}
        for name, param in self.model.named_parameters():
            original_params[name] = param.data
            param.data = params[name].data if name in params else param.data

        output = self.model(x)

        # 원래 파라미터 복원 (create_graph=True인 경우 필요 없음)
        for name, param in self.model.named_parameters():
            if name in original_params:
                param.data = original_params[name]

        return output

    def meta_train_step(
        self,
        tasks: List[Tuple]
    ) -> float:
        """
        하나의 메타 훈련 스텝 (여러 태스크에 대해)
        tasks: [(support_x, support_y, query_x, query_y), ...]
        """
        self.meta_optimizer.zero_grad()

        meta_loss = 0.0

        for support_x, support_y, query_x, query_y in tasks:
            support_x = support_x.to(self.device)
            support_y = support_y.to(self.device)
            query_x = query_x.to(self.device)
            query_y = query_y.to(self.device)

            # 내부 루프: 태스크별 적응
            adapted_params = self.inner_loop(support_x, support_y)

            # 외부 루프: 적응된 파라미터로 query 셋 평가
            query_logits = self._forward_with_params(
                query_x, adapted_params
            )
            query_loss = F.cross_entropy(query_logits, query_y)
            meta_loss += query_loss

        # 태스크 수로 평균
        meta_loss /= len(tasks)

        # 메타 파라미터 업데이트
        meta_loss.backward()
        self.meta_optimizer.step()

        return meta_loss.item()

    def fine_tune(
        self,
        support_x: torch.Tensor,
        support_y: torch.Tensor,
        n_steps: int = None
    ) -> nn.Module:
        """
        새로운 태스크에 대한 파인튜닝 (추론 시 사용)
        """
        n_steps = n_steps or self.n_inner_steps
        model_copy = deepcopy(self.model)
        optimizer = torch.optim.SGD(
            model_copy.parameters(), lr=self.inner_lr
        )

        model_copy.train()
        for step in range(n_steps):
            optimizer.zero_grad()
            logits = model_copy(support_x.to(self.device))
            loss = F.cross_entropy(logits, support_y.to(self.device))
            loss.backward()
            optimizer.step()

        return model_copy

    def evaluate(
        self,
        tasks: List[Tuple],
        n_fine_tune_steps: int = 5
    ) -> Tuple[float, float]:
        """
        메타 테스트: 새로운 태스크들에 대해 적응 후 평가
        """
        total_loss = 0.0
        total_acc = 0.0

        for support_x, support_y, query_x, query_y in tasks:
            # 파인튜닝
            adapted_model = self.fine_tune(support_x, support_y, n_fine_tune_steps)
            adapted_model.eval()

            with torch.no_grad():
                query_logits = adapted_model(query_x.to(self.device))
                loss = F.cross_entropy(query_logits, query_y.to(self.device))
                preds = query_logits.argmax(dim=-1)
                acc = (preds == query_y.to(self.device)).float().mean()

            total_loss += loss.item()
            total_acc += acc.item()

        return total_loss / len(tasks), total_acc / len(tasks)


# ========== MAML 훈련 루프 ==========

def train_maml(
    maml: MAML,
    dataset,
    n_way: int = 5,
    k_shot: int = 1,
    n_query: int = 15,
    meta_batch_size: int = 32,
    n_iterations: int = 60000,
    device: str = 'cpu'
):
    """MAML 훈련"""
    print(f"MAML 훈련 시작: {n_way}-way {k_shot}-shot")
    print(f"메타 배치 크기: {meta_batch_size}")
    print(f"총 이터레이션: {n_iterations}")

    losses = []

    for iteration in range(n_iterations):
        # 메타 배치 생성
        tasks = []
        for _ in range(meta_batch_size):
            task = create_episode(dataset, n_way, k_shot, n_query)
            tasks.append(task)

        # 메타 훈련 스텝
        meta_loss = maml.meta_train_step(tasks)
        losses.append(meta_loss)

        if (iteration + 1) % 1000 == 0:
            avg_loss = np.mean(losses[-1000:])
            print(f"이터레이션 {iteration+1}/{n_iterations} | 메타 손실: {avg_loss:.4f}")

    return losses

3.5 Reptile: MAML의 단순화

Reptile(Nichol et al., 2018)은 MAML을 크게 단순화한 알고리즘이다. 이차 미분이 필요 없고 구현이 훨씬 간단하다.

핵심 아이디어: 태스크에 대해 SGD를 여러 번 수행한 후, 최종 파라미터 방향으로 메타 파라미터를 이동시킨다.

theta = theta + epsilon * (W_k - theta)

여기서 W_k는 태스크 T에서 k번 SGD 업데이트 후의 파라미터이다.

class Reptile:
    """
    Reptile: A Scalable Meta-learning Algorithm
    Nichol et al., 2018 (arXiv:1803.02999)
    """

    def __init__(
        self,
        model: nn.Module,
        inner_lr: float = 0.02,
        outer_lr: float = 0.001,
        n_inner_steps: int = 5,
        device: str = 'cpu'
    ):
        self.model = model.to(device)
        self.inner_lr = inner_lr
        self.outer_lr = outer_lr  # epsilon
        self.n_inner_steps = n_inner_steps
        self.device = device

    def inner_train(
        self,
        support_x: torch.Tensor,
        support_y: torch.Tensor
    ) -> dict:
        """태스크별 내부 훈련 (SGD k번)"""
        model_copy = deepcopy(self.model)
        optimizer = torch.optim.SGD(
            model_copy.parameters(), lr=self.inner_lr
        )

        model_copy.train()
        for step in range(self.n_inner_steps):
            optimizer.zero_grad()
            logits = model_copy(support_x)
            loss = F.cross_entropy(logits, support_y)
            loss.backward()
            optimizer.step()

        return dict(model_copy.named_parameters())

    def meta_update(self, task_params_list: List[dict]):
        """
        Reptile 메타 업데이트:
        theta += epsilon * (mean(W_k) - theta)
        """
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                # 태스크별 파라미터의 평균
                task_mean = torch.stack([
                    task_params[name].data
                    for task_params in task_params_list
                ]).mean(dim=0)

                # Reptile 업데이트
                param.data += self.outer_lr * (task_mean - param.data)

    def train(
        self,
        dataset,
        n_way: int = 5,
        k_shot: int = 5,
        meta_batch_size: int = 5,
        n_iterations: int = 100000
    ):
        """Reptile 전체 훈련 루프"""
        print(f"Reptile 훈련 시작: {n_way}-way {k_shot}-shot")

        for iteration in range(n_iterations):
            task_params_list = []

            for _ in range(meta_batch_size):
                support_x, support_y, _, _ = create_episode(
                    dataset, n_way, k_shot, n_query=0
                )
                support_x = support_x.to(self.device)
                support_y = support_y.to(self.device)

                task_params = self.inner_train(support_x, support_y)
                task_params_list.append(task_params)

            self.meta_update(task_params_list)

            if (iteration + 1) % 10000 == 0:
                print(f"이터레이션 {iteration+1}/{n_iterations} 완료")

4. LLM의 In-Context Learning

4.1 In-Context Learning이란?

In-Context Learning(ICL)은 대형 언어 모델(LLM)이 프롬프트 내의 예시들로부터 새로운 태스크를 수행하는 능력이다. 모델 파라미터를 업데이트하지 않고, 오직 입력 컨텍스트(프롬프트)만으로 태스크를 학습한다.

GPT-3의 등장과 함께 이 능력이 크게 주목받았다. 예를 들어:

영어→프랑스어 번역:
sea otter => loutre de mer
peppermint => menthe poivrée
plush giraffe => girafe en peluche
cheese => ?

이 형식의 프롬프트를 주면 GPT-3는 "fromage"라고 답한다. 프랑스어를 특별히 훈련하지 않은 것처럼 보이지만, 프리트레이닝에서 이미 패턴을 학습한 것이다.

4.2 왜 효과적인가?

ICL이 효과적인 이유에 대한 설명은 여전히 활발히 연구 중이지만, 주요 가설은 다음과 같다.

패턴 완성 관점: LLM은 다음 토큰 예측으로 훈련되었다. 프롬프트의 패턴(입력→출력 쌍)을 보고 그 패턴을 계속하도록 학습되어 있다.

잠재 개념 추론 관점: Brown 등의 연구에 따르면, ICL은 모델이 프롬프트에서 잠재 개념(latent concept)을 추론하는 베이지안 추론과 유사하다.

그래디언트 하강 메타포: Akyürek 등은 Transformer의 어텐션 메커니즘이 암묵적으로 그래디언트 하강과 유사한 연산을 수행한다는 것을 보였다.

4.3 효과적인 퓨샷 프롬프트 전략

from typing import List, Dict, Any
import openai
import anthropic


class FewShotPromptBuilder:
    """퓨샷 프롬프트 생성기"""

    def __init__(self):
        self.examples = []
        self.instruction = ""
        self.template = "{input} => {output}"

    def set_instruction(self, instruction: str):
        """태스크 지시문 설정"""
        self.instruction = instruction
        return self

    def add_example(self, input_text: str, output_text: str):
        """예시 추가"""
        self.examples.append({
            'input': input_text,
            'output': output_text
        })
        return self

    def build_prompt(self, query: str) -> str:
        """완전한 퓨샷 프롬프트 생성"""
        parts = []

        if self.instruction:
            parts.append(self.instruction)
            parts.append("")

        for ex in self.examples:
            parts.append(self.template.format(
                input=ex['input'],
                output=ex['output']
            ))

        # 쿼리 (출력은 비워둠)
        parts.append(f"{query} =>")

        return "\n".join(parts)

    def build_chat_messages(
        self,
        query: str,
        system_prompt: str = None
    ) -> List[Dict]:
        """Chat 형식 메시지 생성 (GPT-4, Claude 등에 사용)"""
        messages = []

        if system_prompt:
            messages.append({
                'role': 'system',
                'content': system_prompt
            })

        # 예시를 대화 형식으로 변환
        for ex in self.examples:
            messages.append({
                'role': 'user',
                'content': ex['input']
            })
            messages.append({
                'role': 'assistant',
                'content': ex['output']
            })

        # 실제 쿼리
        messages.append({
            'role': 'user',
            'content': query
        })

        return messages


class DynamicExampleSelector:
    """
    동적 예시 선택기
    쿼리와 가장 유사한 예시를 선택하여 더 효과적인 퓨샷 학습
    """

    def __init__(self, examples: List[Dict], encoder=None):
        self.examples = examples
        self.encoder = encoder  # sentence-transformers 등

    def select_similar(
        self,
        query: str,
        n_examples: int = 3
    ) -> List[Dict]:
        """
        쿼리와 가장 유사한 n개의 예시 선택
        """
        if self.encoder is None:
            # 인코더가 없으면 랜덤 선택
            return np.random.choice(
                self.examples, n_examples, replace=False
            ).tolist()

        # 의미적 유사도 기반 선택
        query_emb = self.encoder.encode(query)
        example_embs = self.encoder.encode(
            [ex['input'] for ex in self.examples]
        )

        # 코사인 유사도
        similarities = np.dot(example_embs, query_emb) / (
            np.linalg.norm(example_embs, axis=1)
            * np.linalg.norm(query_emb)
        )

        top_indices = np.argsort(similarities)[-n_examples:][::-1]
        return [self.examples[i] for i in top_indices]

    def select_diverse(
        self,
        n_examples: int = 3
    ) -> List[Dict]:
        """
        다양성을 최대화하는 예시 선택 (MMR 알고리즘)
        """
        if len(self.examples) <= n_examples:
            return self.examples

        if self.encoder is None:
            return np.random.choice(
                self.examples, n_examples, replace=False
            ).tolist()

        embeddings = self.encoder.encode(
            [ex['input'] for ex in self.examples]
        )

        selected = [0]  # 첫 번째 예시 선택
        remaining = list(range(1, len(self.examples)))

        while len(selected) < n_examples:
            # 선택된 예시들과의 최대 유사도
            selected_embs = embeddings[selected]
            best_idx = None
            best_score = float('-inf')

            for idx in remaining:
                # MMR: 쿼리와의 유사도 - 선택된 예시들과의 유사도
                sim_to_selected = np.max(
                    np.dot(selected_embs, embeddings[idx]) / (
                        np.linalg.norm(selected_embs, axis=1)
                        * np.linalg.norm(embeddings[idx])
                    )
                )
                score = -sim_to_selected  # 다양성 최대화

                if score > best_score:
                    best_score = score
                    best_idx = idx

            selected.append(best_idx)
            remaining.remove(best_idx)

        return [self.examples[i] for i in selected]


# ========== 실전 사용 예시 ==========

def sentiment_analysis_few_shot():
    """감성 분석 퓨샷 예시"""
    builder = FewShotPromptBuilder()

    builder.set_instruction("다음 영화 리뷰의 감성을 분석하세요. 긍정(Positive) 또는 부정(Negative)으로 답하세요.")

    builder.add_example(
        "이 영화는 정말 감동적이었고 배우들의 연기가 훌륭했다.",
        "Positive"
    )
    builder.add_example(
        "스토리가 너무 지루하고 결말이 실망스러웠다.",
        "Negative"
    )
    builder.add_example(
        "특수효과는 멋있었지만 시나리오가 너무 허술했다.",
        "Negative"
    )
    builder.add_example(
        "오랜만에 가족과 함께 볼 수 있는 따뜻한 영화였다.",
        "Positive"
    )

    query = "연출이 독특하고 음악이 영상과 완벽하게 어우러졌다."
    prompt = builder.build_prompt(query)

    print("=== 퓨샷 프롬프트 ===")
    print(prompt)
    return prompt


def korean_nlp_few_shot():
    """한국어 NLP 퓨샷 예시: 개체명 인식"""
    builder = FewShotPromptBuilder()

    builder.set_instruction(
        "다음 문장에서 인물(PER), 장소(LOC), 기관(ORG) 개체를 찾아 태깅하세요."
    )

    builder.add_example(
        "이순신 장군이 한산도에서 일본군을 격파했다.",
        "이순신[PER] 장군이 한산도[LOC]에서 일본군[ORG]을 격파했다."
    )
    builder.add_example(
        "삼성전자는 수원 사업장에서 반도체를 생산한다.",
        "삼성전자[ORG]는 수원[LOC] 사업장에서 반도체를 생산한다."
    )

    query = "박지성이 맨체스터 유나이티드에서 활약했다."
    prompt = builder.build_prompt(query)

    print("=== 한국어 NLP 퓨샷 프롬프트 ===")
    print(prompt)
    return prompt

4.4 Cross-lingual Few-shot

다국어 모델은 제로샷 크로스링구얼 전이(zero-shot cross-lingual transfer) 능력을 보인다. 영어로만 훈련된 태스크를 한국어나 다른 언어에 적용할 수 있다.

class CrossLingualFewShot:
    """
    크로스링구얼 퓨샷 학습
    영어 예시 + 타겟 언어 쿼리
    """

    def __init__(self, model_name: str = "xlm-roberta-large"):
        from transformers import AutoTokenizer, AutoModelForSequenceClassification
        import torch

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_name, num_labels=2
        )

    def encode_text(self, text: str) -> torch.Tensor:
        """텍스트를 임베딩으로 인코딩"""
        inputs = self.tokenizer(
            text,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=128
        )
        with torch.no_grad():
            outputs = self.model(**inputs, output_hidden_states=True)
            # [CLS] 토큰의 마지막 은닉 상태
            embedding = outputs.hidden_states[-1][:, 0, :]
        return embedding

    def classify_zero_shot(
        self,
        query_ko: str,
        class_descriptions_en: List[str]
    ) -> int:
        """
        영어 클래스 설명으로 한국어 쿼리 분류
        """
        query_emb = self.encode_text(query_ko)
        class_embs = torch.cat([
            self.encode_text(desc) for desc in class_descriptions_en
        ])

        # 코사인 유사도
        similarities = F.cosine_similarity(
            query_emb.expand(len(class_descriptions_en), -1),
            class_embs
        )

        return similarities.argmax().item()

    def few_shot_classify(
        self,
        support_texts: List[str],
        support_labels: List[int],
        query_text: str,
        n_classes: int
    ) -> int:
        """
        퓨샷 분류 (프로토타입 방식)
        support와 query는 다른 언어 가능
        """
        support_embs = torch.cat([
            self.encode_text(t) for t in support_texts
        ])
        query_emb = self.encode_text(query_text)

        # 클래스별 프로토타입
        prototypes = []
        for cls in range(n_classes):
            mask = torch.tensor([l == cls for l in support_labels])
            proto = support_embs[mask].mean(dim=0, keepdim=True)
            prototypes.append(proto)
        prototypes = torch.cat(prototypes)

        # 유클리드 거리로 분류
        dists = torch.cdist(query_emb, prototypes)
        return dists.argmin().item()

5. 실전 활용: 의료 이미지 퓨샷 분류

5.1 희귀 질환 진단 시스템

임상 현장에서 희귀 질환은 훈련 데이터가 매우 부족하다. 퓨샷 학습을 사용하면 소수의 확진 사례만으로도 새로운 질환 패턴을 인식할 수 있다.

import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import numpy as np


class MedicalImageEncoder(nn.Module):
    """
    의료 이미지용 인코더
    ResNet-18 기반, 의료 이미지 특성에 맞게 수정
    """

    def __init__(self, embed_dim: int = 512, pretrained: bool = True):
        super().__init__()

        # ResNet-18 백본
        backbone = models.resnet18(pretrained=pretrained)

        # FC 레이어 제거
        self.backbone = nn.Sequential(*list(backbone.children())[:-1])

        # 임베딩 헤드
        self.embed_head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512, embed_dim),
            nn.LayerNorm(embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.backbone(x)
        return self.embed_head(features)


class MedicalFewShotClassifier:
    """
    의료 이미지 퓨샷 분류기
    Prototypical Networks 기반
    """

    def __init__(
        self,
        encoder: MedicalImageEncoder,
        device: str = 'cpu'
    ):
        self.encoder = encoder.to(device)
        self.device = device
        self.prototypes = {}
        self.disease_names = {}

    def register_disease(
        self,
        disease_id: int,
        disease_name: str,
        support_images: List,
        transform=None
    ):
        """
        새로운 질환 등록 (소수의 예시로)
        """
        self.disease_names[disease_id] = disease_name
        self.encoder.eval()

        embeddings = []
        with torch.no_grad():
            for img in support_images:
                if transform:
                    img_tensor = transform(img).unsqueeze(0).to(self.device)
                else:
                    img_tensor = img.unsqueeze(0).to(self.device)
                emb = self.encoder(img_tensor)
                embeddings.append(emb)

        prototype = torch.cat(embeddings).mean(dim=0)
        self.prototypes[disease_id] = prototype

        print(
            f"질환 등록 완료: {disease_name} "
            f"(예시 {len(support_images)}개)"
        )

    def diagnose(
        self,
        query_image: torch.Tensor,
        top_k: int = 3
    ) -> List[Dict]:
        """
        쿼리 이미지 진단
        Returns: 상위 k개 질환과 유사도 점수
        """
        self.encoder.eval()

        with torch.no_grad():
            query_emb = self.encoder(
                query_image.unsqueeze(0).to(self.device)
            )

        results = []
        for disease_id, prototype in self.prototypes.items():
            # 코사인 유사도
            similarity = F.cosine_similarity(
                query_emb, prototype.unsqueeze(0)
            ).item()
            results.append({
                'disease_id': disease_id,
                'disease_name': self.disease_names[disease_id],
                'similarity': similarity
            })

        # 유사도 순으로 정렬
        results.sort(key=lambda x: x['similarity'], reverse=True)
        return results[:top_k]

    def update_prototype(
        self,
        disease_id: int,
        new_image: torch.Tensor,
        momentum: float = 0.9
    ):
        """
        새로운 확진 사례로 프로토타입 업데이트 (온라인 학습)
        """
        self.encoder.eval()

        with torch.no_grad():
            new_emb = self.encoder(
                new_image.unsqueeze(0).to(self.device)
            ).squeeze(0)

        if disease_id in self.prototypes:
            # 지수 이동 평균으로 업데이트
            self.prototypes[disease_id] = (
                momentum * self.prototypes[disease_id]
                + (1 - momentum) * new_emb
            )
            print(f"프로토타입 업데이트 완료: {self.disease_names[disease_id]}")
        else:
            print(f"경고: 질환 {disease_id}이 등록되지 않았습니다.")


def demo_medical_few_shot():
    """의료 퓨샷 분류 데모"""
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

    encoder = MedicalImageEncoder(embed_dim=512, pretrained=True)
    classifier = MedicalFewShotClassifier(encoder)

    print("=== 의료 이미지 퓨샷 분류 시스템 ===")
    print("이 시스템은 소수의 확진 사례만으로 새로운 질환을 인식합니다.")
    print()

    # 실제 사용 예시
    # 1. 각 질환에 대해 5~10장의 레퍼런스 이미지로 등록
    # 2. 새로운 환자 이미지를 입력하면 유사한 질환 진단
    # 3. 새로운 확진 사례가 나올 때마다 프로토타입 업데이트

    print("사용 방법:")
    print("1. classifier.register_disease(id, name, support_images)")
    print("2. results = classifier.diagnose(patient_image)")
    print("3. classifier.update_prototype(disease_id, new_confirmed_image)")

6. learn2learn 라이브러리 활용

6.1 learn2learn 소개

learn2learn은 MAML, ProtoNet 등 메타러닝 알고리즘을 쉽게 구현할 수 있는 라이브러리이다.

pip install learn2learn

6.2 learn2learn으로 MAML 구현

import learn2learn as l2l
import torch
import torch.nn as nn
from torch.utils.data import DataLoader


def build_l2l_maml(
    model: nn.Module,
    lr: float = 0.01,
    first_order: bool = False
) -> l2l.algorithms.MAML:
    """learn2learn MAML 래퍼 생성"""
    return l2l.algorithms.MAML(
        model,
        lr=lr,
        first_order=first_order,
        allow_unused=True
    )


def train_with_l2l(
    maml_model: l2l.algorithms.MAML,
    tasksets,
    n_way: int = 5,
    k_shot: int = 1,
    n_query: int = 15,
    meta_lr: float = 0.003,
    n_iterations: int = 1000,
    adaptation_steps: int = 1,
    device: str = 'cpu'
):
    """
    learn2learn을 사용한 MAML 훈련
    """
    maml_model = maml_model.to(device)
    meta_optimizer = torch.optim.Adam(
        maml_model.parameters(), lr=meta_lr
    )
    criterion = nn.CrossEntropyLoss(reduction='mean')

    for iteration in range(n_iterations):
        meta_optimizer.zero_grad()
        meta_train_loss = 0.0
        meta_train_acc = 0.0

        # 미니 메타 배치
        for task in range(4):  # 4개 태스크 동시 처리
            # 태스크 샘플링
            X, y = tasksets.train.sample()
            X, y = X.to(device), y.to(device)

            # 분류기 클론 생성 (메타러닝을 위한 복사본)
            learner = maml_model.clone()

            # Support / Query 분리
            support_indices = torch.zeros(X.size(0), dtype=torch.bool)
            for cls in range(n_way):
                cls_idx = (y == cls).nonzero(as_tuple=True)[0]
                support_idx = cls_idx[:k_shot]
                support_indices[support_idx] = True

            query_indices = ~support_indices
            support_x, support_y = X[support_indices], y[support_indices]
            query_x, query_y = X[query_indices], y[query_indices]

            # 내부 루프: adaptation
            for step in range(adaptation_steps):
                support_logits = learner(support_x)
                support_loss = criterion(support_logits, support_y)
                learner.adapt(support_loss)

            # 외부 루프: meta-gradient
            query_logits = learner(query_x)
            query_loss = criterion(query_logits, query_y)
            meta_train_loss += query_loss

            # 정확도
            preds = query_logits.argmax(dim=-1)
            acc = (preds == query_y).float().mean()
            meta_train_acc += acc

        meta_train_loss /= 4
        meta_train_acc /= 4

        meta_train_loss.backward()
        meta_optimizer.step()

        if (iteration + 1) % 100 == 0:
            print(
                f"이터레이션 {iteration+1}/{n_iterations} | "
                f"메타 손실: {meta_train_loss.item():.4f} | "
                f"메타 정확도: {meta_train_acc.item():.4f}"
            )


def setup_omniglot_maml():
    """
    Omniglot 데이터셋으로 MAML 설정
    Omniglot: 50개 언어의 알파벳 문자 (1623 클래스, 각 20개 샘플)
    """
    # learn2learn의 벤치마크 데이터셋 사용
    tasksets = l2l.vision.benchmarks.get_tasksets(
        'omniglot',
        train_ways=5,
        train_samples=2 * 1 + 2 * 15,  # k_shot + n_query
        test_ways=5,
        test_samples=2 * 1 + 2 * 15,
        root='./data',
        device='cpu'
    )

    # CNN 모델
    model = l2l.vision.models.OmniglotCNN(
        output_size=5,
        hidden_size=64,
        layers=4
    )

    # MAML 래퍼
    maml = build_l2l_maml(model, lr=0.4, first_order=False)

    return maml, tasksets


def evaluate_l2l(
    maml_model: l2l.algorithms.MAML,
    tasksets,
    n_way: int = 5,
    k_shot: int = 1,
    n_query: int = 15,
    n_test_tasks: int = 600,
    adaptation_steps: int = 3,
    device: str = 'cpu'
) -> Tuple[float, float]:
    """learn2learn 메타 평가"""
    criterion = nn.CrossEntropyLoss()
    total_loss = 0.0
    total_acc = 0.0

    maml_model.eval()

    for _ in range(n_test_tasks):
        X, y = tasksets.test.sample()
        X, y = X.to(device), y.to(device)

        learner = maml_model.clone()

        support_indices = torch.zeros(X.size(0), dtype=torch.bool)
        for cls in range(n_way):
            cls_idx = (y == cls).nonzero(as_tuple=True)[0]
            support_idx = cls_idx[:k_shot]
            support_indices[support_idx] = True

        query_indices = ~support_indices
        support_x, support_y = X[support_indices], y[support_indices]
        query_x, query_y = X[query_indices], y[query_indices]

        # 테스트 시 더 많은 적응 스텝 사용
        for step in range(adaptation_steps):
            support_loss = criterion(learner(support_x), support_y)
            learner.adapt(support_loss)

        with torch.no_grad():
            query_logits = learner(query_x)
            loss = criterion(query_logits, query_y)
            acc = (query_logits.argmax(dim=-1) == query_y).float().mean()

        total_loss += loss.item()
        total_acc += acc.item()

    return total_loss / n_test_tasks, total_acc / n_test_tasks

7. 메타러닝 벤치마크

7.1 주요 벤치마크 데이터셋

Omniglot

50개 알파벳 체계의 1623개 문자 클래스. 각 클래스당 20개 샘플. 주로 20-way 1-shot 설정에서 평가한다.

Mini-ImageNet

ImageNet의 100개 클래스 서브셋. 각 클래스당 600개 이미지 (84x84). 5-way 1/5-shot 설정이 표준이다.

tieredImageNet

Mini-ImageNet보다 더 어려운 버전. 상위 개념으로 그룹화된 클래스를 사용하여 메타-훈련과 메타-테스트 클래스 간의 의미적 격차를 크게 한다.

CIFAR-FS

CIFAR-100에서 파생된 퓨샷 벤치마크. Mini-ImageNet보다 빠르게 실험할 수 있다.

7.2 평가 프로토콜

def standard_few_shot_evaluation(
    model,
    test_dataset,
    n_way: int = 5,
    k_shot: int = 1,
    n_query: int = 15,
    n_episodes: int = 600,
    confidence_interval: bool = True
) -> Dict:
    """
    표준 퓨샷 평가 프로토콜
    600개 에피소드의 평균과 95% 신뢰 구간
    """
    accs = []
    model.eval()

    for episode in range(n_episodes):
        support_x, support_y, query_x, query_y = create_episode(
            test_dataset, n_way, k_shot, n_query
        )

        with torch.no_grad():
            log_probs = model(support_x, support_y, query_x, n_way)
            preds = log_probs.argmax(dim=-1)
            acc = (preds == query_y).float().mean().item()

        accs.append(acc)

    mean_acc = np.mean(accs)
    std_acc = np.std(accs)

    if confidence_interval:
        # 95% 신뢰 구간
        ci = 1.96 * std_acc / np.sqrt(n_episodes)
        return {
            'mean_accuracy': mean_acc,
            'std': std_acc,
            'confidence_interval_95': ci,
            'result_string': f"{mean_acc*100:.2f} ± {ci*100:.2f}%"
        }

    return {'mean_accuracy': mean_acc, 'std': std_acc}

8. 마무리: 메타러닝의 미래

메타러닝은 AI 연구의 핵심 패러다임 중 하나로 자리잡고 있다. 특히 주목할 트렌드는 다음과 같다.

LLM과의 융합

GPT-4, Claude 같은 대형 언어 모델은 강력한 ICL 능력을 보여준다. 이 모델들이 임의의 도메인에서 퓨샷 학습을 수행하는 메타 러너로 작동한다는 관점에서 연구가 활발하다.

멀티모달 퓨샷 학습

텍스트, 이미지, 음성을 통합한 멀티모달 퓨샷 학습. GPT-4V, Gemini Ultra 등이 시각 퓨샷 태스크에서 인상적인 성능을 보이고 있다.

지속 학습(Continual Learning)과의 결합

메타러닝으로 초기화된 모델은 새로운 태스크를 학습할 때 이전 지식을 덜 잊어버린다는 연구 결과가 있다. 지속 학습과 메타러닝의 결합이 활발히 연구되고 있다.


참고 자료

  • Finn, C., et al. (2017). Model-Agnostic Meta-Learning for Fast Adaptation. ICML 2017. arXiv:1703.03400
  • Snell, J., et al. (2017). Prototypical Networks for Few-shot Learning. NeurIPS 2017. arXiv:1703.05175
  • Vinyals, O., et al. (2016). Matching Networks for One Shot Learning. NeurIPS 2016. arXiv:1606.04080
  • Nichol, A., et al. (2018). On First-Order Meta-Learning Algorithms. arXiv:1803.02999
  • Brown, T., et al. (2020). Language Models are Few-Shot Learners (GPT-3). NeurIPS 2020.
  • learn2learn library: https://github.com/learnables/learn2learn

Meta-Learning and Few-Shot Learning Complete Guide: MAML, Prototypical Networks, In-Context Learning

Meta-Learning and Few-Shot Learning Complete Guide: MAML, Prototypical Networks, In-Context Learning

Humans learn new concepts quickly from just a few examples. Show a child "this is a zebra" once, and they can recognize a zebra in an entirely new photograph the next day. But traditional deep learning models require thousands of images just to build a zebra classifier.

Meta-learning and few-shot learning bridge this gap. The core idea of meta-learning is "learning to learn." The model gains the ability to rapidly adapt to new tasks by experiencing a variety of tasks during training.

This guide covers everything from the theoretical foundations of meta-learning to the latest developments in In-Context Learning.


1. Meta-Learning Fundamentals

1.1 Learning to Learn

In traditional machine learning, a model is trained for a single task. A cat classifier is trained only on cat data, and when a new animal (e.g., an ocelot) needs to be classified, training must start from scratch.

Meta-learning shifts the perspective. What the model needs to learn is not "how to classify cats" but "how to learn to quickly classify new animals."

Meta-learning therefore has two levels of learning.

  • Meta-learning (Outer Loop): Learning a good initialization or learning algorithm by experiencing many tasks
  • Task learning (Inner Loop): Rapidly adapting to each specific task from a small number of examples

1.2 Limitations of Traditional Learning

The limitations of traditional learning are:

Data inefficiency: A model trained on ImageNet requires millions of images. Adding new classes requires thousands of samples.

Lack of generalization: Performance drops sharply on new tasks that differ significantly from the training distribution.

Catastrophic forgetting: Learning new tasks causes the model to forget previously learned tasks.

1.3 Task Distribution

The central concept in meta-learning is the task distribution p(T). Meta-learning does not simply learn from data; it learns from a distribution of tasks.

Each task T includes:

  • A distribution p(x, y) of input-output pairs
  • A task loss function L

The meta-learning objective is:

min over theta: E over T ~ p(T) [L_T(f_theta)]

1.4 Support Set vs Query Set

In few-shot learning, data is divided into two roles.

Support Set: A small number of examples used as references when the model learns a new task. These correspond to training data in traditional learning, but are extremely few (e.g., 1–5 per class).

Query Set: Data used to evaluate model performance. These correspond to test data in traditional learning.

1.5 N-way K-shot Setup

The most important configuration in few-shot learning is N-way K-shot.

  • N-way: Number of classes to classify
  • K-shot: Number of support examples per class

For example, 5-way 1-shot is a task of classifying 5 classes with only 1 example per class. 5-way 5-shot uses 5 examples per class.

import torch
import torch.nn as nn
import numpy as np
from typing import List, Tuple, Dict


def create_episode(
    dataset,
    n_way: int,
    k_shot: int,
    n_query: int,
    classes: List[int] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Create N-way K-shot episode
    Returns: (support_x, support_y, query_x, query_y)
    """
    all_classes = list(set(dataset.targets.tolist()))
    if classes is None:
        selected_classes = np.random.choice(
            all_classes, n_way, replace=False
        )
    else:
        selected_classes = classes

    support_x, support_y = [], []
    query_x, query_y = [], []

    for new_label, cls in enumerate(selected_classes):
        cls_indices = (dataset.targets == cls).nonzero(as_tuple=True)[0]
        chosen = np.random.choice(
            len(cls_indices), k_shot + n_query, replace=False
        )

        for i, idx in enumerate(cls_indices[chosen]):
            x, _ = dataset[idx.item()]
            if i < k_shot:
                support_x.append(x)
                support_y.append(new_label)
            else:
                query_x.append(x)
                query_y.append(new_label)

    support_x = torch.stack(support_x)
    support_y = torch.tensor(support_y)
    query_x = torch.stack(query_x)
    query_y = torch.tensor(query_y)

    return support_x, support_y, query_x, query_y

2. Distance-Based Meta-Learning

2.1 Matching Networks

Matching Networks, proposed by Vinyals et al. in 2016, combines attention mechanisms with the kNN idea. The core idea is to predict a query sample as an attention-weighted sum of support set labels.

Prediction formula

y_hat = sum over i: a(x_hat, x_i) * y_i

Here a(x_hat, x_i) is the attention weight between query x_hat and support sample x_i, computed as a softmax over cosine similarities.

class MatchingNetworks(nn.Module):
    """Matching Networks implementation"""

    def __init__(self, encoder: nn.Module, use_fce: bool = False):
        """
        encoder: feature extractor
        use_fce: whether to use Full Context Embedding
        """
        super().__init__()
        self.encoder = encoder
        self.use_fce = use_fce

    def cosine_similarity(
        self,
        query: torch.Tensor,
        support: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute cosine similarity
        query: (n_query, embed_dim)
        support: (n_support, embed_dim)
        Returns: (n_query, n_support)
        """
        query_norm = nn.functional.normalize(query, dim=-1)
        support_norm = nn.functional.normalize(support, dim=-1)
        return torch.mm(query_norm, support_norm.t())

    def forward(
        self,
        support_x: torch.Tensor,
        support_y: torch.Tensor,
        query_x: torch.Tensor,
        n_way: int
    ) -> torch.Tensor:
        """
        support_x: (n_way * k_shot, C, H, W)
        support_y: (n_way * k_shot,)
        query_x: (n_query, C, H, W)
        """
        support_emb = self.encoder(support_x)   # (n_support, D)
        query_emb = self.encoder(query_x)        # (n_query, D)

        similarities = self.cosine_similarity(
            query_emb, support_emb
        )  # (n_query, n_support)

        # Softmax attention
        attention = nn.functional.softmax(similarities, dim=-1)

        # One-hot labels
        support_labels_one_hot = nn.functional.one_hot(
            support_y, n_way
        ).float()  # (n_support, n_way)

        # Attention-weighted prediction
        logits = torch.mm(attention, support_labels_one_hot)
        return logits

2.2 Prototypical Networks

Prototypical Networks, published by Snell et al. in 2017, is one of the most elegant and intuitive meta-learning algorithms. The core idea: represent each class as a single prototype (centroid) in embedding space.

Prototype computation

The prototype of class c is the mean of the embeddings of the support samples for that class.

p_c = (1/|S_c|) sum over (x_i, y_i) in S_c: f_phi(x_i)

Classification

Classify a query sample x to the nearest prototype.

p(y=c | x) = softmax(-d(f_phi(x), p_c))

where d is the Euclidean distance.

class ConvEncoder(nn.Module):
    """4-layer CNN encoder for few-shot learning"""

    def __init__(
        self,
        in_channels: int = 1,
        hidden_dim: int = 64,
        out_dim: int = 64
    ):
        super().__init__()

        def conv_block(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(),
                nn.MaxPool2d(2)
            )

        self.net = nn.Sequential(
            conv_block(in_channels, hidden_dim),
            conv_block(hidden_dim, hidden_dim),
            conv_block(hidden_dim, hidden_dim),
            conv_block(hidden_dim, out_dim),
            nn.Flatten()
        )

    def forward(self, x):
        return self.net(x)


class PrototypicalNetworks(nn.Module):
    """Complete Prototypical Networks implementation"""

    def __init__(self, encoder: nn.Module):
        super().__init__()
        self.encoder = encoder

    def compute_prototypes(
        self,
        support_x: torch.Tensor,
        support_y: torch.Tensor,
        n_way: int
    ) -> torch.Tensor:
        """
        Compute class prototypes (mean of embeddings)
        support_x: (n_way * k_shot, C, H, W)
        support_y: (n_way * k_shot,)
        Returns: (n_way, embed_dim)
        """
        support_emb = self.encoder(support_x)  # (n_support, D)
        prototypes = []

        for cls in range(n_way):
            mask = (support_y == cls)
            cls_embeddings = support_emb[mask]
            prototype = cls_embeddings.mean(dim=0)
            prototypes.append(prototype)

        return torch.stack(prototypes)  # (n_way, D)

    def euclidean_dist(
        self,
        x: torch.Tensor,
        y: torch.Tensor
    ) -> torch.Tensor:
        """
        Euclidean distance
        x: (n, D)
        y: (m, D)
        Returns: (n, m)
        """
        # ||x - y||^2 = ||x||^2 + ||y||^2 - 2*x.y
        n = x.size(0)
        m = y.size(0)
        x_sq = (x ** 2).sum(dim=1, keepdim=True).expand(n, m)
        y_sq = (y ** 2).sum(dim=1, keepdim=True).expand(m, n).t()
        xy = torch.mm(x, y.t())
        dist = x_sq + y_sq - 2 * xy
        return dist.clamp(min=0).sqrt()

    def forward(
        self,
        support_x: torch.Tensor,
        support_y: torch.Tensor,
        query_x: torch.Tensor,
        n_way: int
    ) -> torch.Tensor:
        """
        Prototypical Networks forward pass
        Returns: log-probabilities for query samples (n_query, n_way)
        """
        prototypes = self.compute_prototypes(
            support_x, support_y, n_way
        )  # (n_way, D)

        query_emb = self.encoder(query_x)  # (n_query, D)

        dists = self.euclidean_dist(
            query_emb, prototypes
        )  # (n_query, n_way)

        log_probs = nn.functional.log_softmax(-dists, dim=-1)
        return log_probs


# ========== Training Loop ==========

def train_prototypical(
    model: PrototypicalNetworks,
    train_dataset,
    n_way: int = 5,
    k_shot: int = 5,
    n_query: int = 15,
    n_episodes: int = 100,
    lr: float = 1e-3,
    device: str = 'cpu'
):
    """Train Prototypical Networks"""
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.NLLLoss()

    model.train()
    episode_losses = []
    episode_accs = []

    for episode in range(n_episodes):
        support_x, support_y, query_x, query_y = create_episode(
            train_dataset, n_way, k_shot, n_query
        )
        support_x = support_x.to(device)
        support_y = support_y.to(device)
        query_x = query_x.to(device)
        query_y = query_y.to(device)

        optimizer.zero_grad()

        log_probs = model(support_x, support_y, query_x, n_way)
        loss = criterion(log_probs, query_y)
        loss.backward()
        optimizer.step()

        preds = log_probs.argmax(dim=-1)
        acc = (preds == query_y).float().mean().item()

        episode_losses.append(loss.item())
        episode_accs.append(acc)

        if (episode + 1) % 100 == 0:
            print(
                f"Episode {episode+1}/{n_episodes} | "
                f"Loss: {np.mean(episode_losses[-100:]):.4f} | "
                f"Acc: {np.mean(episode_accs[-100:]):.4f}"
            )

    return episode_losses, episode_accs

2.3 Relation Networks

Relation Networks by Sung et al. are similar to Prototypical Networks, but the distance function is replaced with a learnable neural network. The network learns to compute a relation score by concatenating the query embedding and the class prototype.

class RelationNetwork(nn.Module):
    """Relation Networks: learnable distance function"""

    def __init__(self, encoder: nn.Module, embed_dim: int = 64):
        super().__init__()
        self.encoder = encoder

        # Relation module: takes concatenated embeddings, outputs relation score
        self.relation_module = nn.Sequential(
            nn.Linear(embed_dim * 2, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(
        self,
        support_x: torch.Tensor,
        support_y: torch.Tensor,
        query_x: torch.Tensor,
        n_way: int
    ) -> torch.Tensor:
        """
        Returns: relation scores (n_query, n_way)
        """
        support_emb = self.encoder(support_x)
        query_emb = self.encoder(query_x)

        prototypes = []
        for cls in range(n_way):
            mask = (support_y == cls)
            proto = support_emb[mask].mean(dim=0)
            prototypes.append(proto)
        prototypes = torch.stack(prototypes)  # (n_way, D)

        n_query = query_emb.size(0)

        query_expanded = query_emb.unsqueeze(1).expand(
            n_query, n_way, -1
        )  # (n_query, n_way, D)
        proto_expanded = prototypes.unsqueeze(0).expand(
            n_query, n_way, -1
        )  # (n_query, n_way, D)

        pairs = torch.cat(
            [query_expanded, proto_expanded], dim=-1
        )  # (n_query, n_way, 2D)

        scores = self.relation_module(
            pairs.view(-1, pairs.size(-1))
        ).view(n_query, n_way)

        return scores

3. Optimization-Based Meta-Learning: MAML

3.1 The Core Idea of MAML

MAML (Model-Agnostic Meta-Learning), proposed by Finn et al. in 2017, is one of the most influential works in meta-learning.

MAML's goal is to "find an initial set of parameters theta from which fast adaptation is possible."

Concretely, it finds an initialization point from which only a few gradient update steps are needed to achieve good performance on a new task.

3.2 Inner Loop vs Outer Loop

MAML consists of two loops.

Inner Loop (Task-specific adaptation)

For each task T_i:

theta_i' = theta - alpha * grad_theta L_{T_i}(f_theta)

Perform 1–5 gradient updates on the support set to obtain task-specific parameters theta_i'.

Outer Loop (Meta-update)

theta = theta - beta * grad_theta sum_i L_{T_i}(f_{theta_i'})

Compute the loss on the query set using task-adapted parameters theta_i', then update the meta-parameters theta based on this loss.

3.3 Second-Order Gradients

The key technical challenge of MAML is that the outer loop gradient requires second-order gradients (backpropagation through the inner loop).

grad_theta L(f_{theta_i'}) = grad_theta L(f_{theta - alpha * grad L(theta)})

Computing gradients with respect to theta requires differentiating through the inner loop update (involves the Hessian matrix). This is computationally expensive.

In practice, FOMAML (First-Order MAML) is often used, which ignores the second-order gradient terms and uses an approximate gradient.

3.4 Complete MAML Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from copy import deepcopy
from typing import List, Tuple


class MAML:
    """
    MAML: Model-Agnostic Meta-Learning
    Finn et al., 2017 (arXiv:1703.03400)
    """

    def __init__(
        self,
        model: nn.Module,
        inner_lr: float = 0.01,       # alpha: inner loop learning rate
        outer_lr: float = 0.001,      # beta: outer loop learning rate
        n_inner_steps: int = 5,       # number of inner loop updates
        first_order: bool = False,    # whether to use FOMAML
        device: str = 'cpu'
    ):
        self.model = model.to(device)
        self.inner_lr = inner_lr
        self.outer_lr = outer_lr
        self.n_inner_steps = n_inner_steps
        self.first_order = first_order
        self.device = device

        self.meta_optimizer = torch.optim.Adam(
            self.model.parameters(), lr=outer_lr
        )

    def inner_loop(
        self,
        support_x: torch.Tensor,
        support_y: torch.Tensor,
        model_params=None
    ) -> dict:
        """
        Inner loop: task-specific adaptation on the support set
        Returns: adapted parameter dictionary
        """
        if model_params is None:
            params = {
                name: param.clone()
                for name, param in self.model.named_parameters()
            }
        else:
            params = {k: v.clone() for k, v in model_params.items()}

        for step in range(self.n_inner_steps):
            logits = self._forward_with_params(support_x, params)
            loss = F.cross_entropy(logits, support_y)

            grads = torch.autograd.grad(
                loss,
                params.values(),
                create_graph=not self.first_order
            )

            params = {
                name: param - self.inner_lr * grad
                for (name, param), grad in zip(params.items(), grads)
            }

        return params

    def _forward_with_params(
        self,
        x: torch.Tensor,
        params: dict
    ) -> torch.Tensor:
        """Forward pass with specific parameters"""
        original_params = {}
        for name, param in self.model.named_parameters():
            original_params[name] = param.data
            param.data = params[name].data if name in params else param.data

        output = self.model(x)

        for name, param in self.model.named_parameters():
            if name in original_params:
                param.data = original_params[name]

        return output

    def meta_train_step(
        self,
        tasks: List[Tuple]
    ) -> float:
        """
        One meta-training step (over multiple tasks)
        tasks: [(support_x, support_y, query_x, query_y), ...]
        """
        self.meta_optimizer.zero_grad()

        meta_loss = 0.0

        for support_x, support_y, query_x, query_y in tasks:
            support_x = support_x.to(self.device)
            support_y = support_y.to(self.device)
            query_x = query_x.to(self.device)
            query_y = query_y.to(self.device)

            adapted_params = self.inner_loop(support_x, support_y)

            query_logits = self._forward_with_params(
                query_x, adapted_params
            )
            query_loss = F.cross_entropy(query_logits, query_y)
            meta_loss += query_loss

        meta_loss /= len(tasks)
        meta_loss.backward()
        self.meta_optimizer.step()

        return meta_loss.item()

    def fine_tune(
        self,
        support_x: torch.Tensor,
        support_y: torch.Tensor,
        n_steps: int = None
    ) -> nn.Module:
        """
        Fine-tune on a new task (used at inference)
        """
        n_steps = n_steps or self.n_inner_steps
        model_copy = deepcopy(self.model)
        optimizer = torch.optim.SGD(
            model_copy.parameters(), lr=self.inner_lr
        )

        model_copy.train()
        for step in range(n_steps):
            optimizer.zero_grad()
            logits = model_copy(support_x.to(self.device))
            loss = F.cross_entropy(logits, support_y.to(self.device))
            loss.backward()
            optimizer.step()

        return model_copy

    def evaluate(
        self,
        tasks: List[Tuple],
        n_fine_tune_steps: int = 5
    ) -> Tuple[float, float]:
        """
        Meta-test: adapt to new tasks and evaluate
        """
        total_loss = 0.0
        total_acc = 0.0

        for support_x, support_y, query_x, query_y in tasks:
            adapted_model = self.fine_tune(
                support_x, support_y, n_fine_tune_steps
            )
            adapted_model.eval()

            with torch.no_grad():
                query_logits = adapted_model(query_x.to(self.device))
                loss = F.cross_entropy(
                    query_logits, query_y.to(self.device)
                )
                preds = query_logits.argmax(dim=-1)
                acc = (preds == query_y.to(self.device)).float().mean()

            total_loss += loss.item()
            total_acc += acc.item()

        return total_loss / len(tasks), total_acc / len(tasks)


def train_maml(
    maml: MAML,
    dataset,
    n_way: int = 5,
    k_shot: int = 1,
    n_query: int = 15,
    meta_batch_size: int = 32,
    n_iterations: int = 60000,
    device: str = 'cpu'
):
    """MAML training loop"""
    print(f"MAML Training: {n_way}-way {k_shot}-shot")
    print(f"Meta-batch size: {meta_batch_size}")
    print(f"Total iterations: {n_iterations}")

    losses = []

    for iteration in range(n_iterations):
        tasks = []
        for _ in range(meta_batch_size):
            task = create_episode(dataset, n_way, k_shot, n_query)
            tasks.append(task)

        meta_loss = maml.meta_train_step(tasks)
        losses.append(meta_loss)

        if (iteration + 1) % 1000 == 0:
            avg_loss = np.mean(losses[-1000:])
            print(
                f"Iteration {iteration+1}/{n_iterations} | "
                f"Meta Loss: {avg_loss:.4f}"
            )

    return losses

3.5 Reptile: Simplifying MAML

Reptile (Nichol et al., 2018) is a greatly simplified version of MAML. It requires no second-order gradients and is much easier to implement.

Core idea: After running SGD multiple times on a task, move the meta-parameters toward the resulting parameters.

theta = theta + epsilon * (W_k - theta)

where W_k is the parameter after k SGD updates on task T.

class Reptile:
    """
    Reptile: A Scalable Meta-learning Algorithm
    Nichol et al., 2018 (arXiv:1803.02999)
    """

    def __init__(
        self,
        model: nn.Module,
        inner_lr: float = 0.02,
        outer_lr: float = 0.001,  # epsilon
        n_inner_steps: int = 5,
        device: str = 'cpu'
    ):
        self.model = model.to(device)
        self.inner_lr = inner_lr
        self.outer_lr = outer_lr
        self.n_inner_steps = n_inner_steps
        self.device = device

    def inner_train(
        self,
        support_x: torch.Tensor,
        support_y: torch.Tensor
    ) -> dict:
        """Task-specific inner training (k SGD steps)"""
        model_copy = deepcopy(self.model)
        optimizer = torch.optim.SGD(
            model_copy.parameters(), lr=self.inner_lr
        )

        model_copy.train()
        for step in range(self.n_inner_steps):
            optimizer.zero_grad()
            logits = model_copy(support_x)
            loss = F.cross_entropy(logits, support_y)
            loss.backward()
            optimizer.step()

        return dict(model_copy.named_parameters())

    def meta_update(self, task_params_list: List[dict]):
        """
        Reptile meta-update:
        theta += epsilon * (mean(W_k) - theta)
        """
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                task_mean = torch.stack([
                    task_params[name].data
                    for task_params in task_params_list
                ]).mean(dim=0)

                param.data += self.outer_lr * (task_mean - param.data)

    def train(
        self,
        dataset,
        n_way: int = 5,
        k_shot: int = 5,
        meta_batch_size: int = 5,
        n_iterations: int = 100000
    ):
        """Reptile full training loop"""
        print(f"Reptile Training: {n_way}-way {k_shot}-shot")

        for iteration in range(n_iterations):
            task_params_list = []

            for _ in range(meta_batch_size):
                support_x, support_y, _, _ = create_episode(
                    dataset, n_way, k_shot, n_query=0
                )
                support_x = support_x.to(self.device)
                support_y = support_y.to(self.device)

                task_params = self.inner_train(support_x, support_y)
                task_params_list.append(task_params)

            self.meta_update(task_params_list)

            if (iteration + 1) % 10000 == 0:
                print(f"Iteration {iteration+1}/{n_iterations} complete")

4. In-Context Learning in LLMs

4.1 What Is In-Context Learning?

In-Context Learning (ICL) is the ability of large language models (LLMs) to perform new tasks from examples within the prompt. The model does not update its parameters; it "learns" solely from the input context (prompt).

This ability gained enormous attention with the arrival of GPT-3. For example:

English to French translation:
sea otter => loutre de mer
peppermint => menthe poivree
plush giraffe => girafe en peluche
cheese => ?

Given this prompt format, GPT-3 answers "fromage." It appears not to have been trained for French translation, but has already learned the pattern during pre-training.

4.2 Why Is It Effective?

The reasons ICL is effective are still being actively studied. Major hypotheses include:

Pattern completion view: LLMs are trained by predicting the next token. Seeing the pattern (input to output pairs) in the prompt, they are effectively trained to continue that pattern.

Latent concept inference view: According to research by Brown et al., ICL resembles Bayesian inference where the model infers a latent concept from the prompt.

Gradient descent metaphor: Akyürek et al. showed that the Transformer's attention mechanism implicitly performs operations analogous to gradient descent.

4.3 Effective Few-Shot Prompting Strategies

from typing import List, Dict, Any
import numpy as np


class FewShotPromptBuilder:
    """Few-shot prompt builder"""

    def __init__(self):
        self.examples = []
        self.instruction = ""
        self.template = "{input} => {output}"

    def set_instruction(self, instruction: str):
        """Set task instruction"""
        self.instruction = instruction
        return self

    def add_example(self, input_text: str, output_text: str):
        """Add an example"""
        self.examples.append({
            'input': input_text,
            'output': output_text
        })
        return self

    def build_prompt(self, query: str) -> str:
        """Build complete few-shot prompt"""
        parts = []

        if self.instruction:
            parts.append(self.instruction)
            parts.append("")

        for ex in self.examples:
            parts.append(self.template.format(
                input=ex['input'],
                output=ex['output']
            ))

        # Query with blank output
        parts.append(f"{query} =>")

        return "\n".join(parts)

    def build_chat_messages(
        self,
        query: str,
        system_prompt: str = None
    ) -> List[Dict]:
        """Build chat-format messages (for GPT-4, Claude, etc.)"""
        messages = []

        if system_prompt:
            messages.append({
                'role': 'system',
                'content': system_prompt
            })

        for ex in self.examples:
            messages.append({
                'role': 'user',
                'content': ex['input']
            })
            messages.append({
                'role': 'assistant',
                'content': ex['output']
            })

        messages.append({
            'role': 'user',
            'content': query
        })

        return messages


class DynamicExampleSelector:
    """
    Dynamic example selector
    Selects the most similar examples to the query for more effective few-shot learning
    """

    def __init__(self, examples: List[Dict], encoder=None):
        self.examples = examples
        self.encoder = encoder  # sentence-transformers etc.

    def select_similar(
        self,
        query: str,
        n_examples: int = 3
    ) -> List[Dict]:
        """
        Select the n examples most similar to the query
        """
        if self.encoder is None:
            return np.random.choice(
                self.examples, n_examples, replace=False
            ).tolist()

        query_emb = self.encoder.encode(query)
        example_embs = self.encoder.encode(
            [ex['input'] for ex in self.examples]
        )

        similarities = np.dot(example_embs, query_emb) / (
            np.linalg.norm(example_embs, axis=1)
            * np.linalg.norm(query_emb)
        )

        top_indices = np.argsort(similarities)[-n_examples:][::-1]
        return [self.examples[i] for i in top_indices]

    def select_diverse(
        self,
        n_examples: int = 3
    ) -> List[Dict]:
        """
        Select examples that maximize diversity (MMR algorithm)
        """
        if len(self.examples) <= n_examples:
            return self.examples

        if self.encoder is None:
            return np.random.choice(
                self.examples, n_examples, replace=False
            ).tolist()

        embeddings = self.encoder.encode(
            [ex['input'] for ex in self.examples]
        )

        selected = [0]
        remaining = list(range(1, len(self.examples)))

        while len(selected) < n_examples:
            selected_embs = embeddings[selected]
            best_idx = None
            best_score = float('-inf')

            for idx in remaining:
                sim_to_selected = np.max(
                    np.dot(selected_embs, embeddings[idx]) / (
                        np.linalg.norm(selected_embs, axis=1)
                        * np.linalg.norm(embeddings[idx])
                    )
                )
                score = -sim_to_selected  # maximize diversity

                if score > best_score:
                    best_score = score
                    best_idx = idx

            selected.append(best_idx)
            remaining.remove(best_idx)

        return [self.examples[i] for i in selected]


# ========== Practical Examples ==========

def sentiment_analysis_few_shot():
    """Sentiment analysis few-shot example"""
    builder = FewShotPromptBuilder()

    builder.set_instruction(
        "Analyze the sentiment of the following movie review. "
        "Answer Positive or Negative."
    )

    builder.add_example(
        "This movie was truly moving and the acting was superb.",
        "Positive"
    )
    builder.add_example(
        "The story was too boring and the ending was disappointing.",
        "Negative"
    )
    builder.add_example(
        "The special effects were impressive but the plot was too weak.",
        "Negative"
    )
    builder.add_example(
        "A warm film the whole family can enjoy together.",
        "Positive"
    )

    query = "The direction was unique and the music paired perfectly with the visuals."
    prompt = builder.build_prompt(query)

    print("=== Few-Shot Prompt ===")
    print(prompt)
    return prompt


def code_generation_few_shot():
    """Code generation few-shot example"""
    builder = FewShotPromptBuilder()

    builder.set_instruction(
        "Convert the natural language description into Python code."
    )

    builder.add_example(
        "Find the maximum element in a list",
        "def find_max(lst):\n    return max(lst)"
    )
    builder.add_example(
        "Check if a string is a palindrome",
        "def is_palindrome(s):\n    return s == s[::-1]"
    )
    builder.add_example(
        "Flatten a nested list",
        "def flatten(lst):\n    return [x for sublist in lst for x in sublist]"
    )

    query = "Count the frequency of each element in a list"
    prompt = builder.build_prompt(query)

    print("=== Code Generation Few-Shot Prompt ===")
    print(prompt)
    return prompt

4.4 Cross-lingual Few-shot

Multilingual models demonstrate zero-shot cross-lingual transfer ability. Tasks trained only in English can be applied to Korean or other languages.

class CrossLingualFewShot:
    """
    Cross-lingual few-shot learning
    English examples + target-language query
    """

    def __init__(self, model_name: str = "xlm-roberta-large"):
        from transformers import AutoTokenizer, AutoModelForSequenceClassification
        import torch

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_name, num_labels=2
        )

    def encode_text(self, text: str) -> torch.Tensor:
        """Encode text to embedding"""
        inputs = self.tokenizer(
            text,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=128
        )
        with torch.no_grad():
            outputs = self.model(**inputs, output_hidden_states=True)
            embedding = outputs.hidden_states[-1][:, 0, :]
        return embedding

    def classify_zero_shot(
        self,
        query: str,
        class_descriptions_en: List[str]
    ) -> int:
        """
        Classify a query using English class descriptions
        (query can be in any language)
        """
        import torch.nn.functional as F

        query_emb = self.encode_text(query)
        class_embs = torch.cat([
            self.encode_text(desc) for desc in class_descriptions_en
        ])

        similarities = F.cosine_similarity(
            query_emb.expand(len(class_descriptions_en), -1),
            class_embs
        )

        return similarities.argmax().item()

    def few_shot_classify(
        self,
        support_texts: List[str],
        support_labels: List[int],
        query_text: str,
        n_classes: int
    ) -> int:
        """
        Few-shot classification (prototype approach)
        Support and query can be in different languages
        """
        import torch
        import torch.nn.functional as F

        support_embs = torch.cat([
            self.encode_text(t) for t in support_texts
        ])
        query_emb = self.encode_text(query_text)

        prototypes = []
        for cls in range(n_classes):
            mask = torch.tensor([l == cls for l in support_labels])
            proto = support_embs[mask].mean(dim=0, keepdim=True)
            prototypes.append(proto)
        prototypes = torch.cat(prototypes)

        dists = torch.cdist(query_emb, prototypes)
        return dists.argmin().item()

5. Hands-On: Medical Image Few-Shot Classification

5.1 Rare Disease Diagnosis System

In clinical settings, rare diseases have very limited training data. Few-shot learning enables recognition of new disease patterns from only a small number of confirmed cases.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np


class MedicalImageEncoder(nn.Module):
    """
    Encoder for medical images
    Based on ResNet-18, adapted for medical imaging characteristics
    """

    def __init__(self, embed_dim: int = 512, pretrained: bool = True):
        super().__init__()

        backbone = models.resnet18(pretrained=pretrained)
        self.backbone = nn.Sequential(*list(backbone.children())[:-1])

        self.embed_head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512, embed_dim),
            nn.LayerNorm(embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.backbone(x)
        return self.embed_head(features)


class MedicalFewShotClassifier:
    """
    Medical image few-shot classifier
    Based on Prototypical Networks
    """

    def __init__(
        self,
        encoder: MedicalImageEncoder,
        device: str = 'cpu'
    ):
        self.encoder = encoder.to(device)
        self.device = device
        self.prototypes = {}
        self.disease_names = {}

    def register_disease(
        self,
        disease_id: int,
        disease_name: str,
        support_images: List,
        transform=None
    ):
        """
        Register a new disease (from a small number of examples)
        """
        self.disease_names[disease_id] = disease_name
        self.encoder.eval()

        embeddings = []
        with torch.no_grad():
            for img in support_images:
                if transform:
                    img_tensor = transform(img).unsqueeze(0).to(self.device)
                else:
                    img_tensor = img.unsqueeze(0).to(self.device)
                emb = self.encoder(img_tensor)
                embeddings.append(emb)

        prototype = torch.cat(embeddings).mean(dim=0)
        self.prototypes[disease_id] = prototype

        print(
            f"Disease registered: {disease_name} "
            f"({len(support_images)} examples)"
        )

    def diagnose(
        self,
        query_image: torch.Tensor,
        top_k: int = 3
    ) -> List[Dict]:
        """
        Diagnose a query image
        Returns: top-k diseases with similarity scores
        """
        self.encoder.eval()

        with torch.no_grad():
            query_emb = self.encoder(
                query_image.unsqueeze(0).to(self.device)
            )

        results = []
        for disease_id, prototype in self.prototypes.items():
            similarity = F.cosine_similarity(
                query_emb, prototype.unsqueeze(0)
            ).item()
            results.append({
                'disease_id': disease_id,
                'disease_name': self.disease_names[disease_id],
                'similarity': similarity
            })

        results.sort(key=lambda x: x['similarity'], reverse=True)
        return results[:top_k]

    def update_prototype(
        self,
        disease_id: int,
        new_image: torch.Tensor,
        momentum: float = 0.9
    ):
        """
        Update prototype with a new confirmed case (online learning)
        """
        self.encoder.eval()

        with torch.no_grad():
            new_emb = self.encoder(
                new_image.unsqueeze(0).to(self.device)
            ).squeeze(0)

        if disease_id in self.prototypes:
            # Exponential moving average update
            self.prototypes[disease_id] = (
                momentum * self.prototypes[disease_id]
                + (1 - momentum) * new_emb
            )
            print(
                f"Prototype updated: {self.disease_names[disease_id]}"
            )
        else:
            print(f"Warning: disease {disease_id} is not registered.")


def demo_medical_few_shot():
    """Medical few-shot classification demo"""
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

    encoder = MedicalImageEncoder(embed_dim=512, pretrained=True)
    classifier = MedicalFewShotClassifier(encoder)

    print("=== Medical Image Few-Shot Classification System ===")
    print("This system recognizes new diseases from only a few confirmed cases.")
    print()
    print("Usage:")
    print("1. classifier.register_disease(id, name, support_images)")
    print("2. results = classifier.diagnose(patient_image)")
    print("3. classifier.update_prototype(disease_id, new_confirmed_image)")

6. Using the learn2learn Library

6.1 Introduction to learn2learn

learn2learn is a library that makes it easy to implement meta-learning algorithms such as MAML and ProtoNet.

pip install learn2learn

6.2 MAML with learn2learn

import learn2learn as l2l
import torch
import torch.nn as nn
from typing import Tuple


def build_l2l_maml(
    model: nn.Module,
    lr: float = 0.01,
    first_order: bool = False
) -> l2l.algorithms.MAML:
    """Create learn2learn MAML wrapper"""
    return l2l.algorithms.MAML(
        model,
        lr=lr,
        first_order=first_order,
        allow_unused=True
    )


def train_with_l2l(
    maml_model: l2l.algorithms.MAML,
    tasksets,
    n_way: int = 5,
    k_shot: int = 1,
    n_query: int = 15,
    meta_lr: float = 0.003,
    n_iterations: int = 1000,
    adaptation_steps: int = 1,
    device: str = 'cpu'
):
    """MAML training with learn2learn"""
    maml_model = maml_model.to(device)
    meta_optimizer = torch.optim.Adam(
        maml_model.parameters(), lr=meta_lr
    )
    criterion = nn.CrossEntropyLoss(reduction='mean')

    for iteration in range(n_iterations):
        meta_optimizer.zero_grad()
        meta_train_loss = 0.0
        meta_train_acc = 0.0

        for task in range(4):
            X, y = tasksets.train.sample()
            X, y = X.to(device), y.to(device)

            learner = maml_model.clone()

            support_indices = torch.zeros(X.size(0), dtype=torch.bool)
            for cls in range(n_way):
                cls_idx = (y == cls).nonzero(as_tuple=True)[0]
                support_idx = cls_idx[:k_shot]
                support_indices[support_idx] = True

            query_indices = ~support_indices
            support_x, support_y = X[support_indices], y[support_indices]
            query_x, query_y = X[query_indices], y[query_indices]

            # Inner loop: adaptation
            for step in range(adaptation_steps):
                support_logits = learner(support_x)
                support_loss = criterion(support_logits, support_y)
                learner.adapt(support_loss)

            # Outer loop: meta-gradient
            query_logits = learner(query_x)
            query_loss = criterion(query_logits, query_y)
            meta_train_loss += query_loss

            preds = query_logits.argmax(dim=-1)
            acc = (preds == query_y).float().mean()
            meta_train_acc += acc

        meta_train_loss /= 4
        meta_train_acc /= 4

        meta_train_loss.backward()
        meta_optimizer.step()

        if (iteration + 1) % 100 == 0:
            print(
                f"Iteration {iteration+1}/{n_iterations} | "
                f"Meta Loss: {meta_train_loss.item():.4f} | "
                f"Meta Acc: {meta_train_acc.item():.4f}"
            )


def setup_omniglot_maml():
    """
    Set up MAML with the Omniglot dataset
    Omniglot: 1623 character classes from 50 alphabets (20 samples each)
    """
    tasksets = l2l.vision.benchmarks.get_tasksets(
        'omniglot',
        train_ways=5,
        train_samples=2 * 1 + 2 * 15,
        test_ways=5,
        test_samples=2 * 1 + 2 * 15,
        root='./data',
        device='cpu'
    )

    model = l2l.vision.models.OmniglotCNN(
        output_size=5,
        hidden_size=64,
        layers=4
    )

    maml = build_l2l_maml(model, lr=0.4, first_order=False)

    return maml, tasksets


def evaluate_l2l(
    maml_model: l2l.algorithms.MAML,
    tasksets,
    n_way: int = 5,
    k_shot: int = 1,
    n_query: int = 15,
    n_test_tasks: int = 600,
    adaptation_steps: int = 3,
    device: str = 'cpu'
) -> Tuple[float, float]:
    """Meta-evaluation with learn2learn"""
    criterion = nn.CrossEntropyLoss()
    total_loss = 0.0
    total_acc = 0.0

    maml_model.eval()

    for _ in range(n_test_tasks):
        X, y = tasksets.test.sample()
        X, y = X.to(device), y.to(device)

        learner = maml_model.clone()

        support_indices = torch.zeros(X.size(0), dtype=torch.bool)
        for cls in range(n_way):
            cls_idx = (y == cls).nonzero(as_tuple=True)[0]
            support_idx = cls_idx[:k_shot]
            support_indices[support_idx] = True

        query_indices = ~support_indices
        support_x, support_y = X[support_indices], y[support_indices]
        query_x, query_y = X[query_indices], y[query_indices]

        for step in range(adaptation_steps):
            support_loss = criterion(learner(support_x), support_y)
            learner.adapt(support_loss)

        with torch.no_grad():
            query_logits = learner(query_x)
            loss = criterion(query_logits, query_y)
            acc = (query_logits.argmax(dim=-1) == query_y).float().mean()

        total_loss += loss.item()
        total_acc += acc.item()

    return total_loss / n_test_tasks, total_acc / n_test_tasks

7. Benchmarks and Evaluation

7.1 Key Benchmark Datasets

Omniglot

1623 character classes from 50 alphabet systems, 20 samples per class. Primarily evaluated in the 20-way 1-shot setting.

Mini-ImageNet

A 100-class subset of ImageNet. 600 images per class (84x84). The 5-way 1/5-shot setting is standard.

tieredImageNet

A harder version of Mini-ImageNet. Classes are grouped by superclass concepts, creating a larger semantic gap between meta-train and meta-test classes.

CIFAR-FS

A few-shot benchmark derived from CIFAR-100. Faster to experiment with than Mini-ImageNet.

7.2 Standard Evaluation Protocol

def standard_few_shot_evaluation(
    model,
    test_dataset,
    n_way: int = 5,
    k_shot: int = 1,
    n_query: int = 15,
    n_episodes: int = 600,
    confidence_interval: bool = True
) -> Dict:
    """
    Standard few-shot evaluation protocol
    Mean and 95% confidence interval over 600 episodes
    """
    accs = []
    model.eval()

    for episode in range(n_episodes):
        support_x, support_y, query_x, query_y = create_episode(
            test_dataset, n_way, k_shot, n_query
        )

        with torch.no_grad():
            log_probs = model(support_x, support_y, query_x, n_way)
            preds = log_probs.argmax(dim=-1)
            acc = (preds == query_y).float().mean().item()

        accs.append(acc)

    mean_acc = np.mean(accs)
    std_acc = np.std(accs)

    if confidence_interval:
        ci = 1.96 * std_acc / np.sqrt(n_episodes)
        return {
            'mean_accuracy': mean_acc,
            'std': std_acc,
            'confidence_interval_95': ci,
            'result_string': f"{mean_acc*100:.2f} +/- {ci*100:.2f}%"
        }

    return {'mean_accuracy': mean_acc, 'std': std_acc}

8. The Future of Meta-Learning

Meta-learning is establishing itself as a core paradigm in AI research. Key trends to watch:

Convergence with LLMs

Large language models like GPT-4 and Claude demonstrate powerful ICL capabilities. Active research views these models as meta-learners performing few-shot learning across arbitrary domains.

Multimodal Few-Shot Learning

Few-shot learning integrating text, images, and audio. Models like GPT-4V and Gemini Ultra are showing impressive performance on visual few-shot tasks.

Combination with Continual Learning

Research shows that models initialized via meta-learning forget less of previous knowledge when learning new tasks. The combination of continual learning and meta-learning is an active research area.

Domain Adaptation Applications

Industrial applications where data is scarce — rare disease diagnosis, satellite imagery analysis, specialized code generation — are emerging as the most practical use cases for few-shot learning.


References

  • Finn, C., et al. (2017). Model-Agnostic Meta-Learning for Fast Adaptation. ICML 2017. arXiv:1703.03400
  • Snell, J., et al. (2017). Prototypical Networks for Few-shot Learning. NeurIPS 2017. arXiv:1703.05175
  • Vinyals, O., et al. (2016). Matching Networks for One Shot Learning. NeurIPS 2016. arXiv:1602.01783
  • Nichol, A., et al. (2018). On First-Order Meta-Learning Algorithms. arXiv:1803.02999
  • Brown, T., et al. (2020). Language Models are Few-Shot Learners (GPT-3). NeurIPS 2020.
  • learn2learn library: https://github.com/learnables/learn2learn