- Published on
Meta-Learning and Few-Shot Learning Complete Guide: MAML, Prototypical Networks, In-Context Learning
- Authors

- Name
- Youngju Kim
- @fjvbn20031
Meta-Learning and Few-Shot Learning Complete Guide: MAML, Prototypical Networks, In-Context Learning
Humans learn new concepts quickly from just a few examples. Show a child "this is a zebra" once, and they can recognize a zebra in an entirely new photograph the next day. But traditional deep learning models require thousands of images just to build a zebra classifier.
Meta-learning and few-shot learning bridge this gap. The core idea of meta-learning is "learning to learn." The model gains the ability to rapidly adapt to new tasks by experiencing a variety of tasks during training.
This guide covers everything from the theoretical foundations of meta-learning to the latest developments in In-Context Learning.
1. Meta-Learning Fundamentals
1.1 Learning to Learn
In traditional machine learning, a model is trained for a single task. A cat classifier is trained only on cat data, and when a new animal (e.g., an ocelot) needs to be classified, training must start from scratch.
Meta-learning shifts the perspective. What the model needs to learn is not "how to classify cats" but "how to learn to quickly classify new animals."
Meta-learning therefore has two levels of learning.
- Meta-learning (Outer Loop): Learning a good initialization or learning algorithm by experiencing many tasks
- Task learning (Inner Loop): Rapidly adapting to each specific task from a small number of examples
1.2 Limitations of Traditional Learning
The limitations of traditional learning are:
Data inefficiency: A model trained on ImageNet requires millions of images. Adding new classes requires thousands of samples.
Lack of generalization: Performance drops sharply on new tasks that differ significantly from the training distribution.
Catastrophic forgetting: Learning new tasks causes the model to forget previously learned tasks.
1.3 Task Distribution
The central concept in meta-learning is the task distribution p(T). Meta-learning does not simply learn from data; it learns from a distribution of tasks.
Each task T includes:
- A distribution p(x, y) of input-output pairs
- A task loss function L
The meta-learning objective is:
min over theta: E over T ~ p(T) [L_T(f_theta)]
1.4 Support Set vs Query Set
In few-shot learning, data is divided into two roles.
Support Set: A small number of examples used as references when the model learns a new task. These correspond to training data in traditional learning, but are extremely few (e.g., 1–5 per class).
Query Set: Data used to evaluate model performance. These correspond to test data in traditional learning.
1.5 N-way K-shot Setup
The most important configuration in few-shot learning is N-way K-shot.
- N-way: Number of classes to classify
- K-shot: Number of support examples per class
For example, 5-way 1-shot is a task of classifying 5 classes with only 1 example per class. 5-way 5-shot uses 5 examples per class.
import torch
import torch.nn as nn
import numpy as np
from typing import List, Tuple, Dict
def create_episode(
dataset,
n_way: int,
k_shot: int,
n_query: int,
classes: List[int] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Create N-way K-shot episode
Returns: (support_x, support_y, query_x, query_y)
"""
all_classes = list(set(dataset.targets.tolist()))
if classes is None:
selected_classes = np.random.choice(
all_classes, n_way, replace=False
)
else:
selected_classes = classes
support_x, support_y = [], []
query_x, query_y = [], []
for new_label, cls in enumerate(selected_classes):
cls_indices = (dataset.targets == cls).nonzero(as_tuple=True)[0]
chosen = np.random.choice(
len(cls_indices), k_shot + n_query, replace=False
)
for i, idx in enumerate(cls_indices[chosen]):
x, _ = dataset[idx.item()]
if i < k_shot:
support_x.append(x)
support_y.append(new_label)
else:
query_x.append(x)
query_y.append(new_label)
support_x = torch.stack(support_x)
support_y = torch.tensor(support_y)
query_x = torch.stack(query_x)
query_y = torch.tensor(query_y)
return support_x, support_y, query_x, query_y
2. Distance-Based Meta-Learning
2.1 Matching Networks
Matching Networks, proposed by Vinyals et al. in 2016, combines attention mechanisms with the kNN idea. The core idea is to predict a query sample as an attention-weighted sum of support set labels.
Prediction formula
y_hat = sum over i: a(x_hat, x_i) * y_i
Here a(x_hat, x_i) is the attention weight between query x_hat and support sample x_i, computed as a softmax over cosine similarities.
class MatchingNetworks(nn.Module):
"""Matching Networks implementation"""
def __init__(self, encoder: nn.Module, use_fce: bool = False):
"""
encoder: feature extractor
use_fce: whether to use Full Context Embedding
"""
super().__init__()
self.encoder = encoder
self.use_fce = use_fce
def cosine_similarity(
self,
query: torch.Tensor,
support: torch.Tensor
) -> torch.Tensor:
"""
Compute cosine similarity
query: (n_query, embed_dim)
support: (n_support, embed_dim)
Returns: (n_query, n_support)
"""
query_norm = nn.functional.normalize(query, dim=-1)
support_norm = nn.functional.normalize(support, dim=-1)
return torch.mm(query_norm, support_norm.t())
def forward(
self,
support_x: torch.Tensor,
support_y: torch.Tensor,
query_x: torch.Tensor,
n_way: int
) -> torch.Tensor:
"""
support_x: (n_way * k_shot, C, H, W)
support_y: (n_way * k_shot,)
query_x: (n_query, C, H, W)
"""
support_emb = self.encoder(support_x) # (n_support, D)
query_emb = self.encoder(query_x) # (n_query, D)
similarities = self.cosine_similarity(
query_emb, support_emb
) # (n_query, n_support)
# Softmax attention
attention = nn.functional.softmax(similarities, dim=-1)
# One-hot labels
support_labels_one_hot = nn.functional.one_hot(
support_y, n_way
).float() # (n_support, n_way)
# Attention-weighted prediction
logits = torch.mm(attention, support_labels_one_hot)
return logits
2.2 Prototypical Networks
Prototypical Networks, published by Snell et al. in 2017, is one of the most elegant and intuitive meta-learning algorithms. The core idea: represent each class as a single prototype (centroid) in embedding space.
Prototype computation
The prototype of class c is the mean of the embeddings of the support samples for that class.
p_c = (1/|S_c|) sum over (x_i, y_i) in S_c: f_phi(x_i)
Classification
Classify a query sample x to the nearest prototype.
p(y=c | x) = softmax(-d(f_phi(x), p_c))
where d is the Euclidean distance.
class ConvEncoder(nn.Module):
"""4-layer CNN encoder for few-shot learning"""
def __init__(
self,
in_channels: int = 1,
hidden_dim: int = 64,
out_dim: int = 64
):
super().__init__()
def conv_block(in_ch, out_ch):
return nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.net = nn.Sequential(
conv_block(in_channels, hidden_dim),
conv_block(hidden_dim, hidden_dim),
conv_block(hidden_dim, hidden_dim),
conv_block(hidden_dim, out_dim),
nn.Flatten()
)
def forward(self, x):
return self.net(x)
class PrototypicalNetworks(nn.Module):
"""Complete Prototypical Networks implementation"""
def __init__(self, encoder: nn.Module):
super().__init__()
self.encoder = encoder
def compute_prototypes(
self,
support_x: torch.Tensor,
support_y: torch.Tensor,
n_way: int
) -> torch.Tensor:
"""
Compute class prototypes (mean of embeddings)
support_x: (n_way * k_shot, C, H, W)
support_y: (n_way * k_shot,)
Returns: (n_way, embed_dim)
"""
support_emb = self.encoder(support_x) # (n_support, D)
prototypes = []
for cls in range(n_way):
mask = (support_y == cls)
cls_embeddings = support_emb[mask]
prototype = cls_embeddings.mean(dim=0)
prototypes.append(prototype)
return torch.stack(prototypes) # (n_way, D)
def euclidean_dist(
self,
x: torch.Tensor,
y: torch.Tensor
) -> torch.Tensor:
"""
Euclidean distance
x: (n, D)
y: (m, D)
Returns: (n, m)
"""
# ||x - y||^2 = ||x||^2 + ||y||^2 - 2*x.y
n = x.size(0)
m = y.size(0)
x_sq = (x ** 2).sum(dim=1, keepdim=True).expand(n, m)
y_sq = (y ** 2).sum(dim=1, keepdim=True).expand(m, n).t()
xy = torch.mm(x, y.t())
dist = x_sq + y_sq - 2 * xy
return dist.clamp(min=0).sqrt()
def forward(
self,
support_x: torch.Tensor,
support_y: torch.Tensor,
query_x: torch.Tensor,
n_way: int
) -> torch.Tensor:
"""
Prototypical Networks forward pass
Returns: log-probabilities for query samples (n_query, n_way)
"""
prototypes = self.compute_prototypes(
support_x, support_y, n_way
) # (n_way, D)
query_emb = self.encoder(query_x) # (n_query, D)
dists = self.euclidean_dist(
query_emb, prototypes
) # (n_query, n_way)
log_probs = nn.functional.log_softmax(-dists, dim=-1)
return log_probs
# ========== Training Loop ==========
def train_prototypical(
model: PrototypicalNetworks,
train_dataset,
n_way: int = 5,
k_shot: int = 5,
n_query: int = 15,
n_episodes: int = 100,
lr: float = 1e-3,
device: str = 'cpu'
):
"""Train Prototypical Networks"""
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.NLLLoss()
model.train()
episode_losses = []
episode_accs = []
for episode in range(n_episodes):
support_x, support_y, query_x, query_y = create_episode(
train_dataset, n_way, k_shot, n_query
)
support_x = support_x.to(device)
support_y = support_y.to(device)
query_x = query_x.to(device)
query_y = query_y.to(device)
optimizer.zero_grad()
log_probs = model(support_x, support_y, query_x, n_way)
loss = criterion(log_probs, query_y)
loss.backward()
optimizer.step()
preds = log_probs.argmax(dim=-1)
acc = (preds == query_y).float().mean().item()
episode_losses.append(loss.item())
episode_accs.append(acc)
if (episode + 1) % 100 == 0:
print(
f"Episode {episode+1}/{n_episodes} | "
f"Loss: {np.mean(episode_losses[-100:]):.4f} | "
f"Acc: {np.mean(episode_accs[-100:]):.4f}"
)
return episode_losses, episode_accs
2.3 Relation Networks
Relation Networks by Sung et al. are similar to Prototypical Networks, but the distance function is replaced with a learnable neural network. The network learns to compute a relation score by concatenating the query embedding and the class prototype.
class RelationNetwork(nn.Module):
"""Relation Networks: learnable distance function"""
def __init__(self, encoder: nn.Module, embed_dim: int = 64):
super().__init__()
self.encoder = encoder
# Relation module: takes concatenated embeddings, outputs relation score
self.relation_module = nn.Sequential(
nn.Linear(embed_dim * 2, 256),
nn.ReLU(),
nn.Linear(256, 64),
nn.ReLU(),
nn.Linear(64, 1),
nn.Sigmoid()
)
def forward(
self,
support_x: torch.Tensor,
support_y: torch.Tensor,
query_x: torch.Tensor,
n_way: int
) -> torch.Tensor:
"""
Returns: relation scores (n_query, n_way)
"""
support_emb = self.encoder(support_x)
query_emb = self.encoder(query_x)
prototypes = []
for cls in range(n_way):
mask = (support_y == cls)
proto = support_emb[mask].mean(dim=0)
prototypes.append(proto)
prototypes = torch.stack(prototypes) # (n_way, D)
n_query = query_emb.size(0)
query_expanded = query_emb.unsqueeze(1).expand(
n_query, n_way, -1
) # (n_query, n_way, D)
proto_expanded = prototypes.unsqueeze(0).expand(
n_query, n_way, -1
) # (n_query, n_way, D)
pairs = torch.cat(
[query_expanded, proto_expanded], dim=-1
) # (n_query, n_way, 2D)
scores = self.relation_module(
pairs.view(-1, pairs.size(-1))
).view(n_query, n_way)
return scores
3. Optimization-Based Meta-Learning: MAML
3.1 The Core Idea of MAML
MAML (Model-Agnostic Meta-Learning), proposed by Finn et al. in 2017, is one of the most influential works in meta-learning.
MAML's goal is to "find an initial set of parameters theta from which fast adaptation is possible."
Concretely, it finds an initialization point from which only a few gradient update steps are needed to achieve good performance on a new task.
3.2 Inner Loop vs Outer Loop
MAML consists of two loops.
Inner Loop (Task-specific adaptation)
For each task T_i:
theta_i' = theta - alpha * grad_theta L_{T_i}(f_theta)
Perform 1–5 gradient updates on the support set to obtain task-specific parameters theta_i'.
Outer Loop (Meta-update)
theta = theta - beta * grad_theta sum_i L_{T_i}(f_{theta_i'})
Compute the loss on the query set using task-adapted parameters theta_i', then update the meta-parameters theta based on this loss.
3.3 Second-Order Gradients
The key technical challenge of MAML is that the outer loop gradient requires second-order gradients (backpropagation through the inner loop).
grad_theta L(f_{theta_i'}) = grad_theta L(f_{theta - alpha * grad L(theta)})
Computing gradients with respect to theta requires differentiating through the inner loop update (involves the Hessian matrix). This is computationally expensive.
In practice, FOMAML (First-Order MAML) is often used, which ignores the second-order gradient terms and uses an approximate gradient.
3.4 Complete MAML Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from copy import deepcopy
from typing import List, Tuple
class MAML:
"""
MAML: Model-Agnostic Meta-Learning
Finn et al., 2017 (arXiv:1703.03400)
"""
def __init__(
self,
model: nn.Module,
inner_lr: float = 0.01, # alpha: inner loop learning rate
outer_lr: float = 0.001, # beta: outer loop learning rate
n_inner_steps: int = 5, # number of inner loop updates
first_order: bool = False, # whether to use FOMAML
device: str = 'cpu'
):
self.model = model.to(device)
self.inner_lr = inner_lr
self.outer_lr = outer_lr
self.n_inner_steps = n_inner_steps
self.first_order = first_order
self.device = device
self.meta_optimizer = torch.optim.Adam(
self.model.parameters(), lr=outer_lr
)
def inner_loop(
self,
support_x: torch.Tensor,
support_y: torch.Tensor,
model_params=None
) -> dict:
"""
Inner loop: task-specific adaptation on the support set
Returns: adapted parameter dictionary
"""
if model_params is None:
params = {
name: param.clone()
for name, param in self.model.named_parameters()
}
else:
params = {k: v.clone() for k, v in model_params.items()}
for step in range(self.n_inner_steps):
logits = self._forward_with_params(support_x, params)
loss = F.cross_entropy(logits, support_y)
grads = torch.autograd.grad(
loss,
params.values(),
create_graph=not self.first_order
)
params = {
name: param - self.inner_lr * grad
for (name, param), grad in zip(params.items(), grads)
}
return params
def _forward_with_params(
self,
x: torch.Tensor,
params: dict
) -> torch.Tensor:
"""Forward pass with specific parameters"""
original_params = {}
for name, param in self.model.named_parameters():
original_params[name] = param.data
param.data = params[name].data if name in params else param.data
output = self.model(x)
for name, param in self.model.named_parameters():
if name in original_params:
param.data = original_params[name]
return output
def meta_train_step(
self,
tasks: List[Tuple]
) -> float:
"""
One meta-training step (over multiple tasks)
tasks: [(support_x, support_y, query_x, query_y), ...]
"""
self.meta_optimizer.zero_grad()
meta_loss = 0.0
for support_x, support_y, query_x, query_y in tasks:
support_x = support_x.to(self.device)
support_y = support_y.to(self.device)
query_x = query_x.to(self.device)
query_y = query_y.to(self.device)
adapted_params = self.inner_loop(support_x, support_y)
query_logits = self._forward_with_params(
query_x, adapted_params
)
query_loss = F.cross_entropy(query_logits, query_y)
meta_loss += query_loss
meta_loss /= len(tasks)
meta_loss.backward()
self.meta_optimizer.step()
return meta_loss.item()
def fine_tune(
self,
support_x: torch.Tensor,
support_y: torch.Tensor,
n_steps: int = None
) -> nn.Module:
"""
Fine-tune on a new task (used at inference)
"""
n_steps = n_steps or self.n_inner_steps
model_copy = deepcopy(self.model)
optimizer = torch.optim.SGD(
model_copy.parameters(), lr=self.inner_lr
)
model_copy.train()
for step in range(n_steps):
optimizer.zero_grad()
logits = model_copy(support_x.to(self.device))
loss = F.cross_entropy(logits, support_y.to(self.device))
loss.backward()
optimizer.step()
return model_copy
def evaluate(
self,
tasks: List[Tuple],
n_fine_tune_steps: int = 5
) -> Tuple[float, float]:
"""
Meta-test: adapt to new tasks and evaluate
"""
total_loss = 0.0
total_acc = 0.0
for support_x, support_y, query_x, query_y in tasks:
adapted_model = self.fine_tune(
support_x, support_y, n_fine_tune_steps
)
adapted_model.eval()
with torch.no_grad():
query_logits = adapted_model(query_x.to(self.device))
loss = F.cross_entropy(
query_logits, query_y.to(self.device)
)
preds = query_logits.argmax(dim=-1)
acc = (preds == query_y.to(self.device)).float().mean()
total_loss += loss.item()
total_acc += acc.item()
return total_loss / len(tasks), total_acc / len(tasks)
def train_maml(
maml: MAML,
dataset,
n_way: int = 5,
k_shot: int = 1,
n_query: int = 15,
meta_batch_size: int = 32,
n_iterations: int = 60000,
device: str = 'cpu'
):
"""MAML training loop"""
print(f"MAML Training: {n_way}-way {k_shot}-shot")
print(f"Meta-batch size: {meta_batch_size}")
print(f"Total iterations: {n_iterations}")
losses = []
for iteration in range(n_iterations):
tasks = []
for _ in range(meta_batch_size):
task = create_episode(dataset, n_way, k_shot, n_query)
tasks.append(task)
meta_loss = maml.meta_train_step(tasks)
losses.append(meta_loss)
if (iteration + 1) % 1000 == 0:
avg_loss = np.mean(losses[-1000:])
print(
f"Iteration {iteration+1}/{n_iterations} | "
f"Meta Loss: {avg_loss:.4f}"
)
return losses
3.5 Reptile: Simplifying MAML
Reptile (Nichol et al., 2018) is a greatly simplified version of MAML. It requires no second-order gradients and is much easier to implement.
Core idea: After running SGD multiple times on a task, move the meta-parameters toward the resulting parameters.
theta = theta + epsilon * (W_k - theta)
where W_k is the parameter after k SGD updates on task T.
class Reptile:
"""
Reptile: A Scalable Meta-learning Algorithm
Nichol et al., 2018 (arXiv:1803.02999)
"""
def __init__(
self,
model: nn.Module,
inner_lr: float = 0.02,
outer_lr: float = 0.001, # epsilon
n_inner_steps: int = 5,
device: str = 'cpu'
):
self.model = model.to(device)
self.inner_lr = inner_lr
self.outer_lr = outer_lr
self.n_inner_steps = n_inner_steps
self.device = device
def inner_train(
self,
support_x: torch.Tensor,
support_y: torch.Tensor
) -> dict:
"""Task-specific inner training (k SGD steps)"""
model_copy = deepcopy(self.model)
optimizer = torch.optim.SGD(
model_copy.parameters(), lr=self.inner_lr
)
model_copy.train()
for step in range(self.n_inner_steps):
optimizer.zero_grad()
logits = model_copy(support_x)
loss = F.cross_entropy(logits, support_y)
loss.backward()
optimizer.step()
return dict(model_copy.named_parameters())
def meta_update(self, task_params_list: List[dict]):
"""
Reptile meta-update:
theta += epsilon * (mean(W_k) - theta)
"""
with torch.no_grad():
for name, param in self.model.named_parameters():
task_mean = torch.stack([
task_params[name].data
for task_params in task_params_list
]).mean(dim=0)
param.data += self.outer_lr * (task_mean - param.data)
def train(
self,
dataset,
n_way: int = 5,
k_shot: int = 5,
meta_batch_size: int = 5,
n_iterations: int = 100000
):
"""Reptile full training loop"""
print(f"Reptile Training: {n_way}-way {k_shot}-shot")
for iteration in range(n_iterations):
task_params_list = []
for _ in range(meta_batch_size):
support_x, support_y, _, _ = create_episode(
dataset, n_way, k_shot, n_query=0
)
support_x = support_x.to(self.device)
support_y = support_y.to(self.device)
task_params = self.inner_train(support_x, support_y)
task_params_list.append(task_params)
self.meta_update(task_params_list)
if (iteration + 1) % 10000 == 0:
print(f"Iteration {iteration+1}/{n_iterations} complete")
4. In-Context Learning in LLMs
4.1 What Is In-Context Learning?
In-Context Learning (ICL) is the ability of large language models (LLMs) to perform new tasks from examples within the prompt. The model does not update its parameters; it "learns" solely from the input context (prompt).
This ability gained enormous attention with the arrival of GPT-3. For example:
English to French translation:
sea otter => loutre de mer
peppermint => menthe poivree
plush giraffe => girafe en peluche
cheese => ?
Given this prompt format, GPT-3 answers "fromage." It appears not to have been trained for French translation, but has already learned the pattern during pre-training.
4.2 Why Is It Effective?
The reasons ICL is effective are still being actively studied. Major hypotheses include:
Pattern completion view: LLMs are trained by predicting the next token. Seeing the pattern (input to output pairs) in the prompt, they are effectively trained to continue that pattern.
Latent concept inference view: According to research by Brown et al., ICL resembles Bayesian inference where the model infers a latent concept from the prompt.
Gradient descent metaphor: Akyürek et al. showed that the Transformer's attention mechanism implicitly performs operations analogous to gradient descent.
4.3 Effective Few-Shot Prompting Strategies
from typing import List, Dict, Any
import numpy as np
class FewShotPromptBuilder:
"""Few-shot prompt builder"""
def __init__(self):
self.examples = []
self.instruction = ""
self.template = "{input} => {output}"
def set_instruction(self, instruction: str):
"""Set task instruction"""
self.instruction = instruction
return self
def add_example(self, input_text: str, output_text: str):
"""Add an example"""
self.examples.append({
'input': input_text,
'output': output_text
})
return self
def build_prompt(self, query: str) -> str:
"""Build complete few-shot prompt"""
parts = []
if self.instruction:
parts.append(self.instruction)
parts.append("")
for ex in self.examples:
parts.append(self.template.format(
input=ex['input'],
output=ex['output']
))
# Query with blank output
parts.append(f"{query} =>")
return "\n".join(parts)
def build_chat_messages(
self,
query: str,
system_prompt: str = None
) -> List[Dict]:
"""Build chat-format messages (for GPT-4, Claude, etc.)"""
messages = []
if system_prompt:
messages.append({
'role': 'system',
'content': system_prompt
})
for ex in self.examples:
messages.append({
'role': 'user',
'content': ex['input']
})
messages.append({
'role': 'assistant',
'content': ex['output']
})
messages.append({
'role': 'user',
'content': query
})
return messages
class DynamicExampleSelector:
"""
Dynamic example selector
Selects the most similar examples to the query for more effective few-shot learning
"""
def __init__(self, examples: List[Dict], encoder=None):
self.examples = examples
self.encoder = encoder # sentence-transformers etc.
def select_similar(
self,
query: str,
n_examples: int = 3
) -> List[Dict]:
"""
Select the n examples most similar to the query
"""
if self.encoder is None:
return np.random.choice(
self.examples, n_examples, replace=False
).tolist()
query_emb = self.encoder.encode(query)
example_embs = self.encoder.encode(
[ex['input'] for ex in self.examples]
)
similarities = np.dot(example_embs, query_emb) / (
np.linalg.norm(example_embs, axis=1)
* np.linalg.norm(query_emb)
)
top_indices = np.argsort(similarities)[-n_examples:][::-1]
return [self.examples[i] for i in top_indices]
def select_diverse(
self,
n_examples: int = 3
) -> List[Dict]:
"""
Select examples that maximize diversity (MMR algorithm)
"""
if len(self.examples) <= n_examples:
return self.examples
if self.encoder is None:
return np.random.choice(
self.examples, n_examples, replace=False
).tolist()
embeddings = self.encoder.encode(
[ex['input'] for ex in self.examples]
)
selected = [0]
remaining = list(range(1, len(self.examples)))
while len(selected) < n_examples:
selected_embs = embeddings[selected]
best_idx = None
best_score = float('-inf')
for idx in remaining:
sim_to_selected = np.max(
np.dot(selected_embs, embeddings[idx]) / (
np.linalg.norm(selected_embs, axis=1)
* np.linalg.norm(embeddings[idx])
)
)
score = -sim_to_selected # maximize diversity
if score > best_score:
best_score = score
best_idx = idx
selected.append(best_idx)
remaining.remove(best_idx)
return [self.examples[i] for i in selected]
# ========== Practical Examples ==========
def sentiment_analysis_few_shot():
"""Sentiment analysis few-shot example"""
builder = FewShotPromptBuilder()
builder.set_instruction(
"Analyze the sentiment of the following movie review. "
"Answer Positive or Negative."
)
builder.add_example(
"This movie was truly moving and the acting was superb.",
"Positive"
)
builder.add_example(
"The story was too boring and the ending was disappointing.",
"Negative"
)
builder.add_example(
"The special effects were impressive but the plot was too weak.",
"Negative"
)
builder.add_example(
"A warm film the whole family can enjoy together.",
"Positive"
)
query = "The direction was unique and the music paired perfectly with the visuals."
prompt = builder.build_prompt(query)
print("=== Few-Shot Prompt ===")
print(prompt)
return prompt
def code_generation_few_shot():
"""Code generation few-shot example"""
builder = FewShotPromptBuilder()
builder.set_instruction(
"Convert the natural language description into Python code."
)
builder.add_example(
"Find the maximum element in a list",
"def find_max(lst):\n return max(lst)"
)
builder.add_example(
"Check if a string is a palindrome",
"def is_palindrome(s):\n return s == s[::-1]"
)
builder.add_example(
"Flatten a nested list",
"def flatten(lst):\n return [x for sublist in lst for x in sublist]"
)
query = "Count the frequency of each element in a list"
prompt = builder.build_prompt(query)
print("=== Code Generation Few-Shot Prompt ===")
print(prompt)
return prompt
4.4 Cross-lingual Few-shot
Multilingual models demonstrate zero-shot cross-lingual transfer ability. Tasks trained only in English can be applied to Korean or other languages.
class CrossLingualFewShot:
"""
Cross-lingual few-shot learning
English examples + target-language query
"""
def __init__(self, model_name: str = "xlm-roberta-large"):
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=2
)
def encode_text(self, text: str) -> torch.Tensor:
"""Encode text to embedding"""
inputs = self.tokenizer(
text,
return_tensors='pt',
padding=True,
truncation=True,
max_length=128
)
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
embedding = outputs.hidden_states[-1][:, 0, :]
return embedding
def classify_zero_shot(
self,
query: str,
class_descriptions_en: List[str]
) -> int:
"""
Classify a query using English class descriptions
(query can be in any language)
"""
import torch.nn.functional as F
query_emb = self.encode_text(query)
class_embs = torch.cat([
self.encode_text(desc) for desc in class_descriptions_en
])
similarities = F.cosine_similarity(
query_emb.expand(len(class_descriptions_en), -1),
class_embs
)
return similarities.argmax().item()
def few_shot_classify(
self,
support_texts: List[str],
support_labels: List[int],
query_text: str,
n_classes: int
) -> int:
"""
Few-shot classification (prototype approach)
Support and query can be in different languages
"""
import torch
import torch.nn.functional as F
support_embs = torch.cat([
self.encode_text(t) for t in support_texts
])
query_emb = self.encode_text(query_text)
prototypes = []
for cls in range(n_classes):
mask = torch.tensor([l == cls for l in support_labels])
proto = support_embs[mask].mean(dim=0, keepdim=True)
prototypes.append(proto)
prototypes = torch.cat(prototypes)
dists = torch.cdist(query_emb, prototypes)
return dists.argmin().item()
5. Hands-On: Medical Image Few-Shot Classification
5.1 Rare Disease Diagnosis System
In clinical settings, rare diseases have very limited training data. Few-shot learning enables recognition of new disease patterns from only a small number of confirmed cases.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
class MedicalImageEncoder(nn.Module):
"""
Encoder for medical images
Based on ResNet-18, adapted for medical imaging characteristics
"""
def __init__(self, embed_dim: int = 512, pretrained: bool = True):
super().__init__()
backbone = models.resnet18(pretrained=pretrained)
self.backbone = nn.Sequential(*list(backbone.children())[:-1])
self.embed_head = nn.Sequential(
nn.Flatten(),
nn.Linear(512, embed_dim),
nn.LayerNorm(embed_dim),
nn.ReLU(),
nn.Linear(embed_dim, embed_dim)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
features = self.backbone(x)
return self.embed_head(features)
class MedicalFewShotClassifier:
"""
Medical image few-shot classifier
Based on Prototypical Networks
"""
def __init__(
self,
encoder: MedicalImageEncoder,
device: str = 'cpu'
):
self.encoder = encoder.to(device)
self.device = device
self.prototypes = {}
self.disease_names = {}
def register_disease(
self,
disease_id: int,
disease_name: str,
support_images: List,
transform=None
):
"""
Register a new disease (from a small number of examples)
"""
self.disease_names[disease_id] = disease_name
self.encoder.eval()
embeddings = []
with torch.no_grad():
for img in support_images:
if transform:
img_tensor = transform(img).unsqueeze(0).to(self.device)
else:
img_tensor = img.unsqueeze(0).to(self.device)
emb = self.encoder(img_tensor)
embeddings.append(emb)
prototype = torch.cat(embeddings).mean(dim=0)
self.prototypes[disease_id] = prototype
print(
f"Disease registered: {disease_name} "
f"({len(support_images)} examples)"
)
def diagnose(
self,
query_image: torch.Tensor,
top_k: int = 3
) -> List[Dict]:
"""
Diagnose a query image
Returns: top-k diseases with similarity scores
"""
self.encoder.eval()
with torch.no_grad():
query_emb = self.encoder(
query_image.unsqueeze(0).to(self.device)
)
results = []
for disease_id, prototype in self.prototypes.items():
similarity = F.cosine_similarity(
query_emb, prototype.unsqueeze(0)
).item()
results.append({
'disease_id': disease_id,
'disease_name': self.disease_names[disease_id],
'similarity': similarity
})
results.sort(key=lambda x: x['similarity'], reverse=True)
return results[:top_k]
def update_prototype(
self,
disease_id: int,
new_image: torch.Tensor,
momentum: float = 0.9
):
"""
Update prototype with a new confirmed case (online learning)
"""
self.encoder.eval()
with torch.no_grad():
new_emb = self.encoder(
new_image.unsqueeze(0).to(self.device)
).squeeze(0)
if disease_id in self.prototypes:
# Exponential moving average update
self.prototypes[disease_id] = (
momentum * self.prototypes[disease_id]
+ (1 - momentum) * new_emb
)
print(
f"Prototype updated: {self.disease_names[disease_id]}"
)
else:
print(f"Warning: disease {disease_id} is not registered.")
def demo_medical_few_shot():
"""Medical few-shot classification demo"""
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
encoder = MedicalImageEncoder(embed_dim=512, pretrained=True)
classifier = MedicalFewShotClassifier(encoder)
print("=== Medical Image Few-Shot Classification System ===")
print("This system recognizes new diseases from only a few confirmed cases.")
print()
print("Usage:")
print("1. classifier.register_disease(id, name, support_images)")
print("2. results = classifier.diagnose(patient_image)")
print("3. classifier.update_prototype(disease_id, new_confirmed_image)")
6. Using the learn2learn Library
6.1 Introduction to learn2learn
learn2learn is a library that makes it easy to implement meta-learning algorithms such as MAML and ProtoNet.
pip install learn2learn
6.2 MAML with learn2learn
import learn2learn as l2l
import torch
import torch.nn as nn
from typing import Tuple
def build_l2l_maml(
model: nn.Module,
lr: float = 0.01,
first_order: bool = False
) -> l2l.algorithms.MAML:
"""Create learn2learn MAML wrapper"""
return l2l.algorithms.MAML(
model,
lr=lr,
first_order=first_order,
allow_unused=True
)
def train_with_l2l(
maml_model: l2l.algorithms.MAML,
tasksets,
n_way: int = 5,
k_shot: int = 1,
n_query: int = 15,
meta_lr: float = 0.003,
n_iterations: int = 1000,
adaptation_steps: int = 1,
device: str = 'cpu'
):
"""MAML training with learn2learn"""
maml_model = maml_model.to(device)
meta_optimizer = torch.optim.Adam(
maml_model.parameters(), lr=meta_lr
)
criterion = nn.CrossEntropyLoss(reduction='mean')
for iteration in range(n_iterations):
meta_optimizer.zero_grad()
meta_train_loss = 0.0
meta_train_acc = 0.0
for task in range(4):
X, y = tasksets.train.sample()
X, y = X.to(device), y.to(device)
learner = maml_model.clone()
support_indices = torch.zeros(X.size(0), dtype=torch.bool)
for cls in range(n_way):
cls_idx = (y == cls).nonzero(as_tuple=True)[0]
support_idx = cls_idx[:k_shot]
support_indices[support_idx] = True
query_indices = ~support_indices
support_x, support_y = X[support_indices], y[support_indices]
query_x, query_y = X[query_indices], y[query_indices]
# Inner loop: adaptation
for step in range(adaptation_steps):
support_logits = learner(support_x)
support_loss = criterion(support_logits, support_y)
learner.adapt(support_loss)
# Outer loop: meta-gradient
query_logits = learner(query_x)
query_loss = criterion(query_logits, query_y)
meta_train_loss += query_loss
preds = query_logits.argmax(dim=-1)
acc = (preds == query_y).float().mean()
meta_train_acc += acc
meta_train_loss /= 4
meta_train_acc /= 4
meta_train_loss.backward()
meta_optimizer.step()
if (iteration + 1) % 100 == 0:
print(
f"Iteration {iteration+1}/{n_iterations} | "
f"Meta Loss: {meta_train_loss.item():.4f} | "
f"Meta Acc: {meta_train_acc.item():.4f}"
)
def setup_omniglot_maml():
"""
Set up MAML with the Omniglot dataset
Omniglot: 1623 character classes from 50 alphabets (20 samples each)
"""
tasksets = l2l.vision.benchmarks.get_tasksets(
'omniglot',
train_ways=5,
train_samples=2 * 1 + 2 * 15,
test_ways=5,
test_samples=2 * 1 + 2 * 15,
root='./data',
device='cpu'
)
model = l2l.vision.models.OmniglotCNN(
output_size=5,
hidden_size=64,
layers=4
)
maml = build_l2l_maml(model, lr=0.4, first_order=False)
return maml, tasksets
def evaluate_l2l(
maml_model: l2l.algorithms.MAML,
tasksets,
n_way: int = 5,
k_shot: int = 1,
n_query: int = 15,
n_test_tasks: int = 600,
adaptation_steps: int = 3,
device: str = 'cpu'
) -> Tuple[float, float]:
"""Meta-evaluation with learn2learn"""
criterion = nn.CrossEntropyLoss()
total_loss = 0.0
total_acc = 0.0
maml_model.eval()
for _ in range(n_test_tasks):
X, y = tasksets.test.sample()
X, y = X.to(device), y.to(device)
learner = maml_model.clone()
support_indices = torch.zeros(X.size(0), dtype=torch.bool)
for cls in range(n_way):
cls_idx = (y == cls).nonzero(as_tuple=True)[0]
support_idx = cls_idx[:k_shot]
support_indices[support_idx] = True
query_indices = ~support_indices
support_x, support_y = X[support_indices], y[support_indices]
query_x, query_y = X[query_indices], y[query_indices]
for step in range(adaptation_steps):
support_loss = criterion(learner(support_x), support_y)
learner.adapt(support_loss)
with torch.no_grad():
query_logits = learner(query_x)
loss = criterion(query_logits, query_y)
acc = (query_logits.argmax(dim=-1) == query_y).float().mean()
total_loss += loss.item()
total_acc += acc.item()
return total_loss / n_test_tasks, total_acc / n_test_tasks
7. Benchmarks and Evaluation
7.1 Key Benchmark Datasets
Omniglot
1623 character classes from 50 alphabet systems, 20 samples per class. Primarily evaluated in the 20-way 1-shot setting.
Mini-ImageNet
A 100-class subset of ImageNet. 600 images per class (84x84). The 5-way 1/5-shot setting is standard.
tieredImageNet
A harder version of Mini-ImageNet. Classes are grouped by superclass concepts, creating a larger semantic gap between meta-train and meta-test classes.
CIFAR-FS
A few-shot benchmark derived from CIFAR-100. Faster to experiment with than Mini-ImageNet.
7.2 Standard Evaluation Protocol
def standard_few_shot_evaluation(
model,
test_dataset,
n_way: int = 5,
k_shot: int = 1,
n_query: int = 15,
n_episodes: int = 600,
confidence_interval: bool = True
) -> Dict:
"""
Standard few-shot evaluation protocol
Mean and 95% confidence interval over 600 episodes
"""
accs = []
model.eval()
for episode in range(n_episodes):
support_x, support_y, query_x, query_y = create_episode(
test_dataset, n_way, k_shot, n_query
)
with torch.no_grad():
log_probs = model(support_x, support_y, query_x, n_way)
preds = log_probs.argmax(dim=-1)
acc = (preds == query_y).float().mean().item()
accs.append(acc)
mean_acc = np.mean(accs)
std_acc = np.std(accs)
if confidence_interval:
ci = 1.96 * std_acc / np.sqrt(n_episodes)
return {
'mean_accuracy': mean_acc,
'std': std_acc,
'confidence_interval_95': ci,
'result_string': f"{mean_acc*100:.2f} +/- {ci*100:.2f}%"
}
return {'mean_accuracy': mean_acc, 'std': std_acc}
8. The Future of Meta-Learning
Meta-learning is establishing itself as a core paradigm in AI research. Key trends to watch:
Convergence with LLMs
Large language models like GPT-4 and Claude demonstrate powerful ICL capabilities. Active research views these models as meta-learners performing few-shot learning across arbitrary domains.
Multimodal Few-Shot Learning
Few-shot learning integrating text, images, and audio. Models like GPT-4V and Gemini Ultra are showing impressive performance on visual few-shot tasks.
Combination with Continual Learning
Research shows that models initialized via meta-learning forget less of previous knowledge when learning new tasks. The combination of continual learning and meta-learning is an active research area.
Domain Adaptation Applications
Industrial applications where data is scarce — rare disease diagnosis, satellite imagery analysis, specialized code generation — are emerging as the most practical use cases for few-shot learning.
References
- Finn, C., et al. (2017). Model-Agnostic Meta-Learning for Fast Adaptation. ICML 2017. arXiv:1703.03400
- Snell, J., et al. (2017). Prototypical Networks for Few-shot Learning. NeurIPS 2017. arXiv:1703.05175
- Vinyals, O., et al. (2016). Matching Networks for One Shot Learning. NeurIPS 2016. arXiv:1602.01783
- Nichol, A., et al. (2018). On First-Order Meta-Learning Algorithms. arXiv:1803.02999
- Brown, T., et al. (2020). Language Models are Few-Shot Learners (GPT-3). NeurIPS 2020.
- learn2learn library: https://github.com/learnables/learn2learn