Skip to content

필사 모드: Federated Learning Complete Guide: Privacy-Preserving Distributed AI

English
0%
정확도 0%
💡 왼쪽 원문을 읽으면서 오른쪽에 따라 써보세요. Tab 키로 힌트를 받을 수 있습니다.
원문 렌더가 준비되기 전까지 텍스트 가이드로 표시합니다.

Federated Learning Complete Guide: Privacy-Preserving Distributed AI

One of the greatest ironies of modern AI is that as model performance improves, more data is needed, and as more data is collected, the risk of privacy violations grows. Hospitals cannot share patient data, smartphone manufacturers cannot send users' typing patterns to servers, and financial institutions cannot share transaction records with competitors.

Federated Learning (FL) solves this dilemma. Instead of sending data to a central server, the model is sent to where the data lives, trained locally, and only the model updates (weight changes) are aggregated. The concept is: "data stays put; only intelligence moves."

This guide covers everything from the theoretical foundations of FL to hands-on implementation.

1. Federated Learning Fundamentals

1.1 Problems with Traditional Centralized Learning

Consider the traditional ML pipeline: collect patient data from thousands of hospitals, store it on a central server, and train a diagnostic model on that data. What are the problems?

**Data Privacy Issues**

Patient records, financial transaction histories, and personal communications are extremely sensitive. Transmitting such data to a central server creates the following risks:

- Eavesdropping risk during transmission

- Large-scale data breaches from server hacking

- Loss of trust in the data-owning organization

- Legal sanctions for regulatory violations

**Legal Regulations**

Data privacy regulations are tightening worldwide.

- **GDPR (General Data Protection Regulation)**: The EU's data protection law specifying processing purpose disclosure, consent requirements, and data minimization principles.

- **HIPAA (Health Insurance Portability and Accountability Act)**: US law protecting patient health information (PHI).

- **CCPA (California Consumer Privacy Act)**: Gives California residents rights over their personal information collected by businesses.

**Communication Costs**

Transmitting data from millions of edge devices (smartphones, IoT sensors) to a central location requires enormous network bandwidth. Large data types like images and audio make this even more problematic.

1.2 The Core Idea of Federated Learning

Federated Learning, proposed by McMahan et al. at Google in 2016, is based on the following principle:

**"Data stays local; only knowledge (model updates) moves."**

The basic FL process:

1. **Initialization**: The central server initializes a global model.

2. **Distribution**: The server distributes the current global model to selected clients.

3. **Local Training**: Each client trains the model on its local data.

4. **Upload**: Clients send model updates (gradients or weight differences) to the server.

5. **Aggregation**: The server aggregates the updates (e.g., by averaging) to update the global model.

6. **Repeat**: Steps 2–5 repeat until convergence.

In this approach, raw data never leaves the client device. Only model parameter updates are transmitted.

1.3 Applications of Federated Learning

**Mobile / Edge Devices**

Google applied FL to Gboard (mobile keyboard). Users' typing patterns are not sent to servers; instead, the next-word prediction model is improved directly on the device. Hundreds of millions of users' data contribute without any personal information leaving the device.

**Healthcare**

Better diagnostic AI can be built without sharing patient data from multiple hospitals. For rare diseases, single-hospital data is insufficient for model training, but FL can combine knowledge from multiple hospitals.

**Finance**

Multiple financial institutions can collaborate on fraud detection and credit scoring without sharing customer data. Especially for cross-border financial transactions, a joint model can be trained while respecting data sovereignty of each country.

**Autonomous Vehicles**

Multiple automakers can jointly improve road hazard detection models without sharing their proprietary driving data.

2. Federated Learning Architecture

2.1 Client-Server Structure

The most common FL architecture consists of a central aggregation server and multiple clients.

┌─────────────────┐

│ Central Server │

│ (Aggregator) │

└────────┬─────────┘

│ model distribution / update aggregation

┌────────┼────────┐

↓ ↓ ↓

┌─────────┐ ┌─────────┐ ┌─────────┐

│Client 1 │ │Client 2 │ │Client 3 │

│(local │ │(local │ │(local │

│training)│ │training)│ │training)│

└─────────┘ └─────────┘ └─────────┘

**Server Responsibilities**

- Maintain and manage the global model

- Select participating clients for each round

- Aggregate client updates

- Distribute the aggregated model to clients

**Client Responsibilities**

- Hold local data

- Fine-tune the received model on local data

- Transmit the updated model (or gradients)

2.2 Horizontal Federated Learning

Horizontal FL is used when all clients have the same feature space but different data samples. For example, multiple hospitals measure the same diagnostic items (blood pressure, blood glucose, age) but have different patients.

Client 1: [feature1, feature2, feature3] x [samples 1-1000]

Client 2: [feature1, feature2, feature3] x [samples 1001-2000]

Client 3: [feature1, feature2, feature3] x [samples 2001-3000]

Same feature space, different sample space. The most common form of FL.

2.3 Vertical Federated Learning

Vertical FL is used when clients hold the same users (data samples) but have different features. For example, a bank has a user's financial information while a hospital has the same user's medical information.

Client A (Bank): [financial features] x [users 1-10000]

Client B (Hospital): [medical features] x [users 1-10000]

Same sample space, different feature space.

Vertical FL requires more complex protocols. Cryptographic techniques are needed for cooperation between a client holding labels and clients holding only features.

2.4 Federated Transfer Learning

When clients have partially overlapping sample and feature spaces, transfer learning techniques are combined. This allows FL to be applied even when data overlap is minimal.

3. The FedAvg Algorithm

3.1 McMahan et al. (2017) Original Algorithm

FedAvg (Federated Averaging) is the foundational FL algorithm, published by McMahan et al. at Google in 2017. The key idea is that each client performs multiple local SGD updates, then the server averages the weights.

**Algorithm Overview**

Server executes:

initialize w_0

for round t = 1, 2, ..., T:

m = max(C x K, 1) // C: participation fraction, K: total clients

S_t = randomly select m clients

for each client k in S_t (parallel):

w_{t+1}^k = ClientUpdate(k, w_t)

w_{t+1} = sum (n_k / n) x w_{t+1}^k // weighted average

Client k executes:

B = split local data into batches

for local epoch e = 1, ..., E:

for batch b in B:

w = w - lr x grad_loss(w; b)

return w

**Key Parameters**

- C: fraction of clients participating in each round, chosen between 0 and 1 inclusive of 1

- E: number of local epochs per client

- B: local mini-batch size

- lr: learning rate

When E=1 and B equals all data, this is equivalent to FedSGD. Increasing E reduces communication rounds but increases the risk of client drift.

3.2 Complete FedAvg Implementation

from torch.utils.data import DataLoader, Subset

from copy import deepcopy

from typing import List, Dict, Tuple

========== Model Definition ==========

class SimpleNet(nn.Module):

"""Simple classification network"""

def __init__(self, input_dim: int, hidden_dim: int, num_classes: int):

super().__init__()

self.net = nn.Sequential(

nn.Linear(input_dim, hidden_dim),

nn.ReLU(),

nn.Dropout(0.2),

nn.Linear(hidden_dim, hidden_dim // 2),

nn.ReLU(),

nn.Linear(hidden_dim // 2, num_classes)

)

def forward(self, x):

return self.net(x)

========== Client ==========

class FLClient:

"""Federated Learning Client"""

def __init__(

self,

client_id: int,

dataset,

device: str = 'cpu'

):

self.client_id = client_id

self.dataset = dataset

self.device = device

def local_train(

self,

model: nn.Module,

local_epochs: int,

batch_size: int,

lr: float

) -> Tuple[Dict, float]:

"""

Train the model on local data

Returns: (updated weights, local loss)

"""

model = deepcopy(model).to(self.device)

model.train()

loader = DataLoader(

self.dataset,

batch_size=batch_size,

shuffle=True

)

optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

criterion = nn.CrossEntropyLoss()

total_loss = 0.0

n_batches = 0

for epoch in range(local_epochs):

for X, y in loader:

X, y = X.to(self.device), y.to(self.device)

optimizer.zero_grad()

output = model(X)

loss = criterion(output, y)

loss.backward()

optimizer.step()

total_loss += loss.item()

n_batches += 1

avg_loss = total_loss / max(n_batches, 1)

return model.state_dict(), avg_loss

def evaluate(self, model: nn.Module) -> Tuple[float, float]:

"""Evaluate model on local data"""

model = deepcopy(model).to(self.device)

model.eval()

loader = DataLoader(self.dataset, batch_size=64)

criterion = nn.CrossEntropyLoss()

total_loss = 0.0

correct = 0

total = 0

with torch.no_grad():

for X, y in loader:

X, y = X.to(self.device), y.to(self.device)

output = model(X)

loss = criterion(output, y)

total_loss += loss.item()

_, predicted = output.max(1)

total += y.size(0)

correct += predicted.eq(y).sum().item()

return total_loss / len(loader), correct / total

========== Server ==========

class FedAvgServer:

"""FedAvg Server"""

def __init__(

self,

global_model: nn.Module,

clients: List[FLClient],

fraction: float = 0.1,

device: str = 'cpu'

):

self.global_model = global_model.to(device)

self.clients = clients

self.fraction = fraction

self.device = device

self.round_history = []

def select_clients(self) -> List[FLClient]:

"""Select clients for each round"""

m = max(int(self.fraction * len(self.clients)), 1)

return random.sample(self.clients, m)

def aggregate(

self,

client_weights: List[Dict],

client_sizes: List[int]

) -> Dict:

"""

Weighted average aggregation (FedAvg)

Weighted by n_k / n

"""

total_size = sum(client_sizes)

aggregated = {}

for key in client_weights[0].keys():

aggregated[key] = torch.zeros_like(

client_weights[0][key], dtype=torch.float32

)

for w, size in zip(client_weights, client_sizes):

weight = size / total_size

aggregated[key] += weight * w[key].float()

return aggregated

def train_round(

self,

local_epochs: int = 5,

batch_size: int = 32,

lr: float = 0.01

) -> Dict:

"""Execute one FL round"""

selected = self.select_clients()

client_weights = []

client_sizes = []

client_losses = []

for client in selected:

weights, loss = client.local_train(

self.global_model, local_epochs, batch_size, lr

)

client_weights.append(weights)

client_sizes.append(len(client.dataset))

client_losses.append(loss)

new_weights = self.aggregate(client_weights, client_sizes)

self.global_model.load_state_dict(new_weights)

round_info = {

'num_clients': len(selected),

'avg_local_loss': np.mean(client_losses),

'client_losses': client_losses

}

self.round_history.append(round_info)

return round_info

def evaluate_global(self, test_loader: DataLoader) -> Tuple[float, float]:

"""Evaluate global model"""

self.global_model.eval()

criterion = nn.CrossEntropyLoss()

total_loss = 0.0

correct = 0

total = 0

with torch.no_grad():

for X, y in test_loader:

X, y = X.to(self.device), y.to(self.device)

output = self.global_model(X)

loss = criterion(output, y)

total_loss += loss.item()

_, predicted = output.max(1)

total += y.size(0)

correct += predicted.eq(y).sum().item()

return total_loss / len(test_loader), correct / total

def federated_train(

self,

num_rounds: int,

local_epochs: int = 5,

batch_size: int = 32,

lr: float = 0.01,

test_loader: DataLoader = None

):

"""Full FL training loop"""

print(f"FL Training: {num_rounds} rounds, {len(self.clients)} clients")

for round_num in range(1, num_rounds + 1):

round_info = self.train_round(local_epochs, batch_size, lr)

if test_loader and round_num % 10 == 0:

test_loss, test_acc = self.evaluate_global(test_loader)

print(

f"Round {round_num:3d}/{num_rounds} | "

f"Clients: {round_info['num_clients']} | "

f"Local Loss: {round_info['avg_local_loss']:.4f} | "

f"Test Acc: {test_acc:.4f}"

)

print("Federated training complete!")

========== Non-IID Data Partitioning ==========

def create_non_iid_partition(

dataset,

num_clients: int,

num_classes: int,

alpha: float = 0.5

) -> List[List[int]]:

"""

Non-IID data partitioning using Dirichlet distribution

Lower alpha = more heterogeneous distribution

"""

labels = np.array([dataset[i][1] for i in range(len(dataset))])

client_indices = [[] for _ in range(num_clients)]

for cls in range(num_classes):

cls_indices = np.where(labels == cls)[0]

np.random.shuffle(cls_indices)

proportions = np.random.dirichlet([alpha] * num_clients)

proportions = (proportions * len(cls_indices)).astype(int)

proportions[-1] = len(cls_indices) - proportions[:-1].sum()

start = 0

for k, prop in enumerate(proportions):

client_indices[k].extend(

cls_indices[start:start + prop].tolist()

)

start += prop

return client_indices

def run_fedavg_demo():

"""Run FedAvg demo"""

transform = transforms.Compose([

transforms.ToTensor(),

transforms.Normalize((0.1307,), (0.3081,))

])

train_dataset = torchvision.datasets.MNIST(

'./data', train=True, download=True, transform=transform

)

test_dataset = torchvision.datasets.MNIST(

'./data', train=False, transform=transform

)

num_clients = 20

client_indices = create_non_iid_partition(

train_dataset, num_clients, num_classes=10, alpha=0.5

)

clients = []

for k in range(num_clients):

subset = Subset(train_dataset, client_indices[k])

clients.append(FLClient(k, subset))

print(f"Number of clients: {num_clients}")

print(f"Avg samples per client: {np.mean([len(c.dataset) for c in clients]):.1f}")

global_model = SimpleNet(784, 256, 10)

server = FedAvgServer(global_model, clients, fraction=0.1)

test_loader = DataLoader(test_dataset, batch_size=256)

server.federated_train(

num_rounds=100,

local_epochs=5,

batch_size=32,

lr=0.01,

test_loader=test_loader

)

_, final_acc = server.evaluate_global(test_loader)

print(f"\nFinal Test Accuracy: {final_acc:.4f}")

if __name__ == '__main__':

run_fedavg_demo()

4. Challenges in Federated Learning

4.1 Data Heterogeneity (Non-IID Problem)

The biggest technical challenge in FL is **Non-IID (Non-Independent and Identically Distributed)** data. In real environments, each client's data distribution is different.

**Types of Non-IID**

- **Feature distribution skew**: Different input distributions per client (e.g., regional weather patterns)

- **Label distribution skew**: Different class proportions per client (each smartphone user has different frequent words)

- **Concept drift**: Different labels for the same input (different cultural contexts by region)

- **Quantity imbalance**: Extremely different amounts of data per client

Non-IID data in FedAvg causes **client drift**. When each client's local optimum differs from the global optimum, aggregation becomes ineffective.

4.2 System Heterogeneity

In real FL systems, device performance, battery status, and network connectivity vary widely.

- **Compute heterogeneity**: GPU-equipped servers and low-end mobile devices participating simultaneously

- **Network heterogeneity**: High-speed wired and unstable mobile networks coexisting

- **Memory heterogeneity**: Some clients may not be able to load the full model

4.3 The Straggler Problem

If some clients are slow or unresponsive, the entire training is delayed. In **Synchronous FL**, all selected clients must return their updates before the next round. Solutions include:

- **Asynchronous FL**: Immediately aggregate updates from responding clients

- **Timeout setting**: Only use clients that respond within a time limit

- **FedProx**: Add a proximal term to local updates to tolerate stragglers

5. Advanced FL Algorithms

5.1 FedProx: Handling Non-IID

FedProx, proposed by Li et al. in 2020, adds a **proximal term** to local optimization. This term prevents the local model from drifting too far from the global model.

**FedProx Objective Function**

A proximal term is added to the local objective function.

h_k(w; w^t) = F_k(w) + (mu/2) x ||w - w^t||^2

Here mu is a hyperparameter controlling the strength of the proximal term. When mu=0, it is equivalent to FedAvg.

class FedProxClient(FLClient):

"""FedProx client: adds proximal term"""

def local_train_prox(

self,

model: nn.Module,

global_weights: Dict,

local_epochs: int,

batch_size: int,

lr: float,

mu: float = 0.01

) -> Tuple[Dict, float]:

"""

Local training with proximal term

h_k(w) = F_k(w) + (mu/2) * ||w - w^t||^2

"""

model = deepcopy(model).to(self.device)

model.train()

global_model = deepcopy(model)

global_model.load_state_dict(global_weights)

for param in global_model.parameters():

param.requires_grad = False

loader = DataLoader(

self.dataset,

batch_size=batch_size,

shuffle=True

)

optimizer = optim.SGD(model.parameters(), lr=lr)

criterion = nn.CrossEntropyLoss()

total_loss = 0.0

n_batches = 0

for epoch in range(local_epochs):

for X, y in loader:

X, y = X.to(self.device), y.to(self.device)

optimizer.zero_grad()

output = model(X)

task_loss = criterion(output, y)

Proximal term: (mu/2) * ||w - w_global||^2

prox_loss = 0.0

for w, w_global in zip(

model.parameters(),

global_model.parameters()

):

prox_loss += (mu / 2) * torch.norm(w - w_global) ** 2

total_batch_loss = task_loss + prox_loss

total_batch_loss.backward()

optimizer.step()

total_loss += task_loss.item()

n_batches += 1

avg_loss = total_loss / max(n_batches, 1)

return model.state_dict(), avg_loss

5.2 SCAFFOLD: Correcting Client Drift

SCAFFOLD (Stochastic Controlled Averaging for Federated Learning) uses control variates to directly correct client drift. Each client and server maintain control variates c_k and c to correct gradient bias.

class ScaffoldClient:

"""SCAFFOLD client"""

def __init__(self, client_id, dataset, device='cpu'):

self.client_id = client_id

self.dataset = dataset

self.device = device

self.c_k = None # client control variate

def init_control_variate(self, model: nn.Module):

"""Initialize control variate"""

self.c_k = {

name: torch.zeros_like(param)

for name, param in model.named_parameters()

}

def local_train_scaffold(

self,

model: nn.Module,

server_control: Dict,

local_epochs: int,

batch_size: int,

lr: float

) -> Tuple[Dict, Dict, float]:

"""

SCAFFOLD local training

Returns: (updated weights, control variate update, loss)

"""

if self.c_k is None:

self.init_control_variate(model)

model = deepcopy(model).to(self.device)

model.train()

initial_weights = deepcopy(model.state_dict())

loader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True)

criterion = nn.CrossEntropyLoss()

total_loss = 0.0

n_batches = 0

total_steps = local_epochs * len(loader)

for epoch in range(local_epochs):

for X, y in loader:

X, y = X.to(self.device), y.to(self.device)

output = model(X)

loss = criterion(output, y)

loss.backward()

with torch.no_grad():

for name, param in model.named_parameters():

if param.grad is not None:

SCAFFOLD correction: g - c_k + c

correction = (

self.c_k[name].to(self.device)

- server_control[name].to(self.device)

)

param -= lr * (param.grad + correction)

param.grad.zero_()

total_loss += loss.item()

n_batches += 1

final_weights = model.state_dict()

Update control variate

c_k+ = c_k - c + (1 / (K * lr)) * (w_0 - w_K)

new_c_k = {}

c_k_diff = {}

for name in self.c_k:

w_diff = (

initial_weights[name].float() - final_weights[name].float()

)

new_c_k[name] = (

self.c_k[name]

- server_control[name]

+ w_diff / (total_steps * lr)

)

c_k_diff[name] = new_c_k[name] - self.c_k[name]

self.c_k = new_c_k

avg_loss = total_loss / max(n_batches, 1)

return final_weights, c_k_diff, avg_loss

6. Differential Privacy

6.1 Mathematical Definition of DP

Differential Privacy (DP) is a mathematical framework for privacy protection. Intuitively, it guarantees that "the output distribution does not change significantly when a single data point is added or removed."

**Epsilon-Delta DP Definition**

A randomized mechanism M satisfies (epsilon, delta)-DP if for all neighboring datasets D and D' and all output sets S:

Pr[M(D) in S] <= exp(epsilon) x Pr[M(D') in S] + delta

- epsilon: privacy budget — lower values give stronger privacy guarantees

- delta: failure probability — typically set below 1/|dataset|

6.2 Gaussian Mechanism and Clipping

To apply DP in FL, noise must be added to each client's update.

**Gradient Clipping**: First clip the L2 norm of gradients to a maximum value C.

g_clipped = g x min(1, C / ||g||_2)

**Noise Addition**: Add Gaussian noise to the clipped gradient.

g_dp = g_clipped + N(0, sigma^2 x C^2 x I)

Here sigma is the noise multiplier.

6.3 DP-FL Implementation

from opacus import PrivacyEngine

from opacus.validators import ModuleValidator

def make_private_model(model: nn.Module) -> nn.Module:

"""Convert to Opacus-compatible model (BatchNorm -> GroupNorm)"""

model = ModuleValidator.fix(model)

return model

class DPFLClient:

"""FL client with differential privacy"""

def __init__(

self,

client_id: int,

dataset,

target_epsilon: float = 1.0,

target_delta: float = 1e-5,

max_grad_norm: float = 1.0,

device: str = 'cpu'

):

self.client_id = client_id

self.dataset = dataset

self.target_epsilon = target_epsilon

self.target_delta = target_delta

self.max_grad_norm = max_grad_norm

self.device = device

def dp_local_train(

self,

model: nn.Module,

local_epochs: int,

batch_size: int,

lr: float

) -> Tuple[Dict, float, float]:

"""

DP local training

Returns: (weights, loss, epsilon used)

"""

model = make_private_model(deepcopy(model)).to(self.device)

model.train()

loader = DataLoader(

self.dataset,

batch_size=batch_size,

shuffle=True,

drop_last=True # required by Opacus

)

optimizer = optim.SGD(model.parameters(), lr=lr)

criterion = nn.CrossEntropyLoss()

privacy_engine = PrivacyEngine()

model, optimizer, loader = privacy_engine.make_private_with_epsilon(

module=model,

optimizer=optimizer,

data_loader=loader,

epochs=local_epochs,

target_epsilon=self.target_epsilon,

target_delta=self.target_delta,

max_grad_norm=self.max_grad_norm

)

total_loss = 0.0

n_batches = 0

for epoch in range(local_epochs):

for X, y in loader:

X, y = X.to(self.device), y.to(self.device)

optimizer.zero_grad()

output = model(X)

loss = criterion(output, y)

loss.backward()

optimizer.step()

total_loss += loss.item()

n_batches += 1

epsilon_used = privacy_engine.get_epsilon(self.target_delta)

avg_loss = total_loss / max(n_batches, 1)

Remove Opacus wrapper to return clean weights

clean_weights = {

k.replace('_module.', ''): v

for k, v in model.state_dict().items()

}

return clean_weights, avg_loss, epsilon_used

class DPFLServer:

"""DP-FL server: server-side DP aggregation"""

def __init__(

self,

global_model: nn.Module,

clients: List['DPFLClient'],

noise_multiplier: float = 0.5,

max_grad_norm: float = 1.0,

device: str = 'cpu'

):

self.global_model = global_model.to(device)

self.clients = clients

self.noise_multiplier = noise_multiplier

self.max_grad_norm = max_grad_norm

self.device = device

def clip_and_aggregate(

self,

client_weights: List[Dict],

reference_weights: Dict

) -> Dict:

"""

Server-side clipping and noise addition

Clip each update then add Gaussian noise

"""

n_clients = len(client_weights)

updates = []

for w in client_weights:

delta = {

k: w[k].float() - reference_weights[k].float()

for k in reference_weights

}

updates.append(delta)

clipped_updates = []

for delta in updates:

total_norm = torch.sqrt(

sum(torch.norm(v) ** 2 for v in delta.values())

)

clip_factor = min(1.0, self.max_grad_norm / (total_norm + 1e-8))

clipped = {k: v * clip_factor for k, v in delta.items()}

clipped_updates.append(clipped)

summed = {}

for key in reference_weights:

summed[key] = sum(u[key] for u in clipped_updates)

sigma = self.noise_multiplier * self.max_grad_norm

noisy = {}

for key in summed:

noise = torch.randn_like(summed[key]) * sigma

noisy[key] = (summed[key] + noise) / n_clients

aggregated = {

k: reference_weights[k].float() + noisy[k]

for k in reference_weights

}

return aggregated

7. Secure Aggregation

7.1 The Need for Encrypted Aggregation

DP provides statistical privacy, but the server can still see each client's update individually. Secure Aggregation (SecAgg) is a cryptographic protocol that ensures the server sees **only the aggregated result**.

7.2 Secret Sharing-Based SecAgg

The Secure Aggregation protocol by Bonawitz et al. (2017) uses the following principle:

- Clients exchange masks (random number pairs) with each other.

- Each client adds its mask to its update before sending to the server.

- When the server sums all masked updates, the masks cancel out, leaving the pure sum of updates.

class SecureAggregation:

"""

Simplified secure aggregation simulation

Real implementations use Diffie-Hellman key exchange

"""

def __init__(self, num_clients: int, seed_length: int = 32):

self.num_clients = num_clients

self.seed_length = seed_length

def generate_pairwise_masks(

self,

client_id: int,

model_shape: Dict[str, torch.Size],

shared_seeds: Dict

) -> Dict:

"""

Generate pairwise masks between clients

For client pair i and j: add if i > j, subtract if i < j

"""

mask = {k: torch.zeros(v) for k, v in model_shape.items()}

for other_id, seed in shared_seeds.items():

torch.manual_seed(seed)

direction = 1 if client_id > other_id else -1

for key in mask:

mask[key] += direction * torch.randn(model_shape[key])

return mask

def client_mask_update(

self,

local_update: Dict,

mask: Dict

) -> Dict:

"""Apply mask to update"""

return {k: local_update[k] + mask[k] for k in local_update}

def server_aggregate_masked(

self,

masked_updates: List[Dict]

) -> Dict:

"""Sum masked updates (masks cancel each other out)"""

aggregated = {}

for key in masked_updates[0]:

aggregated[key] = sum(u[key] for u in masked_updates)

aggregated[key] /= len(masked_updates)

return aggregated

7.3 Homomorphic Encryption Overview

Homomorphic Encryption (HE) is a cryptographic scheme that allows computations to be performed on encrypted data. In FL, HE allows the server to aggregate client updates without decrypting them.

- **Partially Homomorphic Encryption (PHE)**: Supports only addition or multiplication (Paillier cryptosystem)

- **Fully Homomorphic Encryption (FHE)**: Supports arbitrary operations (CKKS, BGV) — very high computational cost

In practical FL, PHE supporting only addition is sufficient for aggregation (weighted average).

8. The Flower Framework

8.1 Introduction to Flower

Flower (flwr) is a Python framework for federated learning that is framework-agnostic. It works with PyTorch, TensorFlow, JAX, and other ML frameworks.

pip install flwr

pip install flwr[simulation]

8.2 Flower Client Implementation

from torch.utils.data import DataLoader

from typing import Dict, List, Tuple

class FlowerClient(fl.client.NumPyClient):

"""Flower FL Client"""

def __init__(

self,

model: nn.Module,

train_loader: DataLoader,

val_loader: DataLoader,

device: str = 'cpu'

):

self.model = model.to(device)

self.train_loader = train_loader

self.val_loader = val_loader

self.device = device

self.criterion = nn.CrossEntropyLoss()

def get_parameters(self, config: Dict) -> List[np.ndarray]:

"""Return current model parameters as NumPy arrays"""

return [

val.cpu().numpy()

for _, val in self.model.state_dict().items()

]

def set_parameters(self, parameters: List[np.ndarray]):

"""Set model parameters from NumPy arrays"""

params_dict = zip(self.model.state_dict().keys(), parameters)

state_dict = {k: torch.tensor(v) for k, v in params_dict}

self.model.load_state_dict(state_dict, strict=True)

def fit(

self,

parameters: List[np.ndarray],

config: Dict

) -> Tuple[List[np.ndarray], int, Dict]:

"""Local training with parameters received from server"""

self.set_parameters(parameters)

local_epochs = int(config.get('local_epochs', 1))

lr = float(config.get('lr', 0.01))

self.model.train()

optimizer = torch.optim.SGD(

self.model.parameters(), lr=lr, momentum=0.9

)

train_loss = 0.0

for epoch in range(local_epochs):

for X, y in self.train_loader:

X, y = X.to(self.device), y.to(self.device)

optimizer.zero_grad()

loss = self.criterion(self.model(X), y)

loss.backward()

optimizer.step()

train_loss += loss.item()

return (

self.get_parameters(config={}),

len(self.train_loader.dataset),

{'train_loss': train_loss}

)

def evaluate(

self,

parameters: List[np.ndarray],

config: Dict

) -> Tuple[float, int, Dict]:

"""Evaluate with parameters received from server"""

self.set_parameters(parameters)

self.model.eval()

val_loss = 0.0

correct = 0

total = 0

with torch.no_grad():

for X, y in self.val_loader:

X, y = X.to(self.device), y.to(self.device)

output = self.model(X)

val_loss += self.criterion(output, y).item()

_, predicted = output.max(1)

total += y.size(0)

correct += predicted.eq(y).sum().item()

accuracy = correct / total

return (

val_loss / len(self.val_loader),

len(self.val_loader.dataset),

{'accuracy': accuracy}

)

8.3 Flower Server Strategy

from flwr.server.strategy import FedAvg

from flwr.common import Parameters, FitIns, EvaluateRes

from flwr.server.client_proxy import ClientProxy

class CustomFedAvgStrategy(FedAvg):

"""Custom FedAvg strategy"""

def __init__(

self,

fraction_fit: float = 0.1,

fraction_evaluate: float = 0.1,

min_fit_clients: int = 2,

min_evaluate_clients: int = 2,

min_available_clients: int = 2,

initial_parameters=None,

):

super().__init__(

fraction_fit=fraction_fit,

fraction_evaluate=fraction_evaluate,

min_fit_clients=min_fit_clients,

min_evaluate_clients=min_evaluate_clients,

min_available_clients=min_available_clients,

initial_parameters=initial_parameters,

)

self.round_metrics = []

def configure_fit(

self,

server_round: int,

parameters: Parameters,

client_manager

) -> List[Tuple[ClientProxy, FitIns]]:

"""Configure training for each round"""

Decay learning rate as rounds progress

lr = 0.01 * (0.99 ** server_round)

config = {

'local_epochs': 5,

'lr': lr,

'round': server_round,

}

fit_ins = FitIns(parameters, config)

clients = client_manager.sample(

num_clients=max(

int(client_manager.num_available() * self.fraction_fit), 1

),

min_num_clients=self.min_fit_clients

)

return [(client, fit_ins) for client in clients]

def aggregate_evaluate(

self,

server_round: int,

results: List[Tuple[ClientProxy, EvaluateRes]],

failures

):

"""Aggregate evaluation results and log"""

aggregated_loss, metrics = super().aggregate_evaluate(

server_round, results, failures

)

if metrics:

print(

f"Round {server_round}: "

f"Aggregated accuracy = {metrics.get('accuracy', 0):.4f}"

)

self.round_metrics.append({

'round': server_round,

'loss': aggregated_loss,

'metrics': metrics

})

return aggregated_loss, metrics

def run_flower_simulation():

"""Run Flower simulation"""

from torch.utils.data import Subset

transform = transforms.Compose([

transforms.ToTensor(),

transforms.Normalize((0.1307,), (0.3081,))

])

train_dataset = torchvision.datasets.MNIST(

'./data', train=True, download=True, transform=transform

)

test_dataset = torchvision.datasets.MNIST(

'./data', train=False, transform=transform

)

num_clients = 10

def client_fn(cid: str):

cid = int(cid)

n = len(train_dataset) // num_clients

start = cid * n

indices = list(range(start, min(start + n, len(train_dataset))))

train_subset = Subset(train_dataset, indices)

val_indices = list(range(cid * 100, (cid + 1) * 100))

val_subset = Subset(test_dataset, val_indices)

model = SimpleNet(784, 256, 10)

return FlowerClient(

model=model,

train_loader=DataLoader(train_subset, batch_size=32),

val_loader=DataLoader(val_subset, batch_size=64)

).to_client()

global_model = SimpleNet(784, 256, 10)

initial_params = fl.common.ndarrays_to_parameters(

[val.numpy() for val in global_model.state_dict().values()]

)

strategy = CustomFedAvgStrategy(

fraction_fit=0.5,

fraction_evaluate=0.5,

min_fit_clients=2,

min_evaluate_clients=2,

min_available_clients=num_clients,

initial_parameters=initial_params

)

history = fl.simulation.start_simulation(

client_fn=client_fn,

num_clients=num_clients,

config=fl.server.ServerConfig(num_rounds=50),

strategy=strategy,

)

print("\n=== Flower Simulation Complete ===")

print(f"Final distributed loss: {history.losses_distributed[-1]}")

9. Hands-On FL Project: Hospital Federated Learning

9.1 Multi-Hospital Chest X-Ray Diagnosis

A scenario where multiple hospitals jointly train a lung disease diagnostic model without sharing patient chest X-ray data.

from torch.utils.data import Dataset, DataLoader

from PIL import Image

class ChestXRayDataset(Dataset):

"""Chest X-Ray Dataset (per hospital)"""

def __init__(

self,

data_dir: str,

labels_file: str,

transform=None,

hospital_id: int = 0

):

self.data_dir = data_dir

self.labels = pd.read_csv(labels_file)

self.transform = transform

self.hospital_id = hospital_id

Filter to this hospital's data only

self.labels = self.labels[

self.labels['hospital_id'] == hospital_id

].reset_index(drop=True)

self.classes = [

'Normal', 'Pneumonia', 'COVID-19', 'Tuberculosis'

]

def __len__(self):

return len(self.labels)

def __getitem__(self, idx):

img_path = os.path.join(

self.data_dir, self.labels.iloc[idx]['filename']

)

image = Image.open(img_path).convert('RGB')

label = self.labels.iloc[idx]['label']

if self.transform:

image = self.transform(image)

return image, label

def build_chest_model(num_classes: int = 4, pretrained: bool = True):

"""ResNet50-based chest X-ray classification model"""

model = models.resnet50(pretrained=pretrained)

in_features = model.fc.in_features

model.fc = nn.Sequential(

nn.Dropout(0.3),

nn.Linear(in_features, 256),

nn.ReLU(),

nn.Linear(256, num_classes)

)

return model

def train_hospital_federation():

"""Run hospital federated learning"""

print("=== Multi-Hospital Federated Learning ===")

print("Patient data at each hospital never leaves its location.")

print("Only model parameter updates are aggregated.\n")

global_model = build_chest_model(num_classes=4)

total_params = sum(p.numel() for p in global_model.parameters())

print(f"Global model parameters: {total_params:,}")

print(

f"Communication cost (FP32): "

f"{total_params * 4 / 1024 / 1024:.1f} MB/round"

)

10. The Future of Federated Learning

Federated Learning is establishing itself as a key technology for resolving the conflict between AI and privacy. Key trends to watch:

**Cross-Device vs Cross-Silo**

- **Cross-device**: Millions of mobile devices participate (Google's Gboard)

- **Cross-silo**: A small number of institutions (hospitals, banks) participate — more trustworthy but smaller scale

**FL + LLMs**

With the rise of Large Language Models (LLMs), FL has become even more important. When fine-tuning models on user conversations, FL ensures that conversations never leave the user's device. Combining with Parameter-Efficient Fine-Tuning (PEFT, LoRA) further reduces communication costs.

**Regulation-Friendly AI**

As regulations like GDPR and HIPAA tighten, FL is becoming a practical compliance solution. FL adoption is expected to expand rapidly in healthcare, finance, and legal sectors.

References

- McMahan, H. B., et al. (2017). Communication-Efficient Learning of Deep Networks from Decentralized Data. _AISTATS 2017_.

- Li, T., et al. (2020). Federated Optimization in Heterogeneous Networks (FedProx). _MLSys 2020_.

- Karimireddy, S. P., et al. (2020). SCAFFOLD: Stochastic Controlled Averaging for Federated Learning. _ICML 2020_.

- Bonawitz, K., et al. (2017). Practical Secure Aggregation for Privacy-Preserving Machine Learning. _CCS 2017_.

- Flower Framework: https://flower.ai/docs/

- Opacus (PyTorch DP): https://opacus.ai/

현재 단락 (1/926)

One of the greatest ironies of modern AI is that as model performance improves, more data is needed,...

작성 글자: 0원문 글자: 31,091작성 단락: 0/926