Skip to content
Published on

Meta-Learning and Few-Shot Learning Complete Guide: MAML, Prototypical Networks, In-Context Learning

Authors

Meta-Learning and Few-Shot Learning Complete Guide: MAML, Prototypical Networks, In-Context Learning

Humans learn new concepts quickly from just a few examples. Show a child "this is a zebra" once, and they can recognize a zebra in an entirely new photograph the next day. But traditional deep learning models require thousands of images just to build a zebra classifier.

Meta-learning and few-shot learning bridge this gap. The core idea of meta-learning is "learning to learn." The model gains the ability to rapidly adapt to new tasks by experiencing a variety of tasks during training.

This guide covers everything from the theoretical foundations of meta-learning to the latest developments in In-Context Learning.


1. Meta-Learning Fundamentals

1.1 Learning to Learn

In traditional machine learning, a model is trained for a single task. A cat classifier is trained only on cat data, and when a new animal (e.g., an ocelot) needs to be classified, training must start from scratch.

Meta-learning shifts the perspective. What the model needs to learn is not "how to classify cats" but "how to learn to quickly classify new animals."

Meta-learning therefore has two levels of learning.

  • Meta-learning (Outer Loop): Learning a good initialization or learning algorithm by experiencing many tasks
  • Task learning (Inner Loop): Rapidly adapting to each specific task from a small number of examples

1.2 Limitations of Traditional Learning

The limitations of traditional learning are:

Data inefficiency: A model trained on ImageNet requires millions of images. Adding new classes requires thousands of samples.

Lack of generalization: Performance drops sharply on new tasks that differ significantly from the training distribution.

Catastrophic forgetting: Learning new tasks causes the model to forget previously learned tasks.

1.3 Task Distribution

The central concept in meta-learning is the task distribution p(T). Meta-learning does not simply learn from data; it learns from a distribution of tasks.

Each task T includes:

  • A distribution p(x, y) of input-output pairs
  • A task loss function L

The meta-learning objective is:

min over theta: E over T ~ p(T) [L_T(f_theta)]

1.4 Support Set vs Query Set

In few-shot learning, data is divided into two roles.

Support Set: A small number of examples used as references when the model learns a new task. These correspond to training data in traditional learning, but are extremely few (e.g., 1–5 per class).

Query Set: Data used to evaluate model performance. These correspond to test data in traditional learning.

1.5 N-way K-shot Setup

The most important configuration in few-shot learning is N-way K-shot.

  • N-way: Number of classes to classify
  • K-shot: Number of support examples per class

For example, 5-way 1-shot is a task of classifying 5 classes with only 1 example per class. 5-way 5-shot uses 5 examples per class.

import torch
import torch.nn as nn
import numpy as np
from typing import List, Tuple, Dict


def create_episode(
    dataset,
    n_way: int,
    k_shot: int,
    n_query: int,
    classes: List[int] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Create N-way K-shot episode
    Returns: (support_x, support_y, query_x, query_y)
    """
    all_classes = list(set(dataset.targets.tolist()))
    if classes is None:
        selected_classes = np.random.choice(
            all_classes, n_way, replace=False
        )
    else:
        selected_classes = classes

    support_x, support_y = [], []
    query_x, query_y = [], []

    for new_label, cls in enumerate(selected_classes):
        cls_indices = (dataset.targets == cls).nonzero(as_tuple=True)[0]
        chosen = np.random.choice(
            len(cls_indices), k_shot + n_query, replace=False
        )

        for i, idx in enumerate(cls_indices[chosen]):
            x, _ = dataset[idx.item()]
            if i < k_shot:
                support_x.append(x)
                support_y.append(new_label)
            else:
                query_x.append(x)
                query_y.append(new_label)

    support_x = torch.stack(support_x)
    support_y = torch.tensor(support_y)
    query_x = torch.stack(query_x)
    query_y = torch.tensor(query_y)

    return support_x, support_y, query_x, query_y

2. Distance-Based Meta-Learning

2.1 Matching Networks

Matching Networks, proposed by Vinyals et al. in 2016, combines attention mechanisms with the kNN idea. The core idea is to predict a query sample as an attention-weighted sum of support set labels.

Prediction formula

y_hat = sum over i: a(x_hat, x_i) * y_i

Here a(x_hat, x_i) is the attention weight between query x_hat and support sample x_i, computed as a softmax over cosine similarities.

class MatchingNetworks(nn.Module):
    """Matching Networks implementation"""

    def __init__(self, encoder: nn.Module, use_fce: bool = False):
        """
        encoder: feature extractor
        use_fce: whether to use Full Context Embedding
        """
        super().__init__()
        self.encoder = encoder
        self.use_fce = use_fce

    def cosine_similarity(
        self,
        query: torch.Tensor,
        support: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute cosine similarity
        query: (n_query, embed_dim)
        support: (n_support, embed_dim)
        Returns: (n_query, n_support)
        """
        query_norm = nn.functional.normalize(query, dim=-1)
        support_norm = nn.functional.normalize(support, dim=-1)
        return torch.mm(query_norm, support_norm.t())

    def forward(
        self,
        support_x: torch.Tensor,
        support_y: torch.Tensor,
        query_x: torch.Tensor,
        n_way: int
    ) -> torch.Tensor:
        """
        support_x: (n_way * k_shot, C, H, W)
        support_y: (n_way * k_shot,)
        query_x: (n_query, C, H, W)
        """
        support_emb = self.encoder(support_x)   # (n_support, D)
        query_emb = self.encoder(query_x)        # (n_query, D)

        similarities = self.cosine_similarity(
            query_emb, support_emb
        )  # (n_query, n_support)

        # Softmax attention
        attention = nn.functional.softmax(similarities, dim=-1)

        # One-hot labels
        support_labels_one_hot = nn.functional.one_hot(
            support_y, n_way
        ).float()  # (n_support, n_way)

        # Attention-weighted prediction
        logits = torch.mm(attention, support_labels_one_hot)
        return logits

2.2 Prototypical Networks

Prototypical Networks, published by Snell et al. in 2017, is one of the most elegant and intuitive meta-learning algorithms. The core idea: represent each class as a single prototype (centroid) in embedding space.

Prototype computation

The prototype of class c is the mean of the embeddings of the support samples for that class.

p_c = (1/|S_c|) sum over (x_i, y_i) in S_c: f_phi(x_i)

Classification

Classify a query sample x to the nearest prototype.

p(y=c | x) = softmax(-d(f_phi(x), p_c))

where d is the Euclidean distance.

class ConvEncoder(nn.Module):
    """4-layer CNN encoder for few-shot learning"""

    def __init__(
        self,
        in_channels: int = 1,
        hidden_dim: int = 64,
        out_dim: int = 64
    ):
        super().__init__()

        def conv_block(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(),
                nn.MaxPool2d(2)
            )

        self.net = nn.Sequential(
            conv_block(in_channels, hidden_dim),
            conv_block(hidden_dim, hidden_dim),
            conv_block(hidden_dim, hidden_dim),
            conv_block(hidden_dim, out_dim),
            nn.Flatten()
        )

    def forward(self, x):
        return self.net(x)


class PrototypicalNetworks(nn.Module):
    """Complete Prototypical Networks implementation"""

    def __init__(self, encoder: nn.Module):
        super().__init__()
        self.encoder = encoder

    def compute_prototypes(
        self,
        support_x: torch.Tensor,
        support_y: torch.Tensor,
        n_way: int
    ) -> torch.Tensor:
        """
        Compute class prototypes (mean of embeddings)
        support_x: (n_way * k_shot, C, H, W)
        support_y: (n_way * k_shot,)
        Returns: (n_way, embed_dim)
        """
        support_emb = self.encoder(support_x)  # (n_support, D)
        prototypes = []

        for cls in range(n_way):
            mask = (support_y == cls)
            cls_embeddings = support_emb[mask]
            prototype = cls_embeddings.mean(dim=0)
            prototypes.append(prototype)

        return torch.stack(prototypes)  # (n_way, D)

    def euclidean_dist(
        self,
        x: torch.Tensor,
        y: torch.Tensor
    ) -> torch.Tensor:
        """
        Euclidean distance
        x: (n, D)
        y: (m, D)
        Returns: (n, m)
        """
        # ||x - y||^2 = ||x||^2 + ||y||^2 - 2*x.y
        n = x.size(0)
        m = y.size(0)
        x_sq = (x ** 2).sum(dim=1, keepdim=True).expand(n, m)
        y_sq = (y ** 2).sum(dim=1, keepdim=True).expand(m, n).t()
        xy = torch.mm(x, y.t())
        dist = x_sq + y_sq - 2 * xy
        return dist.clamp(min=0).sqrt()

    def forward(
        self,
        support_x: torch.Tensor,
        support_y: torch.Tensor,
        query_x: torch.Tensor,
        n_way: int
    ) -> torch.Tensor:
        """
        Prototypical Networks forward pass
        Returns: log-probabilities for query samples (n_query, n_way)
        """
        prototypes = self.compute_prototypes(
            support_x, support_y, n_way
        )  # (n_way, D)

        query_emb = self.encoder(query_x)  # (n_query, D)

        dists = self.euclidean_dist(
            query_emb, prototypes
        )  # (n_query, n_way)

        log_probs = nn.functional.log_softmax(-dists, dim=-1)
        return log_probs


# ========== Training Loop ==========

def train_prototypical(
    model: PrototypicalNetworks,
    train_dataset,
    n_way: int = 5,
    k_shot: int = 5,
    n_query: int = 15,
    n_episodes: int = 100,
    lr: float = 1e-3,
    device: str = 'cpu'
):
    """Train Prototypical Networks"""
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.NLLLoss()

    model.train()
    episode_losses = []
    episode_accs = []

    for episode in range(n_episodes):
        support_x, support_y, query_x, query_y = create_episode(
            train_dataset, n_way, k_shot, n_query
        )
        support_x = support_x.to(device)
        support_y = support_y.to(device)
        query_x = query_x.to(device)
        query_y = query_y.to(device)

        optimizer.zero_grad()

        log_probs = model(support_x, support_y, query_x, n_way)
        loss = criterion(log_probs, query_y)
        loss.backward()
        optimizer.step()

        preds = log_probs.argmax(dim=-1)
        acc = (preds == query_y).float().mean().item()

        episode_losses.append(loss.item())
        episode_accs.append(acc)

        if (episode + 1) % 100 == 0:
            print(
                f"Episode {episode+1}/{n_episodes} | "
                f"Loss: {np.mean(episode_losses[-100:]):.4f} | "
                f"Acc: {np.mean(episode_accs[-100:]):.4f}"
            )

    return episode_losses, episode_accs

2.3 Relation Networks

Relation Networks by Sung et al. are similar to Prototypical Networks, but the distance function is replaced with a learnable neural network. The network learns to compute a relation score by concatenating the query embedding and the class prototype.

class RelationNetwork(nn.Module):
    """Relation Networks: learnable distance function"""

    def __init__(self, encoder: nn.Module, embed_dim: int = 64):
        super().__init__()
        self.encoder = encoder

        # Relation module: takes concatenated embeddings, outputs relation score
        self.relation_module = nn.Sequential(
            nn.Linear(embed_dim * 2, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(
        self,
        support_x: torch.Tensor,
        support_y: torch.Tensor,
        query_x: torch.Tensor,
        n_way: int
    ) -> torch.Tensor:
        """
        Returns: relation scores (n_query, n_way)
        """
        support_emb = self.encoder(support_x)
        query_emb = self.encoder(query_x)

        prototypes = []
        for cls in range(n_way):
            mask = (support_y == cls)
            proto = support_emb[mask].mean(dim=0)
            prototypes.append(proto)
        prototypes = torch.stack(prototypes)  # (n_way, D)

        n_query = query_emb.size(0)

        query_expanded = query_emb.unsqueeze(1).expand(
            n_query, n_way, -1
        )  # (n_query, n_way, D)
        proto_expanded = prototypes.unsqueeze(0).expand(
            n_query, n_way, -1
        )  # (n_query, n_way, D)

        pairs = torch.cat(
            [query_expanded, proto_expanded], dim=-1
        )  # (n_query, n_way, 2D)

        scores = self.relation_module(
            pairs.view(-1, pairs.size(-1))
        ).view(n_query, n_way)

        return scores

3. Optimization-Based Meta-Learning: MAML

3.1 The Core Idea of MAML

MAML (Model-Agnostic Meta-Learning), proposed by Finn et al. in 2017, is one of the most influential works in meta-learning.

MAML's goal is to "find an initial set of parameters theta from which fast adaptation is possible."

Concretely, it finds an initialization point from which only a few gradient update steps are needed to achieve good performance on a new task.

3.2 Inner Loop vs Outer Loop

MAML consists of two loops.

Inner Loop (Task-specific adaptation)

For each task T_i:

theta_i' = theta - alpha * grad_theta L_{T_i}(f_theta)

Perform 1–5 gradient updates on the support set to obtain task-specific parameters theta_i'.

Outer Loop (Meta-update)

theta = theta - beta * grad_theta sum_i L_{T_i}(f_{theta_i'})

Compute the loss on the query set using task-adapted parameters theta_i', then update the meta-parameters theta based on this loss.

3.3 Second-Order Gradients

The key technical challenge of MAML is that the outer loop gradient requires second-order gradients (backpropagation through the inner loop).

grad_theta L(f_{theta_i'}) = grad_theta L(f_{theta - alpha * grad L(theta)})

Computing gradients with respect to theta requires differentiating through the inner loop update (involves the Hessian matrix). This is computationally expensive.

In practice, FOMAML (First-Order MAML) is often used, which ignores the second-order gradient terms and uses an approximate gradient.

3.4 Complete MAML Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from copy import deepcopy
from typing import List, Tuple


class MAML:
    """
    MAML: Model-Agnostic Meta-Learning
    Finn et al., 2017 (arXiv:1703.03400)
    """

    def __init__(
        self,
        model: nn.Module,
        inner_lr: float = 0.01,       # alpha: inner loop learning rate
        outer_lr: float = 0.001,      # beta: outer loop learning rate
        n_inner_steps: int = 5,       # number of inner loop updates
        first_order: bool = False,    # whether to use FOMAML
        device: str = 'cpu'
    ):
        self.model = model.to(device)
        self.inner_lr = inner_lr
        self.outer_lr = outer_lr
        self.n_inner_steps = n_inner_steps
        self.first_order = first_order
        self.device = device

        self.meta_optimizer = torch.optim.Adam(
            self.model.parameters(), lr=outer_lr
        )

    def inner_loop(
        self,
        support_x: torch.Tensor,
        support_y: torch.Tensor,
        model_params=None
    ) -> dict:
        """
        Inner loop: task-specific adaptation on the support set
        Returns: adapted parameter dictionary
        """
        if model_params is None:
            params = {
                name: param.clone()
                for name, param in self.model.named_parameters()
            }
        else:
            params = {k: v.clone() for k, v in model_params.items()}

        for step in range(self.n_inner_steps):
            logits = self._forward_with_params(support_x, params)
            loss = F.cross_entropy(logits, support_y)

            grads = torch.autograd.grad(
                loss,
                params.values(),
                create_graph=not self.first_order
            )

            params = {
                name: param - self.inner_lr * grad
                for (name, param), grad in zip(params.items(), grads)
            }

        return params

    def _forward_with_params(
        self,
        x: torch.Tensor,
        params: dict
    ) -> torch.Tensor:
        """Forward pass with specific parameters"""
        original_params = {}
        for name, param in self.model.named_parameters():
            original_params[name] = param.data
            param.data = params[name].data if name in params else param.data

        output = self.model(x)

        for name, param in self.model.named_parameters():
            if name in original_params:
                param.data = original_params[name]

        return output

    def meta_train_step(
        self,
        tasks: List[Tuple]
    ) -> float:
        """
        One meta-training step (over multiple tasks)
        tasks: [(support_x, support_y, query_x, query_y), ...]
        """
        self.meta_optimizer.zero_grad()

        meta_loss = 0.0

        for support_x, support_y, query_x, query_y in tasks:
            support_x = support_x.to(self.device)
            support_y = support_y.to(self.device)
            query_x = query_x.to(self.device)
            query_y = query_y.to(self.device)

            adapted_params = self.inner_loop(support_x, support_y)

            query_logits = self._forward_with_params(
                query_x, adapted_params
            )
            query_loss = F.cross_entropy(query_logits, query_y)
            meta_loss += query_loss

        meta_loss /= len(tasks)
        meta_loss.backward()
        self.meta_optimizer.step()

        return meta_loss.item()

    def fine_tune(
        self,
        support_x: torch.Tensor,
        support_y: torch.Tensor,
        n_steps: int = None
    ) -> nn.Module:
        """
        Fine-tune on a new task (used at inference)
        """
        n_steps = n_steps or self.n_inner_steps
        model_copy = deepcopy(self.model)
        optimizer = torch.optim.SGD(
            model_copy.parameters(), lr=self.inner_lr
        )

        model_copy.train()
        for step in range(n_steps):
            optimizer.zero_grad()
            logits = model_copy(support_x.to(self.device))
            loss = F.cross_entropy(logits, support_y.to(self.device))
            loss.backward()
            optimizer.step()

        return model_copy

    def evaluate(
        self,
        tasks: List[Tuple],
        n_fine_tune_steps: int = 5
    ) -> Tuple[float, float]:
        """
        Meta-test: adapt to new tasks and evaluate
        """
        total_loss = 0.0
        total_acc = 0.0

        for support_x, support_y, query_x, query_y in tasks:
            adapted_model = self.fine_tune(
                support_x, support_y, n_fine_tune_steps
            )
            adapted_model.eval()

            with torch.no_grad():
                query_logits = adapted_model(query_x.to(self.device))
                loss = F.cross_entropy(
                    query_logits, query_y.to(self.device)
                )
                preds = query_logits.argmax(dim=-1)
                acc = (preds == query_y.to(self.device)).float().mean()

            total_loss += loss.item()
            total_acc += acc.item()

        return total_loss / len(tasks), total_acc / len(tasks)


def train_maml(
    maml: MAML,
    dataset,
    n_way: int = 5,
    k_shot: int = 1,
    n_query: int = 15,
    meta_batch_size: int = 32,
    n_iterations: int = 60000,
    device: str = 'cpu'
):
    """MAML training loop"""
    print(f"MAML Training: {n_way}-way {k_shot}-shot")
    print(f"Meta-batch size: {meta_batch_size}")
    print(f"Total iterations: {n_iterations}")

    losses = []

    for iteration in range(n_iterations):
        tasks = []
        for _ in range(meta_batch_size):
            task = create_episode(dataset, n_way, k_shot, n_query)
            tasks.append(task)

        meta_loss = maml.meta_train_step(tasks)
        losses.append(meta_loss)

        if (iteration + 1) % 1000 == 0:
            avg_loss = np.mean(losses[-1000:])
            print(
                f"Iteration {iteration+1}/{n_iterations} | "
                f"Meta Loss: {avg_loss:.4f}"
            )

    return losses

3.5 Reptile: Simplifying MAML

Reptile (Nichol et al., 2018) is a greatly simplified version of MAML. It requires no second-order gradients and is much easier to implement.

Core idea: After running SGD multiple times on a task, move the meta-parameters toward the resulting parameters.

theta = theta + epsilon * (W_k - theta)

where W_k is the parameter after k SGD updates on task T.

class Reptile:
    """
    Reptile: A Scalable Meta-learning Algorithm
    Nichol et al., 2018 (arXiv:1803.02999)
    """

    def __init__(
        self,
        model: nn.Module,
        inner_lr: float = 0.02,
        outer_lr: float = 0.001,  # epsilon
        n_inner_steps: int = 5,
        device: str = 'cpu'
    ):
        self.model = model.to(device)
        self.inner_lr = inner_lr
        self.outer_lr = outer_lr
        self.n_inner_steps = n_inner_steps
        self.device = device

    def inner_train(
        self,
        support_x: torch.Tensor,
        support_y: torch.Tensor
    ) -> dict:
        """Task-specific inner training (k SGD steps)"""
        model_copy = deepcopy(self.model)
        optimizer = torch.optim.SGD(
            model_copy.parameters(), lr=self.inner_lr
        )

        model_copy.train()
        for step in range(self.n_inner_steps):
            optimizer.zero_grad()
            logits = model_copy(support_x)
            loss = F.cross_entropy(logits, support_y)
            loss.backward()
            optimizer.step()

        return dict(model_copy.named_parameters())

    def meta_update(self, task_params_list: List[dict]):
        """
        Reptile meta-update:
        theta += epsilon * (mean(W_k) - theta)
        """
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                task_mean = torch.stack([
                    task_params[name].data
                    for task_params in task_params_list
                ]).mean(dim=0)

                param.data += self.outer_lr * (task_mean - param.data)

    def train(
        self,
        dataset,
        n_way: int = 5,
        k_shot: int = 5,
        meta_batch_size: int = 5,
        n_iterations: int = 100000
    ):
        """Reptile full training loop"""
        print(f"Reptile Training: {n_way}-way {k_shot}-shot")

        for iteration in range(n_iterations):
            task_params_list = []

            for _ in range(meta_batch_size):
                support_x, support_y, _, _ = create_episode(
                    dataset, n_way, k_shot, n_query=0
                )
                support_x = support_x.to(self.device)
                support_y = support_y.to(self.device)

                task_params = self.inner_train(support_x, support_y)
                task_params_list.append(task_params)

            self.meta_update(task_params_list)

            if (iteration + 1) % 10000 == 0:
                print(f"Iteration {iteration+1}/{n_iterations} complete")

4. In-Context Learning in LLMs

4.1 What Is In-Context Learning?

In-Context Learning (ICL) is the ability of large language models (LLMs) to perform new tasks from examples within the prompt. The model does not update its parameters; it "learns" solely from the input context (prompt).

This ability gained enormous attention with the arrival of GPT-3. For example:

English to French translation:
sea otter => loutre de mer
peppermint => menthe poivree
plush giraffe => girafe en peluche
cheese => ?

Given this prompt format, GPT-3 answers "fromage." It appears not to have been trained for French translation, but has already learned the pattern during pre-training.

4.2 Why Is It Effective?

The reasons ICL is effective are still being actively studied. Major hypotheses include:

Pattern completion view: LLMs are trained by predicting the next token. Seeing the pattern (input to output pairs) in the prompt, they are effectively trained to continue that pattern.

Latent concept inference view: According to research by Brown et al., ICL resembles Bayesian inference where the model infers a latent concept from the prompt.

Gradient descent metaphor: Akyürek et al. showed that the Transformer's attention mechanism implicitly performs operations analogous to gradient descent.

4.3 Effective Few-Shot Prompting Strategies

from typing import List, Dict, Any
import numpy as np


class FewShotPromptBuilder:
    """Few-shot prompt builder"""

    def __init__(self):
        self.examples = []
        self.instruction = ""
        self.template = "{input} => {output}"

    def set_instruction(self, instruction: str):
        """Set task instruction"""
        self.instruction = instruction
        return self

    def add_example(self, input_text: str, output_text: str):
        """Add an example"""
        self.examples.append({
            'input': input_text,
            'output': output_text
        })
        return self

    def build_prompt(self, query: str) -> str:
        """Build complete few-shot prompt"""
        parts = []

        if self.instruction:
            parts.append(self.instruction)
            parts.append("")

        for ex in self.examples:
            parts.append(self.template.format(
                input=ex['input'],
                output=ex['output']
            ))

        # Query with blank output
        parts.append(f"{query} =>")

        return "\n".join(parts)

    def build_chat_messages(
        self,
        query: str,
        system_prompt: str = None
    ) -> List[Dict]:
        """Build chat-format messages (for GPT-4, Claude, etc.)"""
        messages = []

        if system_prompt:
            messages.append({
                'role': 'system',
                'content': system_prompt
            })

        for ex in self.examples:
            messages.append({
                'role': 'user',
                'content': ex['input']
            })
            messages.append({
                'role': 'assistant',
                'content': ex['output']
            })

        messages.append({
            'role': 'user',
            'content': query
        })

        return messages


class DynamicExampleSelector:
    """
    Dynamic example selector
    Selects the most similar examples to the query for more effective few-shot learning
    """

    def __init__(self, examples: List[Dict], encoder=None):
        self.examples = examples
        self.encoder = encoder  # sentence-transformers etc.

    def select_similar(
        self,
        query: str,
        n_examples: int = 3
    ) -> List[Dict]:
        """
        Select the n examples most similar to the query
        """
        if self.encoder is None:
            return np.random.choice(
                self.examples, n_examples, replace=False
            ).tolist()

        query_emb = self.encoder.encode(query)
        example_embs = self.encoder.encode(
            [ex['input'] for ex in self.examples]
        )

        similarities = np.dot(example_embs, query_emb) / (
            np.linalg.norm(example_embs, axis=1)
            * np.linalg.norm(query_emb)
        )

        top_indices = np.argsort(similarities)[-n_examples:][::-1]
        return [self.examples[i] for i in top_indices]

    def select_diverse(
        self,
        n_examples: int = 3
    ) -> List[Dict]:
        """
        Select examples that maximize diversity (MMR algorithm)
        """
        if len(self.examples) <= n_examples:
            return self.examples

        if self.encoder is None:
            return np.random.choice(
                self.examples, n_examples, replace=False
            ).tolist()

        embeddings = self.encoder.encode(
            [ex['input'] for ex in self.examples]
        )

        selected = [0]
        remaining = list(range(1, len(self.examples)))

        while len(selected) < n_examples:
            selected_embs = embeddings[selected]
            best_idx = None
            best_score = float('-inf')

            for idx in remaining:
                sim_to_selected = np.max(
                    np.dot(selected_embs, embeddings[idx]) / (
                        np.linalg.norm(selected_embs, axis=1)
                        * np.linalg.norm(embeddings[idx])
                    )
                )
                score = -sim_to_selected  # maximize diversity

                if score > best_score:
                    best_score = score
                    best_idx = idx

            selected.append(best_idx)
            remaining.remove(best_idx)

        return [self.examples[i] for i in selected]


# ========== Practical Examples ==========

def sentiment_analysis_few_shot():
    """Sentiment analysis few-shot example"""
    builder = FewShotPromptBuilder()

    builder.set_instruction(
        "Analyze the sentiment of the following movie review. "
        "Answer Positive or Negative."
    )

    builder.add_example(
        "This movie was truly moving and the acting was superb.",
        "Positive"
    )
    builder.add_example(
        "The story was too boring and the ending was disappointing.",
        "Negative"
    )
    builder.add_example(
        "The special effects were impressive but the plot was too weak.",
        "Negative"
    )
    builder.add_example(
        "A warm film the whole family can enjoy together.",
        "Positive"
    )

    query = "The direction was unique and the music paired perfectly with the visuals."
    prompt = builder.build_prompt(query)

    print("=== Few-Shot Prompt ===")
    print(prompt)
    return prompt


def code_generation_few_shot():
    """Code generation few-shot example"""
    builder = FewShotPromptBuilder()

    builder.set_instruction(
        "Convert the natural language description into Python code."
    )

    builder.add_example(
        "Find the maximum element in a list",
        "def find_max(lst):\n    return max(lst)"
    )
    builder.add_example(
        "Check if a string is a palindrome",
        "def is_palindrome(s):\n    return s == s[::-1]"
    )
    builder.add_example(
        "Flatten a nested list",
        "def flatten(lst):\n    return [x for sublist in lst for x in sublist]"
    )

    query = "Count the frequency of each element in a list"
    prompt = builder.build_prompt(query)

    print("=== Code Generation Few-Shot Prompt ===")
    print(prompt)
    return prompt

4.4 Cross-lingual Few-shot

Multilingual models demonstrate zero-shot cross-lingual transfer ability. Tasks trained only in English can be applied to Korean or other languages.

class CrossLingualFewShot:
    """
    Cross-lingual few-shot learning
    English examples + target-language query
    """

    def __init__(self, model_name: str = "xlm-roberta-large"):
        from transformers import AutoTokenizer, AutoModelForSequenceClassification
        import torch

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_name, num_labels=2
        )

    def encode_text(self, text: str) -> torch.Tensor:
        """Encode text to embedding"""
        inputs = self.tokenizer(
            text,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=128
        )
        with torch.no_grad():
            outputs = self.model(**inputs, output_hidden_states=True)
            embedding = outputs.hidden_states[-1][:, 0, :]
        return embedding

    def classify_zero_shot(
        self,
        query: str,
        class_descriptions_en: List[str]
    ) -> int:
        """
        Classify a query using English class descriptions
        (query can be in any language)
        """
        import torch.nn.functional as F

        query_emb = self.encode_text(query)
        class_embs = torch.cat([
            self.encode_text(desc) for desc in class_descriptions_en
        ])

        similarities = F.cosine_similarity(
            query_emb.expand(len(class_descriptions_en), -1),
            class_embs
        )

        return similarities.argmax().item()

    def few_shot_classify(
        self,
        support_texts: List[str],
        support_labels: List[int],
        query_text: str,
        n_classes: int
    ) -> int:
        """
        Few-shot classification (prototype approach)
        Support and query can be in different languages
        """
        import torch
        import torch.nn.functional as F

        support_embs = torch.cat([
            self.encode_text(t) for t in support_texts
        ])
        query_emb = self.encode_text(query_text)

        prototypes = []
        for cls in range(n_classes):
            mask = torch.tensor([l == cls for l in support_labels])
            proto = support_embs[mask].mean(dim=0, keepdim=True)
            prototypes.append(proto)
        prototypes = torch.cat(prototypes)

        dists = torch.cdist(query_emb, prototypes)
        return dists.argmin().item()

5. Hands-On: Medical Image Few-Shot Classification

5.1 Rare Disease Diagnosis System

In clinical settings, rare diseases have very limited training data. Few-shot learning enables recognition of new disease patterns from only a small number of confirmed cases.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np


class MedicalImageEncoder(nn.Module):
    """
    Encoder for medical images
    Based on ResNet-18, adapted for medical imaging characteristics
    """

    def __init__(self, embed_dim: int = 512, pretrained: bool = True):
        super().__init__()

        backbone = models.resnet18(pretrained=pretrained)
        self.backbone = nn.Sequential(*list(backbone.children())[:-1])

        self.embed_head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512, embed_dim),
            nn.LayerNorm(embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.backbone(x)
        return self.embed_head(features)


class MedicalFewShotClassifier:
    """
    Medical image few-shot classifier
    Based on Prototypical Networks
    """

    def __init__(
        self,
        encoder: MedicalImageEncoder,
        device: str = 'cpu'
    ):
        self.encoder = encoder.to(device)
        self.device = device
        self.prototypes = {}
        self.disease_names = {}

    def register_disease(
        self,
        disease_id: int,
        disease_name: str,
        support_images: List,
        transform=None
    ):
        """
        Register a new disease (from a small number of examples)
        """
        self.disease_names[disease_id] = disease_name
        self.encoder.eval()

        embeddings = []
        with torch.no_grad():
            for img in support_images:
                if transform:
                    img_tensor = transform(img).unsqueeze(0).to(self.device)
                else:
                    img_tensor = img.unsqueeze(0).to(self.device)
                emb = self.encoder(img_tensor)
                embeddings.append(emb)

        prototype = torch.cat(embeddings).mean(dim=0)
        self.prototypes[disease_id] = prototype

        print(
            f"Disease registered: {disease_name} "
            f"({len(support_images)} examples)"
        )

    def diagnose(
        self,
        query_image: torch.Tensor,
        top_k: int = 3
    ) -> List[Dict]:
        """
        Diagnose a query image
        Returns: top-k diseases with similarity scores
        """
        self.encoder.eval()

        with torch.no_grad():
            query_emb = self.encoder(
                query_image.unsqueeze(0).to(self.device)
            )

        results = []
        for disease_id, prototype in self.prototypes.items():
            similarity = F.cosine_similarity(
                query_emb, prototype.unsqueeze(0)
            ).item()
            results.append({
                'disease_id': disease_id,
                'disease_name': self.disease_names[disease_id],
                'similarity': similarity
            })

        results.sort(key=lambda x: x['similarity'], reverse=True)
        return results[:top_k]

    def update_prototype(
        self,
        disease_id: int,
        new_image: torch.Tensor,
        momentum: float = 0.9
    ):
        """
        Update prototype with a new confirmed case (online learning)
        """
        self.encoder.eval()

        with torch.no_grad():
            new_emb = self.encoder(
                new_image.unsqueeze(0).to(self.device)
            ).squeeze(0)

        if disease_id in self.prototypes:
            # Exponential moving average update
            self.prototypes[disease_id] = (
                momentum * self.prototypes[disease_id]
                + (1 - momentum) * new_emb
            )
            print(
                f"Prototype updated: {self.disease_names[disease_id]}"
            )
        else:
            print(f"Warning: disease {disease_id} is not registered.")


def demo_medical_few_shot():
    """Medical few-shot classification demo"""
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

    encoder = MedicalImageEncoder(embed_dim=512, pretrained=True)
    classifier = MedicalFewShotClassifier(encoder)

    print("=== Medical Image Few-Shot Classification System ===")
    print("This system recognizes new diseases from only a few confirmed cases.")
    print()
    print("Usage:")
    print("1. classifier.register_disease(id, name, support_images)")
    print("2. results = classifier.diagnose(patient_image)")
    print("3. classifier.update_prototype(disease_id, new_confirmed_image)")

6. Using the learn2learn Library

6.1 Introduction to learn2learn

learn2learn is a library that makes it easy to implement meta-learning algorithms such as MAML and ProtoNet.

pip install learn2learn

6.2 MAML with learn2learn

import learn2learn as l2l
import torch
import torch.nn as nn
from typing import Tuple


def build_l2l_maml(
    model: nn.Module,
    lr: float = 0.01,
    first_order: bool = False
) -> l2l.algorithms.MAML:
    """Create learn2learn MAML wrapper"""
    return l2l.algorithms.MAML(
        model,
        lr=lr,
        first_order=first_order,
        allow_unused=True
    )


def train_with_l2l(
    maml_model: l2l.algorithms.MAML,
    tasksets,
    n_way: int = 5,
    k_shot: int = 1,
    n_query: int = 15,
    meta_lr: float = 0.003,
    n_iterations: int = 1000,
    adaptation_steps: int = 1,
    device: str = 'cpu'
):
    """MAML training with learn2learn"""
    maml_model = maml_model.to(device)
    meta_optimizer = torch.optim.Adam(
        maml_model.parameters(), lr=meta_lr
    )
    criterion = nn.CrossEntropyLoss(reduction='mean')

    for iteration in range(n_iterations):
        meta_optimizer.zero_grad()
        meta_train_loss = 0.0
        meta_train_acc = 0.0

        for task in range(4):
            X, y = tasksets.train.sample()
            X, y = X.to(device), y.to(device)

            learner = maml_model.clone()

            support_indices = torch.zeros(X.size(0), dtype=torch.bool)
            for cls in range(n_way):
                cls_idx = (y == cls).nonzero(as_tuple=True)[0]
                support_idx = cls_idx[:k_shot]
                support_indices[support_idx] = True

            query_indices = ~support_indices
            support_x, support_y = X[support_indices], y[support_indices]
            query_x, query_y = X[query_indices], y[query_indices]

            # Inner loop: adaptation
            for step in range(adaptation_steps):
                support_logits = learner(support_x)
                support_loss = criterion(support_logits, support_y)
                learner.adapt(support_loss)

            # Outer loop: meta-gradient
            query_logits = learner(query_x)
            query_loss = criterion(query_logits, query_y)
            meta_train_loss += query_loss

            preds = query_logits.argmax(dim=-1)
            acc = (preds == query_y).float().mean()
            meta_train_acc += acc

        meta_train_loss /= 4
        meta_train_acc /= 4

        meta_train_loss.backward()
        meta_optimizer.step()

        if (iteration + 1) % 100 == 0:
            print(
                f"Iteration {iteration+1}/{n_iterations} | "
                f"Meta Loss: {meta_train_loss.item():.4f} | "
                f"Meta Acc: {meta_train_acc.item():.4f}"
            )


def setup_omniglot_maml():
    """
    Set up MAML with the Omniglot dataset
    Omniglot: 1623 character classes from 50 alphabets (20 samples each)
    """
    tasksets = l2l.vision.benchmarks.get_tasksets(
        'omniglot',
        train_ways=5,
        train_samples=2 * 1 + 2 * 15,
        test_ways=5,
        test_samples=2 * 1 + 2 * 15,
        root='./data',
        device='cpu'
    )

    model = l2l.vision.models.OmniglotCNN(
        output_size=5,
        hidden_size=64,
        layers=4
    )

    maml = build_l2l_maml(model, lr=0.4, first_order=False)

    return maml, tasksets


def evaluate_l2l(
    maml_model: l2l.algorithms.MAML,
    tasksets,
    n_way: int = 5,
    k_shot: int = 1,
    n_query: int = 15,
    n_test_tasks: int = 600,
    adaptation_steps: int = 3,
    device: str = 'cpu'
) -> Tuple[float, float]:
    """Meta-evaluation with learn2learn"""
    criterion = nn.CrossEntropyLoss()
    total_loss = 0.0
    total_acc = 0.0

    maml_model.eval()

    for _ in range(n_test_tasks):
        X, y = tasksets.test.sample()
        X, y = X.to(device), y.to(device)

        learner = maml_model.clone()

        support_indices = torch.zeros(X.size(0), dtype=torch.bool)
        for cls in range(n_way):
            cls_idx = (y == cls).nonzero(as_tuple=True)[0]
            support_idx = cls_idx[:k_shot]
            support_indices[support_idx] = True

        query_indices = ~support_indices
        support_x, support_y = X[support_indices], y[support_indices]
        query_x, query_y = X[query_indices], y[query_indices]

        for step in range(adaptation_steps):
            support_loss = criterion(learner(support_x), support_y)
            learner.adapt(support_loss)

        with torch.no_grad():
            query_logits = learner(query_x)
            loss = criterion(query_logits, query_y)
            acc = (query_logits.argmax(dim=-1) == query_y).float().mean()

        total_loss += loss.item()
        total_acc += acc.item()

    return total_loss / n_test_tasks, total_acc / n_test_tasks

7. Benchmarks and Evaluation

7.1 Key Benchmark Datasets

Omniglot

1623 character classes from 50 alphabet systems, 20 samples per class. Primarily evaluated in the 20-way 1-shot setting.

Mini-ImageNet

A 100-class subset of ImageNet. 600 images per class (84x84). The 5-way 1/5-shot setting is standard.

tieredImageNet

A harder version of Mini-ImageNet. Classes are grouped by superclass concepts, creating a larger semantic gap between meta-train and meta-test classes.

CIFAR-FS

A few-shot benchmark derived from CIFAR-100. Faster to experiment with than Mini-ImageNet.

7.2 Standard Evaluation Protocol

def standard_few_shot_evaluation(
    model,
    test_dataset,
    n_way: int = 5,
    k_shot: int = 1,
    n_query: int = 15,
    n_episodes: int = 600,
    confidence_interval: bool = True
) -> Dict:
    """
    Standard few-shot evaluation protocol
    Mean and 95% confidence interval over 600 episodes
    """
    accs = []
    model.eval()

    for episode in range(n_episodes):
        support_x, support_y, query_x, query_y = create_episode(
            test_dataset, n_way, k_shot, n_query
        )

        with torch.no_grad():
            log_probs = model(support_x, support_y, query_x, n_way)
            preds = log_probs.argmax(dim=-1)
            acc = (preds == query_y).float().mean().item()

        accs.append(acc)

    mean_acc = np.mean(accs)
    std_acc = np.std(accs)

    if confidence_interval:
        ci = 1.96 * std_acc / np.sqrt(n_episodes)
        return {
            'mean_accuracy': mean_acc,
            'std': std_acc,
            'confidence_interval_95': ci,
            'result_string': f"{mean_acc*100:.2f} +/- {ci*100:.2f}%"
        }

    return {'mean_accuracy': mean_acc, 'std': std_acc}

8. The Future of Meta-Learning

Meta-learning is establishing itself as a core paradigm in AI research. Key trends to watch:

Convergence with LLMs

Large language models like GPT-4 and Claude demonstrate powerful ICL capabilities. Active research views these models as meta-learners performing few-shot learning across arbitrary domains.

Multimodal Few-Shot Learning

Few-shot learning integrating text, images, and audio. Models like GPT-4V and Gemini Ultra are showing impressive performance on visual few-shot tasks.

Combination with Continual Learning

Research shows that models initialized via meta-learning forget less of previous knowledge when learning new tasks. The combination of continual learning and meta-learning is an active research area.

Domain Adaptation Applications

Industrial applications where data is scarce — rare disease diagnosis, satellite imagery analysis, specialized code generation — are emerging as the most practical use cases for few-shot learning.


References

  • Finn, C., et al. (2017). Model-Agnostic Meta-Learning for Fast Adaptation. ICML 2017. arXiv:1703.03400
  • Snell, J., et al. (2017). Prototypical Networks for Few-shot Learning. NeurIPS 2017. arXiv:1703.05175
  • Vinyals, O., et al. (2016). Matching Networks for One Shot Learning. NeurIPS 2016. arXiv:1602.01783
  • Nichol, A., et al. (2018). On First-Order Meta-Learning Algorithms. arXiv:1803.02999
  • Brown, T., et al. (2020). Language Models are Few-Shot Learners (GPT-3). NeurIPS 2020.
  • learn2learn library: https://github.com/learnables/learn2learn