- Authors

- Name
- Youngju Kim
- @fjvbn20031
Federated Learning Complete Guide: Privacy-Preserving Distributed AI
One of the greatest ironies of modern AI is that as model performance improves, more data is needed, and as more data is collected, the risk of privacy violations grows. Hospitals cannot share patient data, smartphone manufacturers cannot send users' typing patterns to servers, and financial institutions cannot share transaction records with competitors.
Federated Learning (FL) solves this dilemma. Instead of sending data to a central server, the model is sent to where the data lives, trained locally, and only the model updates (weight changes) are aggregated. The concept is: "data stays put; only intelligence moves."
This guide covers everything from the theoretical foundations of FL to hands-on implementation.
1. Federated Learning Fundamentals
1.1 Problems with Traditional Centralized Learning
Consider the traditional ML pipeline: collect patient data from thousands of hospitals, store it on a central server, and train a diagnostic model on that data. What are the problems?
Data Privacy Issues
Patient records, financial transaction histories, and personal communications are extremely sensitive. Transmitting such data to a central server creates the following risks:
- Eavesdropping risk during transmission
- Large-scale data breaches from server hacking
- Loss of trust in the data-owning organization
- Legal sanctions for regulatory violations
Legal Regulations
Data privacy regulations are tightening worldwide.
- GDPR (General Data Protection Regulation): The EU's data protection law specifying processing purpose disclosure, consent requirements, and data minimization principles.
- HIPAA (Health Insurance Portability and Accountability Act): US law protecting patient health information (PHI).
- CCPA (California Consumer Privacy Act): Gives California residents rights over their personal information collected by businesses.
Communication Costs
Transmitting data from millions of edge devices (smartphones, IoT sensors) to a central location requires enormous network bandwidth. Large data types like images and audio make this even more problematic.
1.2 The Core Idea of Federated Learning
Federated Learning, proposed by McMahan et al. at Google in 2016, is based on the following principle:
"Data stays local; only knowledge (model updates) moves."
The basic FL process:
- Initialization: The central server initializes a global model.
- Distribution: The server distributes the current global model to selected clients.
- Local Training: Each client trains the model on its local data.
- Upload: Clients send model updates (gradients or weight differences) to the server.
- Aggregation: The server aggregates the updates (e.g., by averaging) to update the global model.
- Repeat: Steps 2–5 repeat until convergence.
In this approach, raw data never leaves the client device. Only model parameter updates are transmitted.
1.3 Applications of Federated Learning
Mobile / Edge Devices
Google applied FL to Gboard (mobile keyboard). Users' typing patterns are not sent to servers; instead, the next-word prediction model is improved directly on the device. Hundreds of millions of users' data contribute without any personal information leaving the device.
Healthcare
Better diagnostic AI can be built without sharing patient data from multiple hospitals. For rare diseases, single-hospital data is insufficient for model training, but FL can combine knowledge from multiple hospitals.
Finance
Multiple financial institutions can collaborate on fraud detection and credit scoring without sharing customer data. Especially for cross-border financial transactions, a joint model can be trained while respecting data sovereignty of each country.
Autonomous Vehicles
Multiple automakers can jointly improve road hazard detection models without sharing their proprietary driving data.
2. Federated Learning Architecture
2.1 Client-Server Structure
The most common FL architecture consists of a central aggregation server and multiple clients.
┌─────────────────┐
│ Central Server │
│ (Aggregator) │
└────────┬─────────┘
│ model distribution / update aggregation
┌────────┼────────┐
↓ ↓ ↓
┌─────────┐ ┌─────────┐ ┌─────────┐
│Client 1 │ │Client 2 │ │Client 3 │
│(local │ │(local │ │(local │
│training)│ │training)│ │training)│
└─────────┘ └─────────┘ └─────────┘
Server Responsibilities
- Maintain and manage the global model
- Select participating clients for each round
- Aggregate client updates
- Distribute the aggregated model to clients
Client Responsibilities
- Hold local data
- Fine-tune the received model on local data
- Transmit the updated model (or gradients)
2.2 Horizontal Federated Learning
Horizontal FL is used when all clients have the same feature space but different data samples. For example, multiple hospitals measure the same diagnostic items (blood pressure, blood glucose, age) but have different patients.
Client 1: [feature1, feature2, feature3] x [samples 1-1000]
Client 2: [feature1, feature2, feature3] x [samples 1001-2000]
Client 3: [feature1, feature2, feature3] x [samples 2001-3000]
Same feature space, different sample space. The most common form of FL.
2.3 Vertical Federated Learning
Vertical FL is used when clients hold the same users (data samples) but have different features. For example, a bank has a user's financial information while a hospital has the same user's medical information.
Client A (Bank): [financial features] x [users 1-10000]
Client B (Hospital): [medical features] x [users 1-10000]
Same sample space, different feature space.
Vertical FL requires more complex protocols. Cryptographic techniques are needed for cooperation between a client holding labels and clients holding only features.
2.4 Federated Transfer Learning
When clients have partially overlapping sample and feature spaces, transfer learning techniques are combined. This allows FL to be applied even when data overlap is minimal.
3. The FedAvg Algorithm
3.1 McMahan et al. (2017) Original Algorithm
FedAvg (Federated Averaging) is the foundational FL algorithm, published by McMahan et al. at Google in 2017. The key idea is that each client performs multiple local SGD updates, then the server averages the weights.
Algorithm Overview
Server executes:
initialize w_0
for round t = 1, 2, ..., T:
m = max(C x K, 1) // C: participation fraction, K: total clients
S_t = randomly select m clients
for each client k in S_t (parallel):
w_{t+1}^k = ClientUpdate(k, w_t)
w_{t+1} = sum (n_k / n) x w_{t+1}^k // weighted average
Client k executes:
B = split local data into batches
for local epoch e = 1, ..., E:
for batch b in B:
w = w - lr x grad_loss(w; b)
return w
Key Parameters
- C: fraction of clients participating in each round, chosen between 0 and 1 inclusive of 1
- E: number of local epochs per client
- B: local mini-batch size
- lr: learning rate
When E=1 and B equals all data, this is equivalent to FedSGD. Increasing E reduces communication rounds but increases the risk of client drift.
3.2 Complete FedAvg Implementation
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import numpy as np
from copy import deepcopy
from typing import List, Dict, Tuple
import random
# ========== Model Definition ==========
class SimpleNet(nn.Module):
"""Simple classification network"""
def __init__(self, input_dim: int, hidden_dim: int, num_classes: int):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, num_classes)
)
def forward(self, x):
return self.net(x)
# ========== Client ==========
class FLClient:
"""Federated Learning Client"""
def __init__(
self,
client_id: int,
dataset,
device: str = 'cpu'
):
self.client_id = client_id
self.dataset = dataset
self.device = device
def local_train(
self,
model: nn.Module,
local_epochs: int,
batch_size: int,
lr: float
) -> Tuple[Dict, float]:
"""
Train the model on local data
Returns: (updated weights, local loss)
"""
model = deepcopy(model).to(self.device)
model.train()
loader = DataLoader(
self.dataset,
batch_size=batch_size,
shuffle=True
)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
criterion = nn.CrossEntropyLoss()
total_loss = 0.0
n_batches = 0
for epoch in range(local_epochs):
for X, y in loader:
X, y = X.to(self.device), y.to(self.device)
optimizer.zero_grad()
output = model(X)
loss = criterion(output, y)
loss.backward()
optimizer.step()
total_loss += loss.item()
n_batches += 1
avg_loss = total_loss / max(n_batches, 1)
return model.state_dict(), avg_loss
def evaluate(self, model: nn.Module) -> Tuple[float, float]:
"""Evaluate model on local data"""
model = deepcopy(model).to(self.device)
model.eval()
loader = DataLoader(self.dataset, batch_size=64)
criterion = nn.CrossEntropyLoss()
total_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for X, y in loader:
X, y = X.to(self.device), y.to(self.device)
output = model(X)
loss = criterion(output, y)
total_loss += loss.item()
_, predicted = output.max(1)
total += y.size(0)
correct += predicted.eq(y).sum().item()
return total_loss / len(loader), correct / total
# ========== Server ==========
class FedAvgServer:
"""FedAvg Server"""
def __init__(
self,
global_model: nn.Module,
clients: List[FLClient],
fraction: float = 0.1,
device: str = 'cpu'
):
self.global_model = global_model.to(device)
self.clients = clients
self.fraction = fraction
self.device = device
self.round_history = []
def select_clients(self) -> List[FLClient]:
"""Select clients for each round"""
m = max(int(self.fraction * len(self.clients)), 1)
return random.sample(self.clients, m)
def aggregate(
self,
client_weights: List[Dict],
client_sizes: List[int]
) -> Dict:
"""
Weighted average aggregation (FedAvg)
Weighted by n_k / n
"""
total_size = sum(client_sizes)
aggregated = {}
for key in client_weights[0].keys():
aggregated[key] = torch.zeros_like(
client_weights[0][key], dtype=torch.float32
)
for w, size in zip(client_weights, client_sizes):
weight = size / total_size
aggregated[key] += weight * w[key].float()
return aggregated
def train_round(
self,
local_epochs: int = 5,
batch_size: int = 32,
lr: float = 0.01
) -> Dict:
"""Execute one FL round"""
selected = self.select_clients()
client_weights = []
client_sizes = []
client_losses = []
for client in selected:
weights, loss = client.local_train(
self.global_model, local_epochs, batch_size, lr
)
client_weights.append(weights)
client_sizes.append(len(client.dataset))
client_losses.append(loss)
new_weights = self.aggregate(client_weights, client_sizes)
self.global_model.load_state_dict(new_weights)
round_info = {
'num_clients': len(selected),
'avg_local_loss': np.mean(client_losses),
'client_losses': client_losses
}
self.round_history.append(round_info)
return round_info
def evaluate_global(self, test_loader: DataLoader) -> Tuple[float, float]:
"""Evaluate global model"""
self.global_model.eval()
criterion = nn.CrossEntropyLoss()
total_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for X, y in test_loader:
X, y = X.to(self.device), y.to(self.device)
output = self.global_model(X)
loss = criterion(output, y)
total_loss += loss.item()
_, predicted = output.max(1)
total += y.size(0)
correct += predicted.eq(y).sum().item()
return total_loss / len(test_loader), correct / total
def federated_train(
self,
num_rounds: int,
local_epochs: int = 5,
batch_size: int = 32,
lr: float = 0.01,
test_loader: DataLoader = None
):
"""Full FL training loop"""
print(f"FL Training: {num_rounds} rounds, {len(self.clients)} clients")
for round_num in range(1, num_rounds + 1):
round_info = self.train_round(local_epochs, batch_size, lr)
if test_loader and round_num % 10 == 0:
test_loss, test_acc = self.evaluate_global(test_loader)
print(
f"Round {round_num:3d}/{num_rounds} | "
f"Clients: {round_info['num_clients']} | "
f"Local Loss: {round_info['avg_local_loss']:.4f} | "
f"Test Acc: {test_acc:.4f}"
)
print("Federated training complete!")
# ========== Non-IID Data Partitioning ==========
def create_non_iid_partition(
dataset,
num_clients: int,
num_classes: int,
alpha: float = 0.5
) -> List[List[int]]:
"""
Non-IID data partitioning using Dirichlet distribution
Lower alpha = more heterogeneous distribution
"""
labels = np.array([dataset[i][1] for i in range(len(dataset))])
client_indices = [[] for _ in range(num_clients)]
for cls in range(num_classes):
cls_indices = np.where(labels == cls)[0]
np.random.shuffle(cls_indices)
proportions = np.random.dirichlet([alpha] * num_clients)
proportions = (proportions * len(cls_indices)).astype(int)
proportions[-1] = len(cls_indices) - proportions[:-1].sum()
start = 0
for k, prop in enumerate(proportions):
client_indices[k].extend(
cls_indices[start:start + prop].tolist()
)
start += prop
return client_indices
def run_fedavg_demo():
"""Run FedAvg demo"""
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = torchvision.datasets.MNIST(
'./data', train=True, download=True, transform=transform
)
test_dataset = torchvision.datasets.MNIST(
'./data', train=False, transform=transform
)
num_clients = 20
client_indices = create_non_iid_partition(
train_dataset, num_clients, num_classes=10, alpha=0.5
)
clients = []
for k in range(num_clients):
subset = Subset(train_dataset, client_indices[k])
clients.append(FLClient(k, subset))
print(f"Number of clients: {num_clients}")
print(f"Avg samples per client: {np.mean([len(c.dataset) for c in clients]):.1f}")
global_model = SimpleNet(784, 256, 10)
server = FedAvgServer(global_model, clients, fraction=0.1)
test_loader = DataLoader(test_dataset, batch_size=256)
server.federated_train(
num_rounds=100,
local_epochs=5,
batch_size=32,
lr=0.01,
test_loader=test_loader
)
_, final_acc = server.evaluate_global(test_loader)
print(f"\nFinal Test Accuracy: {final_acc:.4f}")
if __name__ == '__main__':
run_fedavg_demo()
4. Challenges in Federated Learning
4.1 Data Heterogeneity (Non-IID Problem)
The biggest technical challenge in FL is Non-IID (Non-Independent and Identically Distributed) data. In real environments, each client's data distribution is different.
Types of Non-IID
- Feature distribution skew: Different input distributions per client (e.g., regional weather patterns)
- Label distribution skew: Different class proportions per client (each smartphone user has different frequent words)
- Concept drift: Different labels for the same input (different cultural contexts by region)
- Quantity imbalance: Extremely different amounts of data per client
Non-IID data in FedAvg causes client drift. When each client's local optimum differs from the global optimum, aggregation becomes ineffective.
4.2 System Heterogeneity
In real FL systems, device performance, battery status, and network connectivity vary widely.
- Compute heterogeneity: GPU-equipped servers and low-end mobile devices participating simultaneously
- Network heterogeneity: High-speed wired and unstable mobile networks coexisting
- Memory heterogeneity: Some clients may not be able to load the full model
4.3 The Straggler Problem
If some clients are slow or unresponsive, the entire training is delayed. In Synchronous FL, all selected clients must return their updates before the next round. Solutions include:
- Asynchronous FL: Immediately aggregate updates from responding clients
- Timeout setting: Only use clients that respond within a time limit
- FedProx: Add a proximal term to local updates to tolerate stragglers
5. Advanced FL Algorithms
5.1 FedProx: Handling Non-IID
FedProx, proposed by Li et al. in 2020, adds a proximal term to local optimization. This term prevents the local model from drifting too far from the global model.
FedProx Objective Function
A proximal term is added to the local objective function.
h_k(w; w^t) = F_k(w) + (mu/2) x ||w - w^t||^2
Here mu is a hyperparameter controlling the strength of the proximal term. When mu=0, it is equivalent to FedAvg.
class FedProxClient(FLClient):
"""FedProx client: adds proximal term"""
def local_train_prox(
self,
model: nn.Module,
global_weights: Dict,
local_epochs: int,
batch_size: int,
lr: float,
mu: float = 0.01
) -> Tuple[Dict, float]:
"""
Local training with proximal term
h_k(w) = F_k(w) + (mu/2) * ||w - w^t||^2
"""
model = deepcopy(model).to(self.device)
model.train()
global_model = deepcopy(model)
global_model.load_state_dict(global_weights)
for param in global_model.parameters():
param.requires_grad = False
loader = DataLoader(
self.dataset,
batch_size=batch_size,
shuffle=True
)
optimizer = optim.SGD(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
total_loss = 0.0
n_batches = 0
for epoch in range(local_epochs):
for X, y in loader:
X, y = X.to(self.device), y.to(self.device)
optimizer.zero_grad()
output = model(X)
task_loss = criterion(output, y)
# Proximal term: (mu/2) * ||w - w_global||^2
prox_loss = 0.0
for w, w_global in zip(
model.parameters(),
global_model.parameters()
):
prox_loss += (mu / 2) * torch.norm(w - w_global) ** 2
total_batch_loss = task_loss + prox_loss
total_batch_loss.backward()
optimizer.step()
total_loss += task_loss.item()
n_batches += 1
avg_loss = total_loss / max(n_batches, 1)
return model.state_dict(), avg_loss
5.2 SCAFFOLD: Correcting Client Drift
SCAFFOLD (Stochastic Controlled Averaging for Federated Learning) uses control variates to directly correct client drift. Each client and server maintain control variates c_k and c to correct gradient bias.
class ScaffoldClient:
"""SCAFFOLD client"""
def __init__(self, client_id, dataset, device='cpu'):
self.client_id = client_id
self.dataset = dataset
self.device = device
self.c_k = None # client control variate
def init_control_variate(self, model: nn.Module):
"""Initialize control variate"""
self.c_k = {
name: torch.zeros_like(param)
for name, param in model.named_parameters()
}
def local_train_scaffold(
self,
model: nn.Module,
server_control: Dict,
local_epochs: int,
batch_size: int,
lr: float
) -> Tuple[Dict, Dict, float]:
"""
SCAFFOLD local training
Returns: (updated weights, control variate update, loss)
"""
if self.c_k is None:
self.init_control_variate(model)
model = deepcopy(model).to(self.device)
model.train()
initial_weights = deepcopy(model.state_dict())
loader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True)
criterion = nn.CrossEntropyLoss()
total_loss = 0.0
n_batches = 0
total_steps = local_epochs * len(loader)
for epoch in range(local_epochs):
for X, y in loader:
X, y = X.to(self.device), y.to(self.device)
output = model(X)
loss = criterion(output, y)
loss.backward()
with torch.no_grad():
for name, param in model.named_parameters():
if param.grad is not None:
# SCAFFOLD correction: g - c_k + c
correction = (
self.c_k[name].to(self.device)
- server_control[name].to(self.device)
)
param -= lr * (param.grad + correction)
param.grad.zero_()
total_loss += loss.item()
n_batches += 1
final_weights = model.state_dict()
# Update control variate
# c_k+ = c_k - c + (1 / (K * lr)) * (w_0 - w_K)
new_c_k = {}
c_k_diff = {}
for name in self.c_k:
w_diff = (
initial_weights[name].float() - final_weights[name].float()
)
new_c_k[name] = (
self.c_k[name]
- server_control[name]
+ w_diff / (total_steps * lr)
)
c_k_diff[name] = new_c_k[name] - self.c_k[name]
self.c_k = new_c_k
avg_loss = total_loss / max(n_batches, 1)
return final_weights, c_k_diff, avg_loss
6. Differential Privacy
6.1 Mathematical Definition of DP
Differential Privacy (DP) is a mathematical framework for privacy protection. Intuitively, it guarantees that "the output distribution does not change significantly when a single data point is added or removed."
Epsilon-Delta DP Definition
A randomized mechanism M satisfies (epsilon, delta)-DP if for all neighboring datasets D and D' and all output sets S:
Pr[M(D) in S] <= exp(epsilon) x Pr[M(D') in S] + delta
- epsilon: privacy budget — lower values give stronger privacy guarantees
- delta: failure probability — typically set below 1/|dataset|
6.2 Gaussian Mechanism and Clipping
To apply DP in FL, noise must be added to each client's update.
Gradient Clipping: First clip the L2 norm of gradients to a maximum value C.
g_clipped = g x min(1, C / ||g||_2)
Noise Addition: Add Gaussian noise to the clipped gradient.
g_dp = g_clipped + N(0, sigma^2 x C^2 x I)
Here sigma is the noise multiplier.
6.3 DP-FL Implementation
import torch
import torch.nn as nn
from opacus import PrivacyEngine
from opacus.validators import ModuleValidator
def make_private_model(model: nn.Module) -> nn.Module:
"""Convert to Opacus-compatible model (BatchNorm -> GroupNorm)"""
model = ModuleValidator.fix(model)
return model
class DPFLClient:
"""FL client with differential privacy"""
def __init__(
self,
client_id: int,
dataset,
target_epsilon: float = 1.0,
target_delta: float = 1e-5,
max_grad_norm: float = 1.0,
device: str = 'cpu'
):
self.client_id = client_id
self.dataset = dataset
self.target_epsilon = target_epsilon
self.target_delta = target_delta
self.max_grad_norm = max_grad_norm
self.device = device
def dp_local_train(
self,
model: nn.Module,
local_epochs: int,
batch_size: int,
lr: float
) -> Tuple[Dict, float, float]:
"""
DP local training
Returns: (weights, loss, epsilon used)
"""
model = make_private_model(deepcopy(model)).to(self.device)
model.train()
loader = DataLoader(
self.dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True # required by Opacus
)
optimizer = optim.SGD(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
privacy_engine = PrivacyEngine()
model, optimizer, loader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
data_loader=loader,
epochs=local_epochs,
target_epsilon=self.target_epsilon,
target_delta=self.target_delta,
max_grad_norm=self.max_grad_norm
)
total_loss = 0.0
n_batches = 0
for epoch in range(local_epochs):
for X, y in loader:
X, y = X.to(self.device), y.to(self.device)
optimizer.zero_grad()
output = model(X)
loss = criterion(output, y)
loss.backward()
optimizer.step()
total_loss += loss.item()
n_batches += 1
epsilon_used = privacy_engine.get_epsilon(self.target_delta)
avg_loss = total_loss / max(n_batches, 1)
# Remove Opacus wrapper to return clean weights
clean_weights = {
k.replace('_module.', ''): v
for k, v in model.state_dict().items()
}
return clean_weights, avg_loss, epsilon_used
class DPFLServer:
"""DP-FL server: server-side DP aggregation"""
def __init__(
self,
global_model: nn.Module,
clients: List['DPFLClient'],
noise_multiplier: float = 0.5,
max_grad_norm: float = 1.0,
device: str = 'cpu'
):
self.global_model = global_model.to(device)
self.clients = clients
self.noise_multiplier = noise_multiplier
self.max_grad_norm = max_grad_norm
self.device = device
def clip_and_aggregate(
self,
client_weights: List[Dict],
reference_weights: Dict
) -> Dict:
"""
Server-side clipping and noise addition
Clip each update then add Gaussian noise
"""
n_clients = len(client_weights)
updates = []
for w in client_weights:
delta = {
k: w[k].float() - reference_weights[k].float()
for k in reference_weights
}
updates.append(delta)
clipped_updates = []
for delta in updates:
total_norm = torch.sqrt(
sum(torch.norm(v) ** 2 for v in delta.values())
)
clip_factor = min(1.0, self.max_grad_norm / (total_norm + 1e-8))
clipped = {k: v * clip_factor for k, v in delta.items()}
clipped_updates.append(clipped)
summed = {}
for key in reference_weights:
summed[key] = sum(u[key] for u in clipped_updates)
sigma = self.noise_multiplier * self.max_grad_norm
noisy = {}
for key in summed:
noise = torch.randn_like(summed[key]) * sigma
noisy[key] = (summed[key] + noise) / n_clients
aggregated = {
k: reference_weights[k].float() + noisy[k]
for k in reference_weights
}
return aggregated
7. Secure Aggregation
7.1 The Need for Encrypted Aggregation
DP provides statistical privacy, but the server can still see each client's update individually. Secure Aggregation (SecAgg) is a cryptographic protocol that ensures the server sees only the aggregated result.
7.2 Secret Sharing-Based SecAgg
The Secure Aggregation protocol by Bonawitz et al. (2017) uses the following principle:
- Clients exchange masks (random number pairs) with each other.
- Each client adds its mask to its update before sending to the server.
- When the server sums all masked updates, the masks cancel out, leaving the pure sum of updates.
class SecureAggregation:
"""
Simplified secure aggregation simulation
Real implementations use Diffie-Hellman key exchange
"""
def __init__(self, num_clients: int, seed_length: int = 32):
self.num_clients = num_clients
self.seed_length = seed_length
def generate_pairwise_masks(
self,
client_id: int,
model_shape: Dict[str, torch.Size],
shared_seeds: Dict
) -> Dict:
"""
Generate pairwise masks between clients
For client pair i and j: add if i > j, subtract if i < j
"""
mask = {k: torch.zeros(v) for k, v in model_shape.items()}
for other_id, seed in shared_seeds.items():
torch.manual_seed(seed)
direction = 1 if client_id > other_id else -1
for key in mask:
mask[key] += direction * torch.randn(model_shape[key])
return mask
def client_mask_update(
self,
local_update: Dict,
mask: Dict
) -> Dict:
"""Apply mask to update"""
return {k: local_update[k] + mask[k] for k in local_update}
def server_aggregate_masked(
self,
masked_updates: List[Dict]
) -> Dict:
"""Sum masked updates (masks cancel each other out)"""
aggregated = {}
for key in masked_updates[0]:
aggregated[key] = sum(u[key] for u in masked_updates)
aggregated[key] /= len(masked_updates)
return aggregated
7.3 Homomorphic Encryption Overview
Homomorphic Encryption (HE) is a cryptographic scheme that allows computations to be performed on encrypted data. In FL, HE allows the server to aggregate client updates without decrypting them.
- Partially Homomorphic Encryption (PHE): Supports only addition or multiplication (Paillier cryptosystem)
- Fully Homomorphic Encryption (FHE): Supports arbitrary operations (CKKS, BGV) — very high computational cost
In practical FL, PHE supporting only addition is sufficient for aggregation (weighted average).
8. The Flower Framework
8.1 Introduction to Flower
Flower (flwr) is a Python framework for federated learning that is framework-agnostic. It works with PyTorch, TensorFlow, JAX, and other ML frameworks.
pip install flwr
pip install flwr[simulation]
8.2 Flower Client Implementation
import flwr as fl
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from typing import Dict, List, Tuple
import numpy as np
class FlowerClient(fl.client.NumPyClient):
"""Flower FL Client"""
def __init__(
self,
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
device: str = 'cpu'
):
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.criterion = nn.CrossEntropyLoss()
def get_parameters(self, config: Dict) -> List[np.ndarray]:
"""Return current model parameters as NumPy arrays"""
return [
val.cpu().numpy()
for _, val in self.model.state_dict().items()
]
def set_parameters(self, parameters: List[np.ndarray]):
"""Set model parameters from NumPy arrays"""
params_dict = zip(self.model.state_dict().keys(), parameters)
state_dict = {k: torch.tensor(v) for k, v in params_dict}
self.model.load_state_dict(state_dict, strict=True)
def fit(
self,
parameters: List[np.ndarray],
config: Dict
) -> Tuple[List[np.ndarray], int, Dict]:
"""Local training with parameters received from server"""
self.set_parameters(parameters)
local_epochs = int(config.get('local_epochs', 1))
lr = float(config.get('lr', 0.01))
self.model.train()
optimizer = torch.optim.SGD(
self.model.parameters(), lr=lr, momentum=0.9
)
train_loss = 0.0
for epoch in range(local_epochs):
for X, y in self.train_loader:
X, y = X.to(self.device), y.to(self.device)
optimizer.zero_grad()
loss = self.criterion(self.model(X), y)
loss.backward()
optimizer.step()
train_loss += loss.item()
return (
self.get_parameters(config={}),
len(self.train_loader.dataset),
{'train_loss': train_loss}
)
def evaluate(
self,
parameters: List[np.ndarray],
config: Dict
) -> Tuple[float, int, Dict]:
"""Evaluate with parameters received from server"""
self.set_parameters(parameters)
self.model.eval()
val_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for X, y in self.val_loader:
X, y = X.to(self.device), y.to(self.device)
output = self.model(X)
val_loss += self.criterion(output, y).item()
_, predicted = output.max(1)
total += y.size(0)
correct += predicted.eq(y).sum().item()
accuracy = correct / total
return (
val_loss / len(self.val_loader),
len(self.val_loader.dataset),
{'accuracy': accuracy}
)
8.3 Flower Server Strategy
from flwr.server.strategy import FedAvg
from flwr.common import Parameters, FitIns, EvaluateRes
from flwr.server.client_proxy import ClientProxy
class CustomFedAvgStrategy(FedAvg):
"""Custom FedAvg strategy"""
def __init__(
self,
fraction_fit: float = 0.1,
fraction_evaluate: float = 0.1,
min_fit_clients: int = 2,
min_evaluate_clients: int = 2,
min_available_clients: int = 2,
initial_parameters=None,
):
super().__init__(
fraction_fit=fraction_fit,
fraction_evaluate=fraction_evaluate,
min_fit_clients=min_fit_clients,
min_evaluate_clients=min_evaluate_clients,
min_available_clients=min_available_clients,
initial_parameters=initial_parameters,
)
self.round_metrics = []
def configure_fit(
self,
server_round: int,
parameters: Parameters,
client_manager
) -> List[Tuple[ClientProxy, FitIns]]:
"""Configure training for each round"""
# Decay learning rate as rounds progress
lr = 0.01 * (0.99 ** server_round)
config = {
'local_epochs': 5,
'lr': lr,
'round': server_round,
}
fit_ins = FitIns(parameters, config)
clients = client_manager.sample(
num_clients=max(
int(client_manager.num_available() * self.fraction_fit), 1
),
min_num_clients=self.min_fit_clients
)
return [(client, fit_ins) for client in clients]
def aggregate_evaluate(
self,
server_round: int,
results: List[Tuple[ClientProxy, EvaluateRes]],
failures
):
"""Aggregate evaluation results and log"""
aggregated_loss, metrics = super().aggregate_evaluate(
server_round, results, failures
)
if metrics:
print(
f"Round {server_round}: "
f"Aggregated accuracy = {metrics.get('accuracy', 0):.4f}"
)
self.round_metrics.append({
'round': server_round,
'loss': aggregated_loss,
'metrics': metrics
})
return aggregated_loss, metrics
def run_flower_simulation():
"""Run Flower simulation"""
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Subset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = torchvision.datasets.MNIST(
'./data', train=True, download=True, transform=transform
)
test_dataset = torchvision.datasets.MNIST(
'./data', train=False, transform=transform
)
num_clients = 10
def client_fn(cid: str):
cid = int(cid)
n = len(train_dataset) // num_clients
start = cid * n
indices = list(range(start, min(start + n, len(train_dataset))))
train_subset = Subset(train_dataset, indices)
val_indices = list(range(cid * 100, (cid + 1) * 100))
val_subset = Subset(test_dataset, val_indices)
model = SimpleNet(784, 256, 10)
return FlowerClient(
model=model,
train_loader=DataLoader(train_subset, batch_size=32),
val_loader=DataLoader(val_subset, batch_size=64)
).to_client()
global_model = SimpleNet(784, 256, 10)
initial_params = fl.common.ndarrays_to_parameters(
[val.numpy() for val in global_model.state_dict().values()]
)
strategy = CustomFedAvgStrategy(
fraction_fit=0.5,
fraction_evaluate=0.5,
min_fit_clients=2,
min_evaluate_clients=2,
min_available_clients=num_clients,
initial_parameters=initial_params
)
history = fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=num_clients,
config=fl.server.ServerConfig(num_rounds=50),
strategy=strategy,
)
print("\n=== Flower Simulation Complete ===")
print(f"Final distributed loss: {history.losses_distributed[-1]}")
9. Hands-On FL Project: Hospital Federated Learning
9.1 Multi-Hospital Chest X-Ray Diagnosis
A scenario where multiple hospitals jointly train a lung disease diagnostic model without sharing patient chest X-ray data.
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import pandas as pd
class ChestXRayDataset(Dataset):
"""Chest X-Ray Dataset (per hospital)"""
def __init__(
self,
data_dir: str,
labels_file: str,
transform=None,
hospital_id: int = 0
):
self.data_dir = data_dir
self.labels = pd.read_csv(labels_file)
self.transform = transform
self.hospital_id = hospital_id
# Filter to this hospital's data only
self.labels = self.labels[
self.labels['hospital_id'] == hospital_id
].reset_index(drop=True)
self.classes = [
'Normal', 'Pneumonia', 'COVID-19', 'Tuberculosis'
]
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
img_path = os.path.join(
self.data_dir, self.labels.iloc[idx]['filename']
)
image = Image.open(img_path).convert('RGB')
label = self.labels.iloc[idx]['label']
if self.transform:
image = self.transform(image)
return image, label
def build_chest_model(num_classes: int = 4, pretrained: bool = True):
"""ResNet50-based chest X-ray classification model"""
model = models.resnet50(pretrained=pretrained)
in_features = model.fc.in_features
model.fc = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(in_features, 256),
nn.ReLU(),
nn.Linear(256, num_classes)
)
return model
def train_hospital_federation():
"""Run hospital federated learning"""
print("=== Multi-Hospital Federated Learning ===")
print("Patient data at each hospital never leaves its location.")
print("Only model parameter updates are aggregated.\n")
global_model = build_chest_model(num_classes=4)
total_params = sum(p.numel() for p in global_model.parameters())
print(f"Global model parameters: {total_params:,}")
print(
f"Communication cost (FP32): "
f"{total_params * 4 / 1024 / 1024:.1f} MB/round"
)
10. The Future of Federated Learning
Federated Learning is establishing itself as a key technology for resolving the conflict between AI and privacy. Key trends to watch:
Cross-Device vs Cross-Silo
- Cross-device: Millions of mobile devices participate (Google's Gboard)
- Cross-silo: A small number of institutions (hospitals, banks) participate — more trustworthy but smaller scale
FL + LLMs
With the rise of Large Language Models (LLMs), FL has become even more important. When fine-tuning models on user conversations, FL ensures that conversations never leave the user's device. Combining with Parameter-Efficient Fine-Tuning (PEFT, LoRA) further reduces communication costs.
Regulation-Friendly AI
As regulations like GDPR and HIPAA tighten, FL is becoming a practical compliance solution. FL adoption is expected to expand rapidly in healthcare, finance, and legal sectors.
References
- McMahan, H. B., et al. (2017). Communication-Efficient Learning of Deep Networks from Decentralized Data. AISTATS 2017.
- Li, T., et al. (2020). Federated Optimization in Heterogeneous Networks (FedProx). MLSys 2020.
- Karimireddy, S. P., et al. (2020). SCAFFOLD: Stochastic Controlled Averaging for Federated Learning. ICML 2020.
- Bonawitz, K., et al. (2017). Practical Secure Aggregation for Privacy-Preserving Machine Learning. CCS 2017.
- Flower Framework: https://flower.ai/docs/
- Opacus (PyTorch DP): https://opacus.ai/