Skip to content

필사 모드: Meta-Learning and Few-Shot Learning Complete Guide: MAML, Prototypical Networks, In-Context Learning

English
0%
정확도 0%
💡 왼쪽 원문을 읽으면서 오른쪽에 따라 써보세요. Tab 키로 힌트를 받을 수 있습니다.
원문 렌더가 준비되기 전까지 텍스트 가이드로 표시합니다.

Meta-Learning and Few-Shot Learning Complete Guide: MAML, Prototypical Networks, In-Context Learning

Humans learn new concepts quickly from just a few examples. Show a child "this is a zebra" once, and they can recognize a zebra in an entirely new photograph the next day. But traditional deep learning models require thousands of images just to build a zebra classifier.

**Meta-learning** and **few-shot learning** bridge this gap. The core idea of meta-learning is "learning to learn." The model gains the ability to rapidly adapt to new tasks by experiencing a variety of tasks during training.

This guide covers everything from the theoretical foundations of meta-learning to the latest developments in In-Context Learning.

1. Meta-Learning Fundamentals

1.1 Learning to Learn

In traditional machine learning, a model is trained for a single task. A cat classifier is trained only on cat data, and when a new animal (e.g., an ocelot) needs to be classified, training must start from scratch.

Meta-learning shifts the perspective. What the model needs to learn is not "how to classify cats" but "how to learn to quickly classify new animals."

Meta-learning therefore has two levels of learning.

- **Meta-learning (Outer Loop)**: Learning a good initialization or learning algorithm by experiencing many tasks

- **Task learning (Inner Loop)**: Rapidly adapting to each specific task from a small number of examples

1.2 Limitations of Traditional Learning

The limitations of traditional learning are:

**Data inefficiency**: A model trained on ImageNet requires millions of images. Adding new classes requires thousands of samples.

**Lack of generalization**: Performance drops sharply on new tasks that differ significantly from the training distribution.

**Catastrophic forgetting**: Learning new tasks causes the model to forget previously learned tasks.

1.3 Task Distribution

The central concept in meta-learning is the **task distribution p(T)**. Meta-learning does not simply learn from data; it learns from a distribution of tasks.

Each task T includes:

- A distribution p(x, y) of input-output pairs

- A task loss function L

The meta-learning objective is:

min over theta: E over T ~ p(T) [L_T(f_theta)]

1.4 Support Set vs Query Set

In few-shot learning, data is divided into two roles.

**Support Set**: A small number of examples used as references when the model learns a new task. These correspond to training data in traditional learning, but are extremely few (e.g., 1–5 per class).

**Query Set**: Data used to evaluate model performance. These correspond to test data in traditional learning.

1.5 N-way K-shot Setup

The most important configuration in few-shot learning is **N-way K-shot**.

- **N-way**: Number of classes to classify

- **K-shot**: Number of support examples per class

For example, 5-way 1-shot is a task of classifying 5 classes with only 1 example per class. 5-way 5-shot uses 5 examples per class.

from typing import List, Tuple, Dict

def create_episode(

dataset,

n_way: int,

k_shot: int,

n_query: int,

classes: List[int] = None

) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

"""

Create N-way K-shot episode

Returns: (support_x, support_y, query_x, query_y)

"""

all_classes = list(set(dataset.targets.tolist()))

if classes is None:

selected_classes = np.random.choice(

all_classes, n_way, replace=False

)

else:

selected_classes = classes

support_x, support_y = [], []

query_x, query_y = [], []

for new_label, cls in enumerate(selected_classes):

cls_indices = (dataset.targets == cls).nonzero(as_tuple=True)[0]

chosen = np.random.choice(

len(cls_indices), k_shot + n_query, replace=False

)

for i, idx in enumerate(cls_indices[chosen]):

x, _ = dataset[idx.item()]

if i < k_shot:

support_x.append(x)

support_y.append(new_label)

else:

query_x.append(x)

query_y.append(new_label)

support_x = torch.stack(support_x)

support_y = torch.tensor(support_y)

query_x = torch.stack(query_x)

query_y = torch.tensor(query_y)

return support_x, support_y, query_x, query_y

2. Distance-Based Meta-Learning

2.1 Matching Networks

Matching Networks, proposed by Vinyals et al. in 2016, combines attention mechanisms with the kNN idea. The core idea is to predict a query sample as an attention-weighted sum of support set labels.

**Prediction formula**

y_hat = sum over i: a(x_hat, x_i) * y_i

Here a(x_hat, x_i) is the attention weight between query x_hat and support sample x_i, computed as a softmax over cosine similarities.

class MatchingNetworks(nn.Module):

"""Matching Networks implementation"""

def __init__(self, encoder: nn.Module, use_fce: bool = False):

"""

encoder: feature extractor

use_fce: whether to use Full Context Embedding

"""

super().__init__()

self.encoder = encoder

self.use_fce = use_fce

def cosine_similarity(

self,

query: torch.Tensor,

support: torch.Tensor

) -> torch.Tensor:

"""

Compute cosine similarity

query: (n_query, embed_dim)

support: (n_support, embed_dim)

Returns: (n_query, n_support)

"""

query_norm = nn.functional.normalize(query, dim=-1)

support_norm = nn.functional.normalize(support, dim=-1)

return torch.mm(query_norm, support_norm.t())

def forward(

self,

support_x: torch.Tensor,

support_y: torch.Tensor,

query_x: torch.Tensor,

n_way: int

) -> torch.Tensor:

"""

support_x: (n_way * k_shot, C, H, W)

support_y: (n_way * k_shot,)

query_x: (n_query, C, H, W)

"""

support_emb = self.encoder(support_x) # (n_support, D)

query_emb = self.encoder(query_x) # (n_query, D)

similarities = self.cosine_similarity(

query_emb, support_emb

) # (n_query, n_support)

Softmax attention

attention = nn.functional.softmax(similarities, dim=-1)

One-hot labels

support_labels_one_hot = nn.functional.one_hot(

support_y, n_way

).float() # (n_support, n_way)

Attention-weighted prediction

logits = torch.mm(attention, support_labels_one_hot)

return logits

2.2 Prototypical Networks

Prototypical Networks, published by Snell et al. in 2017, is one of the most elegant and intuitive meta-learning algorithms. The core idea: **represent each class as a single prototype (centroid) in embedding space.**

**Prototype computation**

The prototype of class c is the mean of the embeddings of the support samples for that class.

p_c = (1/|S_c|) sum over (x_i, y_i) in S_c: f_phi(x_i)

**Classification**

Classify a query sample x to the nearest prototype.

p(y=c | x) = softmax(-d(f_phi(x), p_c))

where d is the Euclidean distance.

class ConvEncoder(nn.Module):

"""4-layer CNN encoder for few-shot learning"""

def __init__(

self,

in_channels: int = 1,

hidden_dim: int = 64,

out_dim: int = 64

):

super().__init__()

def conv_block(in_ch, out_ch):

return nn.Sequential(

nn.Conv2d(in_ch, out_ch, 3, padding=1),

nn.BatchNorm2d(out_ch),

nn.ReLU(),

nn.MaxPool2d(2)

)

self.net = nn.Sequential(

conv_block(in_channels, hidden_dim),

conv_block(hidden_dim, hidden_dim),

conv_block(hidden_dim, hidden_dim),

conv_block(hidden_dim, out_dim),

nn.Flatten()

)

def forward(self, x):

return self.net(x)

class PrototypicalNetworks(nn.Module):

"""Complete Prototypical Networks implementation"""

def __init__(self, encoder: nn.Module):

super().__init__()

self.encoder = encoder

def compute_prototypes(

self,

support_x: torch.Tensor,

support_y: torch.Tensor,

n_way: int

) -> torch.Tensor:

"""

Compute class prototypes (mean of embeddings)

support_x: (n_way * k_shot, C, H, W)

support_y: (n_way * k_shot,)

Returns: (n_way, embed_dim)

"""

support_emb = self.encoder(support_x) # (n_support, D)

prototypes = []

for cls in range(n_way):

mask = (support_y == cls)

cls_embeddings = support_emb[mask]

prototype = cls_embeddings.mean(dim=0)

prototypes.append(prototype)

return torch.stack(prototypes) # (n_way, D)

def euclidean_dist(

self,

x: torch.Tensor,

y: torch.Tensor

) -> torch.Tensor:

"""

Euclidean distance

x: (n, D)

y: (m, D)

Returns: (n, m)

"""

||x - y||^2 = ||x||^2 + ||y||^2 - 2*x.y

n = x.size(0)

m = y.size(0)

x_sq = (x ** 2).sum(dim=1, keepdim=True).expand(n, m)

y_sq = (y ** 2).sum(dim=1, keepdim=True).expand(m, n).t()

xy = torch.mm(x, y.t())

dist = x_sq + y_sq - 2 * xy

return dist.clamp(min=0).sqrt()

def forward(

self,

support_x: torch.Tensor,

support_y: torch.Tensor,

query_x: torch.Tensor,

n_way: int

) -> torch.Tensor:

"""

Prototypical Networks forward pass

Returns: log-probabilities for query samples (n_query, n_way)

"""

prototypes = self.compute_prototypes(

support_x, support_y, n_way

) # (n_way, D)

query_emb = self.encoder(query_x) # (n_query, D)

dists = self.euclidean_dist(

query_emb, prototypes

) # (n_query, n_way)

log_probs = nn.functional.log_softmax(-dists, dim=-1)

return log_probs

========== Training Loop ==========

def train_prototypical(

model: PrototypicalNetworks,

train_dataset,

n_way: int = 5,

k_shot: int = 5,

n_query: int = 15,

n_episodes: int = 100,

lr: float = 1e-3,

device: str = 'cpu'

):

"""Train Prototypical Networks"""

model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

criterion = nn.NLLLoss()

model.train()

episode_losses = []

episode_accs = []

for episode in range(n_episodes):

support_x, support_y, query_x, query_y = create_episode(

train_dataset, n_way, k_shot, n_query

)

support_x = support_x.to(device)

support_y = support_y.to(device)

query_x = query_x.to(device)

query_y = query_y.to(device)

optimizer.zero_grad()

log_probs = model(support_x, support_y, query_x, n_way)

loss = criterion(log_probs, query_y)

loss.backward()

optimizer.step()

preds = log_probs.argmax(dim=-1)

acc = (preds == query_y).float().mean().item()

episode_losses.append(loss.item())

episode_accs.append(acc)

if (episode + 1) % 100 == 0:

print(

f"Episode {episode+1}/{n_episodes} | "

f"Loss: {np.mean(episode_losses[-100:]):.4f} | "

f"Acc: {np.mean(episode_accs[-100:]):.4f}"

)

return episode_losses, episode_accs

2.3 Relation Networks

Relation Networks by Sung et al. are similar to Prototypical Networks, but the distance function is replaced with a learnable neural network. The network learns to compute a relation score by concatenating the query embedding and the class prototype.

class RelationNetwork(nn.Module):

"""Relation Networks: learnable distance function"""

def __init__(self, encoder: nn.Module, embed_dim: int = 64):

super().__init__()

self.encoder = encoder

Relation module: takes concatenated embeddings, outputs relation score

self.relation_module = nn.Sequential(

nn.Linear(embed_dim * 2, 256),

nn.ReLU(),

nn.Linear(256, 64),

nn.ReLU(),

nn.Linear(64, 1),

nn.Sigmoid()

)

def forward(

self,

support_x: torch.Tensor,

support_y: torch.Tensor,

query_x: torch.Tensor,

n_way: int

) -> torch.Tensor:

"""

Returns: relation scores (n_query, n_way)

"""

support_emb = self.encoder(support_x)

query_emb = self.encoder(query_x)

prototypes = []

for cls in range(n_way):

mask = (support_y == cls)

proto = support_emb[mask].mean(dim=0)

prototypes.append(proto)

prototypes = torch.stack(prototypes) # (n_way, D)

n_query = query_emb.size(0)

query_expanded = query_emb.unsqueeze(1).expand(

n_query, n_way, -1

) # (n_query, n_way, D)

proto_expanded = prototypes.unsqueeze(0).expand(

n_query, n_way, -1

) # (n_query, n_way, D)

pairs = torch.cat(

[query_expanded, proto_expanded], dim=-1

) # (n_query, n_way, 2D)

scores = self.relation_module(

pairs.view(-1, pairs.size(-1))

).view(n_query, n_way)

return scores

3. Optimization-Based Meta-Learning: MAML

3.1 The Core Idea of MAML

MAML (Model-Agnostic Meta-Learning), proposed by Finn et al. in 2017, is one of the most influential works in meta-learning.

MAML's goal is to **"find an initial set of parameters theta from which fast adaptation is possible."**

Concretely, it finds an initialization point from which only a few gradient update steps are needed to achieve good performance on a new task.

3.2 Inner Loop vs Outer Loop

MAML consists of two loops.

**Inner Loop (Task-specific adaptation)**

For each task T_i:

theta_i' = theta - alpha * grad_theta L_{T_i}(f_theta)

Perform 1–5 gradient updates on the support set to obtain task-specific parameters theta_i'.

**Outer Loop (Meta-update)**

theta = theta - beta * grad_theta sum_i L_{T_i}(f_{theta_i'})

Compute the loss on the query set using task-adapted parameters theta_i', then update the meta-parameters theta based on this loss.

3.3 Second-Order Gradients

The key technical challenge of MAML is that the outer loop gradient requires **second-order gradients (backpropagation through the inner loop)**.

grad_theta L(f_{theta_i'}) = grad_theta L(f_{theta - alpha * grad L(theta)})

Computing gradients with respect to theta requires differentiating through the inner loop update (involves the Hessian matrix). This is computationally expensive.

In practice, **FOMAML (First-Order MAML)** is often used, which ignores the second-order gradient terms and uses an approximate gradient.

3.4 Complete MAML Implementation

from torch.utils.data import DataLoader

from copy import deepcopy

from typing import List, Tuple

class MAML:

"""

MAML: Model-Agnostic Meta-Learning

Finn et al., 2017 (arXiv:1703.03400)

"""

def __init__(

self,

model: nn.Module,

inner_lr: float = 0.01, # alpha: inner loop learning rate

outer_lr: float = 0.001, # beta: outer loop learning rate

n_inner_steps: int = 5, # number of inner loop updates

first_order: bool = False, # whether to use FOMAML

device: str = 'cpu'

):

self.model = model.to(device)

self.inner_lr = inner_lr

self.outer_lr = outer_lr

self.n_inner_steps = n_inner_steps

self.first_order = first_order

self.device = device

self.meta_optimizer = torch.optim.Adam(

self.model.parameters(), lr=outer_lr

)

def inner_loop(

self,

support_x: torch.Tensor,

support_y: torch.Tensor,

model_params=None

) -> dict:

"""

Inner loop: task-specific adaptation on the support set

Returns: adapted parameter dictionary

"""

if model_params is None:

params = {

name: param.clone()

for name, param in self.model.named_parameters()

}

else:

params = {k: v.clone() for k, v in model_params.items()}

for step in range(self.n_inner_steps):

logits = self._forward_with_params(support_x, params)

loss = F.cross_entropy(logits, support_y)

grads = torch.autograd.grad(

loss,

params.values(),

create_graph=not self.first_order

)

params = {

name: param - self.inner_lr * grad

for (name, param), grad in zip(params.items(), grads)

}

return params

def _forward_with_params(

self,

x: torch.Tensor,

params: dict

) -> torch.Tensor:

"""Forward pass with specific parameters"""

original_params = {}

for name, param in self.model.named_parameters():

original_params[name] = param.data

param.data = params[name].data if name in params else param.data

output = self.model(x)

for name, param in self.model.named_parameters():

if name in original_params:

param.data = original_params[name]

return output

def meta_train_step(

self,

tasks: List[Tuple]

) -> float:

"""

One meta-training step (over multiple tasks)

tasks: [(support_x, support_y, query_x, query_y), ...]

"""

self.meta_optimizer.zero_grad()

meta_loss = 0.0

for support_x, support_y, query_x, query_y in tasks:

support_x = support_x.to(self.device)

support_y = support_y.to(self.device)

query_x = query_x.to(self.device)

query_y = query_y.to(self.device)

adapted_params = self.inner_loop(support_x, support_y)

query_logits = self._forward_with_params(

query_x, adapted_params

)

query_loss = F.cross_entropy(query_logits, query_y)

meta_loss += query_loss

meta_loss /= len(tasks)

meta_loss.backward()

self.meta_optimizer.step()

return meta_loss.item()

def fine_tune(

self,

support_x: torch.Tensor,

support_y: torch.Tensor,

n_steps: int = None

) -> nn.Module:

"""

Fine-tune on a new task (used at inference)

"""

n_steps = n_steps or self.n_inner_steps

model_copy = deepcopy(self.model)

optimizer = torch.optim.SGD(

model_copy.parameters(), lr=self.inner_lr

)

model_copy.train()

for step in range(n_steps):

optimizer.zero_grad()

logits = model_copy(support_x.to(self.device))

loss = F.cross_entropy(logits, support_y.to(self.device))

loss.backward()

optimizer.step()

return model_copy

def evaluate(

self,

tasks: List[Tuple],

n_fine_tune_steps: int = 5

) -> Tuple[float, float]:

"""

Meta-test: adapt to new tasks and evaluate

"""

total_loss = 0.0

total_acc = 0.0

for support_x, support_y, query_x, query_y in tasks:

adapted_model = self.fine_tune(

support_x, support_y, n_fine_tune_steps

)

adapted_model.eval()

with torch.no_grad():

query_logits = adapted_model(query_x.to(self.device))

loss = F.cross_entropy(

query_logits, query_y.to(self.device)

)

preds = query_logits.argmax(dim=-1)

acc = (preds == query_y.to(self.device)).float().mean()

total_loss += loss.item()

total_acc += acc.item()

return total_loss / len(tasks), total_acc / len(tasks)

def train_maml(

maml: MAML,

dataset,

n_way: int = 5,

k_shot: int = 1,

n_query: int = 15,

meta_batch_size: int = 32,

n_iterations: int = 60000,

device: str = 'cpu'

):

"""MAML training loop"""

print(f"MAML Training: {n_way}-way {k_shot}-shot")

print(f"Meta-batch size: {meta_batch_size}")

print(f"Total iterations: {n_iterations}")

losses = []

for iteration in range(n_iterations):

tasks = []

for _ in range(meta_batch_size):

task = create_episode(dataset, n_way, k_shot, n_query)

tasks.append(task)

meta_loss = maml.meta_train_step(tasks)

losses.append(meta_loss)

if (iteration + 1) % 1000 == 0:

avg_loss = np.mean(losses[-1000:])

print(

f"Iteration {iteration+1}/{n_iterations} | "

f"Meta Loss: {avg_loss:.4f}"

)

return losses

3.5 Reptile: Simplifying MAML

Reptile (Nichol et al., 2018) is a greatly simplified version of MAML. It requires no second-order gradients and is much easier to implement.

**Core idea**: After running SGD multiple times on a task, move the meta-parameters toward the resulting parameters.

theta = theta + epsilon * (W_k - theta)

where W_k is the parameter after k SGD updates on task T.

class Reptile:

"""

Reptile: A Scalable Meta-learning Algorithm

Nichol et al., 2018 (arXiv:1803.02999)

"""

def __init__(

self,

model: nn.Module,

inner_lr: float = 0.02,

outer_lr: float = 0.001, # epsilon

n_inner_steps: int = 5,

device: str = 'cpu'

):

self.model = model.to(device)

self.inner_lr = inner_lr

self.outer_lr = outer_lr

self.n_inner_steps = n_inner_steps

self.device = device

def inner_train(

self,

support_x: torch.Tensor,

support_y: torch.Tensor

) -> dict:

"""Task-specific inner training (k SGD steps)"""

model_copy = deepcopy(self.model)

optimizer = torch.optim.SGD(

model_copy.parameters(), lr=self.inner_lr

)

model_copy.train()

for step in range(self.n_inner_steps):

optimizer.zero_grad()

logits = model_copy(support_x)

loss = F.cross_entropy(logits, support_y)

loss.backward()

optimizer.step()

return dict(model_copy.named_parameters())

def meta_update(self, task_params_list: List[dict]):

"""

Reptile meta-update:

theta += epsilon * (mean(W_k) - theta)

"""

with torch.no_grad():

for name, param in self.model.named_parameters():

task_mean = torch.stack([

task_params[name].data

for task_params in task_params_list

]).mean(dim=0)

param.data += self.outer_lr * (task_mean - param.data)

def train(

self,

dataset,

n_way: int = 5,

k_shot: int = 5,

meta_batch_size: int = 5,

n_iterations: int = 100000

):

"""Reptile full training loop"""

print(f"Reptile Training: {n_way}-way {k_shot}-shot")

for iteration in range(n_iterations):

task_params_list = []

for _ in range(meta_batch_size):

support_x, support_y, _, _ = create_episode(

dataset, n_way, k_shot, n_query=0

)

support_x = support_x.to(self.device)

support_y = support_y.to(self.device)

task_params = self.inner_train(support_x, support_y)

task_params_list.append(task_params)

self.meta_update(task_params_list)

if (iteration + 1) % 10000 == 0:

print(f"Iteration {iteration+1}/{n_iterations} complete")

4. In-Context Learning in LLMs

4.1 What Is In-Context Learning?

In-Context Learning (ICL) is the ability of large language models (LLMs) to perform new tasks from examples within the prompt. The model does not update its parameters; it "learns" solely from the input context (prompt).

This ability gained enormous attention with the arrival of GPT-3. For example:

English to French translation:

sea otter => loutre de mer

peppermint => menthe poivree

plush giraffe => girafe en peluche

cheese => ?

Given this prompt format, GPT-3 answers "fromage." It appears not to have been trained for French translation, but has already learned the pattern during pre-training.

4.2 Why Is It Effective?

The reasons ICL is effective are still being actively studied. Major hypotheses include:

**Pattern completion view**: LLMs are trained by predicting the next token. Seeing the pattern (input to output pairs) in the prompt, they are effectively trained to continue that pattern.

**Latent concept inference view**: According to research by Brown et al., ICL resembles Bayesian inference where the model infers a latent concept from the prompt.

**Gradient descent metaphor**: Akyürek et al. showed that the Transformer's attention mechanism implicitly performs operations analogous to gradient descent.

4.3 Effective Few-Shot Prompting Strategies

from typing import List, Dict, Any

class FewShotPromptBuilder:

"""Few-shot prompt builder"""

def __init__(self):

self.examples = []

self.instruction = ""

self.template = "{input} => {output}"

def set_instruction(self, instruction: str):

"""Set task instruction"""

self.instruction = instruction

return self

def add_example(self, input_text: str, output_text: str):

"""Add an example"""

self.examples.append({

'input': input_text,

'output': output_text

})

return self

def build_prompt(self, query: str) -> str:

"""Build complete few-shot prompt"""

parts = []

if self.instruction:

parts.append(self.instruction)

parts.append("")

for ex in self.examples:

parts.append(self.template.format(

input=ex['input'],

output=ex['output']

))

Query with blank output

parts.append(f"{query} =>")

return "\n".join(parts)

def build_chat_messages(

self,

query: str,

system_prompt: str = None

) -> List[Dict]:

"""Build chat-format messages (for GPT-4, Claude, etc.)"""

messages = []

if system_prompt:

messages.append({

'role': 'system',

'content': system_prompt

})

for ex in self.examples:

messages.append({

'role': 'user',

'content': ex['input']

})

messages.append({

'role': 'assistant',

'content': ex['output']

})

messages.append({

'role': 'user',

'content': query

})

return messages

class DynamicExampleSelector:

"""

Dynamic example selector

Selects the most similar examples to the query for more effective few-shot learning

"""

def __init__(self, examples: List[Dict], encoder=None):

self.examples = examples

self.encoder = encoder # sentence-transformers etc.

def select_similar(

self,

query: str,

n_examples: int = 3

) -> List[Dict]:

"""

Select the n examples most similar to the query

"""

if self.encoder is None:

return np.random.choice(

self.examples, n_examples, replace=False

).tolist()

query_emb = self.encoder.encode(query)

example_embs = self.encoder.encode(

[ex['input'] for ex in self.examples]

)

similarities = np.dot(example_embs, query_emb) / (

np.linalg.norm(example_embs, axis=1)

* np.linalg.norm(query_emb)

)

top_indices = np.argsort(similarities)[-n_examples:][::-1]

return [self.examples[i] for i in top_indices]

def select_diverse(

self,

n_examples: int = 3

) -> List[Dict]:

"""

Select examples that maximize diversity (MMR algorithm)

"""

if len(self.examples) <= n_examples:

return self.examples

if self.encoder is None:

return np.random.choice(

self.examples, n_examples, replace=False

).tolist()

embeddings = self.encoder.encode(

[ex['input'] for ex in self.examples]

)

selected = [0]

remaining = list(range(1, len(self.examples)))

while len(selected) < n_examples:

selected_embs = embeddings[selected]

best_idx = None

best_score = float('-inf')

for idx in remaining:

sim_to_selected = np.max(

np.dot(selected_embs, embeddings[idx]) / (

np.linalg.norm(selected_embs, axis=1)

* np.linalg.norm(embeddings[idx])

)

)

score = -sim_to_selected # maximize diversity

if score > best_score:

best_score = score

best_idx = idx

selected.append(best_idx)

remaining.remove(best_idx)

return [self.examples[i] for i in selected]

========== Practical Examples ==========

def sentiment_analysis_few_shot():

"""Sentiment analysis few-shot example"""

builder = FewShotPromptBuilder()

builder.set_instruction(

"Analyze the sentiment of the following movie review. "

"Answer Positive or Negative."

)

builder.add_example(

"This movie was truly moving and the acting was superb.",

"Positive"

)

builder.add_example(

"The story was too boring and the ending was disappointing.",

"Negative"

)

builder.add_example(

"The special effects were impressive but the plot was too weak.",

"Negative"

)

builder.add_example(

"A warm film the whole family can enjoy together.",

"Positive"

)

query = "The direction was unique and the music paired perfectly with the visuals."

prompt = builder.build_prompt(query)

print("=== Few-Shot Prompt ===")

print(prompt)

return prompt

def code_generation_few_shot():

"""Code generation few-shot example"""

builder = FewShotPromptBuilder()

builder.set_instruction(

"Convert the natural language description into Python code."

)

builder.add_example(

"Find the maximum element in a list",

"def find_max(lst):\n return max(lst)"

)

builder.add_example(

"Check if a string is a palindrome",

"def is_palindrome(s):\n return s == s[::-1]"

)

builder.add_example(

"Flatten a nested list",

"def flatten(lst):\n return [x for sublist in lst for x in sublist]"

)

query = "Count the frequency of each element in a list"

prompt = builder.build_prompt(query)

print("=== Code Generation Few-Shot Prompt ===")

print(prompt)

return prompt

4.4 Cross-lingual Few-shot

Multilingual models demonstrate zero-shot cross-lingual transfer ability. Tasks trained only in English can be applied to Korean or other languages.

class CrossLingualFewShot:

"""

Cross-lingual few-shot learning

English examples + target-language query

"""

def __init__(self, model_name: str = "xlm-roberta-large"):

from transformers import AutoTokenizer, AutoModelForSequenceClassification

self.tokenizer = AutoTokenizer.from_pretrained(model_name)

self.model = AutoModelForSequenceClassification.from_pretrained(

model_name, num_labels=2

)

def encode_text(self, text: str) -> torch.Tensor:

"""Encode text to embedding"""

inputs = self.tokenizer(

text,

return_tensors='pt',

padding=True,

truncation=True,

max_length=128

)

with torch.no_grad():

outputs = self.model(**inputs, output_hidden_states=True)

embedding = outputs.hidden_states[-1][:, 0, :]

return embedding

def classify_zero_shot(

self,

query: str,

class_descriptions_en: List[str]

) -> int:

"""

Classify a query using English class descriptions

(query can be in any language)

"""

query_emb = self.encode_text(query)

class_embs = torch.cat([

self.encode_text(desc) for desc in class_descriptions_en

])

similarities = F.cosine_similarity(

query_emb.expand(len(class_descriptions_en), -1),

class_embs

)

return similarities.argmax().item()

def few_shot_classify(

self,

support_texts: List[str],

support_labels: List[int],

query_text: str,

n_classes: int

) -> int:

"""

Few-shot classification (prototype approach)

Support and query can be in different languages

"""

support_embs = torch.cat([

self.encode_text(t) for t in support_texts

])

query_emb = self.encode_text(query_text)

prototypes = []

for cls in range(n_classes):

mask = torch.tensor([l == cls for l in support_labels])

proto = support_embs[mask].mean(dim=0, keepdim=True)

prototypes.append(proto)

prototypes = torch.cat(prototypes)

dists = torch.cdist(query_emb, prototypes)

return dists.argmin().item()

5. Hands-On: Medical Image Few-Shot Classification

5.1 Rare Disease Diagnosis System

In clinical settings, rare diseases have very limited training data. Few-shot learning enables recognition of new disease patterns from only a small number of confirmed cases.

from torch.utils.data import Dataset, DataLoader

from PIL import Image

class MedicalImageEncoder(nn.Module):

"""

Encoder for medical images

Based on ResNet-18, adapted for medical imaging characteristics

"""

def __init__(self, embed_dim: int = 512, pretrained: bool = True):

super().__init__()

backbone = models.resnet18(pretrained=pretrained)

self.backbone = nn.Sequential(*list(backbone.children())[:-1])

self.embed_head = nn.Sequential(

nn.Flatten(),

nn.Linear(512, embed_dim),

nn.LayerNorm(embed_dim),

nn.ReLU(),

nn.Linear(embed_dim, embed_dim)

)

def forward(self, x: torch.Tensor) -> torch.Tensor:

features = self.backbone(x)

return self.embed_head(features)

class MedicalFewShotClassifier:

"""

Medical image few-shot classifier

Based on Prototypical Networks

"""

def __init__(

self,

encoder: MedicalImageEncoder,

device: str = 'cpu'

):

self.encoder = encoder.to(device)

self.device = device

self.prototypes = {}

self.disease_names = {}

def register_disease(

self,

disease_id: int,

disease_name: str,

support_images: List,

transform=None

):

"""

Register a new disease (from a small number of examples)

"""

self.disease_names[disease_id] = disease_name

self.encoder.eval()

embeddings = []

with torch.no_grad():

for img in support_images:

if transform:

img_tensor = transform(img).unsqueeze(0).to(self.device)

else:

img_tensor = img.unsqueeze(0).to(self.device)

emb = self.encoder(img_tensor)

embeddings.append(emb)

prototype = torch.cat(embeddings).mean(dim=0)

self.prototypes[disease_id] = prototype

print(

f"Disease registered: {disease_name} "

f"({len(support_images)} examples)"

)

def diagnose(

self,

query_image: torch.Tensor,

top_k: int = 3

) -> List[Dict]:

"""

Diagnose a query image

Returns: top-k diseases with similarity scores

"""

self.encoder.eval()

with torch.no_grad():

query_emb = self.encoder(

query_image.unsqueeze(0).to(self.device)

)

results = []

for disease_id, prototype in self.prototypes.items():

similarity = F.cosine_similarity(

query_emb, prototype.unsqueeze(0)

).item()

results.append({

'disease_id': disease_id,

'disease_name': self.disease_names[disease_id],

'similarity': similarity

})

results.sort(key=lambda x: x['similarity'], reverse=True)

return results[:top_k]

def update_prototype(

self,

disease_id: int,

new_image: torch.Tensor,

momentum: float = 0.9

):

"""

Update prototype with a new confirmed case (online learning)

"""

self.encoder.eval()

with torch.no_grad():

new_emb = self.encoder(

new_image.unsqueeze(0).to(self.device)

).squeeze(0)

if disease_id in self.prototypes:

Exponential moving average update

self.prototypes[disease_id] = (

momentum * self.prototypes[disease_id]

+ (1 - momentum) * new_emb

)

print(

f"Prototype updated: {self.disease_names[disease_id]}"

)

else:

print(f"Warning: disease {disease_id} is not registered.")

def demo_medical_few_shot():

"""Medical few-shot classification demo"""

transform = transforms.Compose([

transforms.Resize((224, 224)),

transforms.ToTensor(),

transforms.Normalize(

mean=[0.485, 0.456, 0.406],

std=[0.229, 0.224, 0.225]

)

])

encoder = MedicalImageEncoder(embed_dim=512, pretrained=True)

classifier = MedicalFewShotClassifier(encoder)

print("=== Medical Image Few-Shot Classification System ===")

print("This system recognizes new diseases from only a few confirmed cases.")

print()

print("Usage:")

print("1. classifier.register_disease(id, name, support_images)")

print("2. results = classifier.diagnose(patient_image)")

print("3. classifier.update_prototype(disease_id, new_confirmed_image)")

6. Using the learn2learn Library

6.1 Introduction to learn2learn

learn2learn is a library that makes it easy to implement meta-learning algorithms such as MAML and ProtoNet.

pip install learn2learn

6.2 MAML with learn2learn

from typing import Tuple

def build_l2l_maml(

model: nn.Module,

lr: float = 0.01,

first_order: bool = False

) -> l2l.algorithms.MAML:

"""Create learn2learn MAML wrapper"""

return l2l.algorithms.MAML(

model,

lr=lr,

first_order=first_order,

allow_unused=True

)

def train_with_l2l(

maml_model: l2l.algorithms.MAML,

tasksets,

n_way: int = 5,

k_shot: int = 1,

n_query: int = 15,

meta_lr: float = 0.003,

n_iterations: int = 1000,

adaptation_steps: int = 1,

device: str = 'cpu'

):

"""MAML training with learn2learn"""

maml_model = maml_model.to(device)

meta_optimizer = torch.optim.Adam(

maml_model.parameters(), lr=meta_lr

)

criterion = nn.CrossEntropyLoss(reduction='mean')

for iteration in range(n_iterations):

meta_optimizer.zero_grad()

meta_train_loss = 0.0

meta_train_acc = 0.0

for task in range(4):

X, y = tasksets.train.sample()

X, y = X.to(device), y.to(device)

learner = maml_model.clone()

support_indices = torch.zeros(X.size(0), dtype=torch.bool)

for cls in range(n_way):

cls_idx = (y == cls).nonzero(as_tuple=True)[0]

support_idx = cls_idx[:k_shot]

support_indices[support_idx] = True

query_indices = ~support_indices

support_x, support_y = X[support_indices], y[support_indices]

query_x, query_y = X[query_indices], y[query_indices]

Inner loop: adaptation

for step in range(adaptation_steps):

support_logits = learner(support_x)

support_loss = criterion(support_logits, support_y)

learner.adapt(support_loss)

Outer loop: meta-gradient

query_logits = learner(query_x)

query_loss = criterion(query_logits, query_y)

meta_train_loss += query_loss

preds = query_logits.argmax(dim=-1)

acc = (preds == query_y).float().mean()

meta_train_acc += acc

meta_train_loss /= 4

meta_train_acc /= 4

meta_train_loss.backward()

meta_optimizer.step()

if (iteration + 1) % 100 == 0:

print(

f"Iteration {iteration+1}/{n_iterations} | "

f"Meta Loss: {meta_train_loss.item():.4f} | "

f"Meta Acc: {meta_train_acc.item():.4f}"

)

def setup_omniglot_maml():

"""

Set up MAML with the Omniglot dataset

Omniglot: 1623 character classes from 50 alphabets (20 samples each)

"""

tasksets = l2l.vision.benchmarks.get_tasksets(

'omniglot',

train_ways=5,

train_samples=2 * 1 + 2 * 15,

test_ways=5,

test_samples=2 * 1 + 2 * 15,

root='./data',

device='cpu'

)

model = l2l.vision.models.OmniglotCNN(

output_size=5,

hidden_size=64,

layers=4

)

maml = build_l2l_maml(model, lr=0.4, first_order=False)

return maml, tasksets

def evaluate_l2l(

maml_model: l2l.algorithms.MAML,

tasksets,

n_way: int = 5,

k_shot: int = 1,

n_query: int = 15,

n_test_tasks: int = 600,

adaptation_steps: int = 3,

device: str = 'cpu'

) -> Tuple[float, float]:

"""Meta-evaluation with learn2learn"""

criterion = nn.CrossEntropyLoss()

total_loss = 0.0

total_acc = 0.0

maml_model.eval()

for _ in range(n_test_tasks):

X, y = tasksets.test.sample()

X, y = X.to(device), y.to(device)

learner = maml_model.clone()

support_indices = torch.zeros(X.size(0), dtype=torch.bool)

for cls in range(n_way):

cls_idx = (y == cls).nonzero(as_tuple=True)[0]

support_idx = cls_idx[:k_shot]

support_indices[support_idx] = True

query_indices = ~support_indices

support_x, support_y = X[support_indices], y[support_indices]

query_x, query_y = X[query_indices], y[query_indices]

for step in range(adaptation_steps):

support_loss = criterion(learner(support_x), support_y)

learner.adapt(support_loss)

with torch.no_grad():

query_logits = learner(query_x)

loss = criterion(query_logits, query_y)

acc = (query_logits.argmax(dim=-1) == query_y).float().mean()

total_loss += loss.item()

total_acc += acc.item()

return total_loss / n_test_tasks, total_acc / n_test_tasks

7. Benchmarks and Evaluation

7.1 Key Benchmark Datasets

**Omniglot**

1623 character classes from 50 alphabet systems, 20 samples per class. Primarily evaluated in the 20-way 1-shot setting.

**Mini-ImageNet**

A 100-class subset of ImageNet. 600 images per class (84x84). The 5-way 1/5-shot setting is standard.

**tieredImageNet**

A harder version of Mini-ImageNet. Classes are grouped by superclass concepts, creating a larger semantic gap between meta-train and meta-test classes.

**CIFAR-FS**

A few-shot benchmark derived from CIFAR-100. Faster to experiment with than Mini-ImageNet.

7.2 Standard Evaluation Protocol

def standard_few_shot_evaluation(

model,

test_dataset,

n_way: int = 5,

k_shot: int = 1,

n_query: int = 15,

n_episodes: int = 600,

confidence_interval: bool = True

) -> Dict:

"""

Standard few-shot evaluation protocol

Mean and 95% confidence interval over 600 episodes

"""

accs = []

model.eval()

for episode in range(n_episodes):

support_x, support_y, query_x, query_y = create_episode(

test_dataset, n_way, k_shot, n_query

)

with torch.no_grad():

log_probs = model(support_x, support_y, query_x, n_way)

preds = log_probs.argmax(dim=-1)

acc = (preds == query_y).float().mean().item()

accs.append(acc)

mean_acc = np.mean(accs)

std_acc = np.std(accs)

if confidence_interval:

ci = 1.96 * std_acc / np.sqrt(n_episodes)

return {

'mean_accuracy': mean_acc,

'std': std_acc,

'confidence_interval_95': ci,

'result_string': f"{mean_acc*100:.2f} +/- {ci*100:.2f}%"

}

return {'mean_accuracy': mean_acc, 'std': std_acc}

8. The Future of Meta-Learning

Meta-learning is establishing itself as a core paradigm in AI research. Key trends to watch:

**Convergence with LLMs**

Large language models like GPT-4 and Claude demonstrate powerful ICL capabilities. Active research views these models as meta-learners performing few-shot learning across arbitrary domains.

**Multimodal Few-Shot Learning**

Few-shot learning integrating text, images, and audio. Models like GPT-4V and Gemini Ultra are showing impressive performance on visual few-shot tasks.

**Combination with Continual Learning**

Research shows that models initialized via meta-learning forget less of previous knowledge when learning new tasks. The combination of continual learning and meta-learning is an active research area.

**Domain Adaptation Applications**

Industrial applications where data is scarce — rare disease diagnosis, satellite imagery analysis, specialized code generation — are emerging as the most practical use cases for few-shot learning.

References

- Finn, C., et al. (2017). Model-Agnostic Meta-Learning for Fast Adaptation. _ICML 2017_. arXiv:1703.03400

- Snell, J., et al. (2017). Prototypical Networks for Few-shot Learning. _NeurIPS 2017_. arXiv:1703.05175

- Vinyals, O., et al. (2016). Matching Networks for One Shot Learning. _NeurIPS 2016_. arXiv:1602.01783

- Nichol, A., et al. (2018). On First-Order Meta-Learning Algorithms. arXiv:1803.02999

- Brown, T., et al. (2020). Language Models are Few-Shot Learners (GPT-3). _NeurIPS 2020_.

- learn2learn library: https://github.com/learnables/learn2learn

현재 단락 (1/1128)

Humans learn new concepts quickly from just a few examples. Show a child "this is a zebra" once, and...

작성 글자: 0원문 글자: 33,598작성 단락: 0/1128