Skip to content
Published on

Federated Learning Complete Guide: Privacy-Preserving Distributed AI

Authors

Federated Learning Complete Guide: Privacy-Preserving Distributed AI

One of the greatest ironies of modern AI is that as model performance improves, more data is needed, and as more data is collected, the risk of privacy violations grows. Hospitals cannot share patient data, smartphone manufacturers cannot send users' typing patterns to servers, and financial institutions cannot share transaction records with competitors.

Federated Learning (FL) solves this dilemma. Instead of sending data to a central server, the model is sent to where the data lives, trained locally, and only the model updates (weight changes) are aggregated. The concept is: "data stays put; only intelligence moves."

This guide covers everything from the theoretical foundations of FL to hands-on implementation.


1. Federated Learning Fundamentals

1.1 Problems with Traditional Centralized Learning

Consider the traditional ML pipeline: collect patient data from thousands of hospitals, store it on a central server, and train a diagnostic model on that data. What are the problems?

Data Privacy Issues

Patient records, financial transaction histories, and personal communications are extremely sensitive. Transmitting such data to a central server creates the following risks:

  • Eavesdropping risk during transmission
  • Large-scale data breaches from server hacking
  • Loss of trust in the data-owning organization
  • Legal sanctions for regulatory violations

Legal Regulations

Data privacy regulations are tightening worldwide.

  • GDPR (General Data Protection Regulation): The EU's data protection law specifying processing purpose disclosure, consent requirements, and data minimization principles.
  • HIPAA (Health Insurance Portability and Accountability Act): US law protecting patient health information (PHI).
  • CCPA (California Consumer Privacy Act): Gives California residents rights over their personal information collected by businesses.

Communication Costs

Transmitting data from millions of edge devices (smartphones, IoT sensors) to a central location requires enormous network bandwidth. Large data types like images and audio make this even more problematic.

1.2 The Core Idea of Federated Learning

Federated Learning, proposed by McMahan et al. at Google in 2016, is based on the following principle:

"Data stays local; only knowledge (model updates) moves."

The basic FL process:

  1. Initialization: The central server initializes a global model.
  2. Distribution: The server distributes the current global model to selected clients.
  3. Local Training: Each client trains the model on its local data.
  4. Upload: Clients send model updates (gradients or weight differences) to the server.
  5. Aggregation: The server aggregates the updates (e.g., by averaging) to update the global model.
  6. Repeat: Steps 2–5 repeat until convergence.

In this approach, raw data never leaves the client device. Only model parameter updates are transmitted.

1.3 Applications of Federated Learning

Mobile / Edge Devices

Google applied FL to Gboard (mobile keyboard). Users' typing patterns are not sent to servers; instead, the next-word prediction model is improved directly on the device. Hundreds of millions of users' data contribute without any personal information leaving the device.

Healthcare

Better diagnostic AI can be built without sharing patient data from multiple hospitals. For rare diseases, single-hospital data is insufficient for model training, but FL can combine knowledge from multiple hospitals.

Finance

Multiple financial institutions can collaborate on fraud detection and credit scoring without sharing customer data. Especially for cross-border financial transactions, a joint model can be trained while respecting data sovereignty of each country.

Autonomous Vehicles

Multiple automakers can jointly improve road hazard detection models without sharing their proprietary driving data.


2. Federated Learning Architecture

2.1 Client-Server Structure

The most common FL architecture consists of a central aggregation server and multiple clients.

         ┌─────────────────┐
Central Server           (Aggregator)         └────────┬─────────┘
                  │ model distribution / update aggregation
         ┌────────┼────────┐
         ↓        ↓        ↓
    ┌─────────┐ ┌─────────┐ ┌─────────┐
    │Client 1 │ │Client 2 │ │Client 3    (local   │ (local   │ (local   │
    │training)│ │training)│ │training)    └─────────┘ └─────────┘ └─────────┘

Server Responsibilities

  • Maintain and manage the global model
  • Select participating clients for each round
  • Aggregate client updates
  • Distribute the aggregated model to clients

Client Responsibilities

  • Hold local data
  • Fine-tune the received model on local data
  • Transmit the updated model (or gradients)

2.2 Horizontal Federated Learning

Horizontal FL is used when all clients have the same feature space but different data samples. For example, multiple hospitals measure the same diagnostic items (blood pressure, blood glucose, age) but have different patients.

Client 1: [feature1, feature2, feature3] x [samples 1-1000]
Client 2: [feature1, feature2, feature3] x [samples 1001-2000]
Client 3: [feature1, feature2, feature3] x [samples 2001-3000]

Same feature space, different sample space. The most common form of FL.

2.3 Vertical Federated Learning

Vertical FL is used when clients hold the same users (data samples) but have different features. For example, a bank has a user's financial information while a hospital has the same user's medical information.

Client A (Bank):     [financial features] x [users 1-10000]
Client B (Hospital): [medical features]   x [users 1-10000]

Same sample space, different feature space.

Vertical FL requires more complex protocols. Cryptographic techniques are needed for cooperation between a client holding labels and clients holding only features.

2.4 Federated Transfer Learning

When clients have partially overlapping sample and feature spaces, transfer learning techniques are combined. This allows FL to be applied even when data overlap is minimal.


3. The FedAvg Algorithm

3.1 McMahan et al. (2017) Original Algorithm

FedAvg (Federated Averaging) is the foundational FL algorithm, published by McMahan et al. at Google in 2017. The key idea is that each client performs multiple local SGD updates, then the server averages the weights.

Algorithm Overview

Server executes:
  initialize w_0
  for round t = 1, 2, ..., T:
    m = max(C x K, 1)  // C: participation fraction, K: total clients
    S_t = randomly select m clients
    for each client k in S_t (parallel):
      w_{t+1}^k = ClientUpdate(k, w_t)
    w_{t+1} = sum (n_k / n) x w_{t+1}^k  // weighted average

Client k executes:
  B = split local data into batches
  for local epoch e = 1, ..., E:
    for batch b in B:
      w = w - lr x grad_loss(w; b)
  return w

Key Parameters

  • C: fraction of clients participating in each round, chosen between 0 and 1 inclusive of 1
  • E: number of local epochs per client
  • B: local mini-batch size
  • lr: learning rate

When E=1 and B equals all data, this is equivalent to FedSGD. Increasing E reduces communication rounds but increases the risk of client drift.

3.2 Complete FedAvg Implementation

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import numpy as np
from copy import deepcopy
from typing import List, Dict, Tuple
import random


# ========== Model Definition ==========

class SimpleNet(nn.Module):
    """Simple classification network"""
    def __init__(self, input_dim: int, hidden_dim: int, num_classes: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, num_classes)
        )

    def forward(self, x):
        return self.net(x)


# ========== Client ==========

class FLClient:
    """Federated Learning Client"""

    def __init__(
        self,
        client_id: int,
        dataset,
        device: str = 'cpu'
    ):
        self.client_id = client_id
        self.dataset = dataset
        self.device = device

    def local_train(
        self,
        model: nn.Module,
        local_epochs: int,
        batch_size: int,
        lr: float
    ) -> Tuple[Dict, float]:
        """
        Train the model on local data
        Returns: (updated weights, local loss)
        """
        model = deepcopy(model).to(self.device)
        model.train()

        loader = DataLoader(
            self.dataset,
            batch_size=batch_size,
            shuffle=True
        )
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
        criterion = nn.CrossEntropyLoss()

        total_loss = 0.0
        n_batches = 0

        for epoch in range(local_epochs):
            for X, y in loader:
                X, y = X.to(self.device), y.to(self.device)
                optimizer.zero_grad()
                output = model(X)
                loss = criterion(output, y)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                n_batches += 1

        avg_loss = total_loss / max(n_batches, 1)
        return model.state_dict(), avg_loss

    def evaluate(self, model: nn.Module) -> Tuple[float, float]:
        """Evaluate model on local data"""
        model = deepcopy(model).to(self.device)
        model.eval()

        loader = DataLoader(self.dataset, batch_size=64)
        criterion = nn.CrossEntropyLoss()

        total_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for X, y in loader:
                X, y = X.to(self.device), y.to(self.device)
                output = model(X)
                loss = criterion(output, y)
                total_loss += loss.item()

                _, predicted = output.max(1)
                total += y.size(0)
                correct += predicted.eq(y).sum().item()

        return total_loss / len(loader), correct / total


# ========== Server ==========

class FedAvgServer:
    """FedAvg Server"""

    def __init__(
        self,
        global_model: nn.Module,
        clients: List[FLClient],
        fraction: float = 0.1,
        device: str = 'cpu'
    ):
        self.global_model = global_model.to(device)
        self.clients = clients
        self.fraction = fraction
        self.device = device
        self.round_history = []

    def select_clients(self) -> List[FLClient]:
        """Select clients for each round"""
        m = max(int(self.fraction * len(self.clients)), 1)
        return random.sample(self.clients, m)

    def aggregate(
        self,
        client_weights: List[Dict],
        client_sizes: List[int]
    ) -> Dict:
        """
        Weighted average aggregation (FedAvg)
        Weighted by n_k / n
        """
        total_size = sum(client_sizes)
        aggregated = {}

        for key in client_weights[0].keys():
            aggregated[key] = torch.zeros_like(
                client_weights[0][key], dtype=torch.float32
            )
            for w, size in zip(client_weights, client_sizes):
                weight = size / total_size
                aggregated[key] += weight * w[key].float()

        return aggregated

    def train_round(
        self,
        local_epochs: int = 5,
        batch_size: int = 32,
        lr: float = 0.01
    ) -> Dict:
        """Execute one FL round"""
        selected = self.select_clients()

        client_weights = []
        client_sizes = []
        client_losses = []

        for client in selected:
            weights, loss = client.local_train(
                self.global_model, local_epochs, batch_size, lr
            )
            client_weights.append(weights)
            client_sizes.append(len(client.dataset))
            client_losses.append(loss)

        new_weights = self.aggregate(client_weights, client_sizes)
        self.global_model.load_state_dict(new_weights)

        round_info = {
            'num_clients': len(selected),
            'avg_local_loss': np.mean(client_losses),
            'client_losses': client_losses
        }
        self.round_history.append(round_info)
        return round_info

    def evaluate_global(self, test_loader: DataLoader) -> Tuple[float, float]:
        """Evaluate global model"""
        self.global_model.eval()
        criterion = nn.CrossEntropyLoss()

        total_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for X, y in test_loader:
                X, y = X.to(self.device), y.to(self.device)
                output = self.global_model(X)
                loss = criterion(output, y)
                total_loss += loss.item()

                _, predicted = output.max(1)
                total += y.size(0)
                correct += predicted.eq(y).sum().item()

        return total_loss / len(test_loader), correct / total

    def federated_train(
        self,
        num_rounds: int,
        local_epochs: int = 5,
        batch_size: int = 32,
        lr: float = 0.01,
        test_loader: DataLoader = None
    ):
        """Full FL training loop"""
        print(f"FL Training: {num_rounds} rounds, {len(self.clients)} clients")

        for round_num in range(1, num_rounds + 1):
            round_info = self.train_round(local_epochs, batch_size, lr)

            if test_loader and round_num % 10 == 0:
                test_loss, test_acc = self.evaluate_global(test_loader)
                print(
                    f"Round {round_num:3d}/{num_rounds} | "
                    f"Clients: {round_info['num_clients']} | "
                    f"Local Loss: {round_info['avg_local_loss']:.4f} | "
                    f"Test Acc: {test_acc:.4f}"
                )

        print("Federated training complete!")


# ========== Non-IID Data Partitioning ==========

def create_non_iid_partition(
    dataset,
    num_clients: int,
    num_classes: int,
    alpha: float = 0.5
) -> List[List[int]]:
    """
    Non-IID data partitioning using Dirichlet distribution
    Lower alpha = more heterogeneous distribution
    """
    labels = np.array([dataset[i][1] for i in range(len(dataset))])
    client_indices = [[] for _ in range(num_clients)]

    for cls in range(num_classes):
        cls_indices = np.where(labels == cls)[0]
        np.random.shuffle(cls_indices)

        proportions = np.random.dirichlet([alpha] * num_clients)
        proportions = (proportions * len(cls_indices)).astype(int)
        proportions[-1] = len(cls_indices) - proportions[:-1].sum()

        start = 0
        for k, prop in enumerate(proportions):
            client_indices[k].extend(
                cls_indices[start:start + prop].tolist()
            )
            start += prop

    return client_indices


def run_fedavg_demo():
    """Run FedAvg demo"""
    import torchvision
    import torchvision.transforms as transforms

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = torchvision.datasets.MNIST(
        './data', train=True, download=True, transform=transform
    )
    test_dataset = torchvision.datasets.MNIST(
        './data', train=False, transform=transform
    )

    num_clients = 20
    client_indices = create_non_iid_partition(
        train_dataset, num_clients, num_classes=10, alpha=0.5
    )

    clients = []
    for k in range(num_clients):
        subset = Subset(train_dataset, client_indices[k])
        clients.append(FLClient(k, subset))

    print(f"Number of clients: {num_clients}")
    print(f"Avg samples per client: {np.mean([len(c.dataset) for c in clients]):.1f}")

    global_model = SimpleNet(784, 256, 10)
    server = FedAvgServer(global_model, clients, fraction=0.1)
    test_loader = DataLoader(test_dataset, batch_size=256)

    server.federated_train(
        num_rounds=100,
        local_epochs=5,
        batch_size=32,
        lr=0.01,
        test_loader=test_loader
    )

    _, final_acc = server.evaluate_global(test_loader)
    print(f"\nFinal Test Accuracy: {final_acc:.4f}")


if __name__ == '__main__':
    run_fedavg_demo()

4. Challenges in Federated Learning

4.1 Data Heterogeneity (Non-IID Problem)

The biggest technical challenge in FL is Non-IID (Non-Independent and Identically Distributed) data. In real environments, each client's data distribution is different.

Types of Non-IID

  • Feature distribution skew: Different input distributions per client (e.g., regional weather patterns)
  • Label distribution skew: Different class proportions per client (each smartphone user has different frequent words)
  • Concept drift: Different labels for the same input (different cultural contexts by region)
  • Quantity imbalance: Extremely different amounts of data per client

Non-IID data in FedAvg causes client drift. When each client's local optimum differs from the global optimum, aggregation becomes ineffective.

4.2 System Heterogeneity

In real FL systems, device performance, battery status, and network connectivity vary widely.

  • Compute heterogeneity: GPU-equipped servers and low-end mobile devices participating simultaneously
  • Network heterogeneity: High-speed wired and unstable mobile networks coexisting
  • Memory heterogeneity: Some clients may not be able to load the full model

4.3 The Straggler Problem

If some clients are slow or unresponsive, the entire training is delayed. In Synchronous FL, all selected clients must return their updates before the next round. Solutions include:

  • Asynchronous FL: Immediately aggregate updates from responding clients
  • Timeout setting: Only use clients that respond within a time limit
  • FedProx: Add a proximal term to local updates to tolerate stragglers

5. Advanced FL Algorithms

5.1 FedProx: Handling Non-IID

FedProx, proposed by Li et al. in 2020, adds a proximal term to local optimization. This term prevents the local model from drifting too far from the global model.

FedProx Objective Function

A proximal term is added to the local objective function.

h_k(w; w^t) = F_k(w) + (mu/2) x ||w - w^t||^2

Here mu is a hyperparameter controlling the strength of the proximal term. When mu=0, it is equivalent to FedAvg.

class FedProxClient(FLClient):
    """FedProx client: adds proximal term"""

    def local_train_prox(
        self,
        model: nn.Module,
        global_weights: Dict,
        local_epochs: int,
        batch_size: int,
        lr: float,
        mu: float = 0.01
    ) -> Tuple[Dict, float]:
        """
        Local training with proximal term
        h_k(w) = F_k(w) + (mu/2) * ||w - w^t||^2
        """
        model = deepcopy(model).to(self.device)
        model.train()

        global_model = deepcopy(model)
        global_model.load_state_dict(global_weights)
        for param in global_model.parameters():
            param.requires_grad = False

        loader = DataLoader(
            self.dataset,
            batch_size=batch_size,
            shuffle=True
        )
        optimizer = optim.SGD(model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()

        total_loss = 0.0
        n_batches = 0

        for epoch in range(local_epochs):
            for X, y in loader:
                X, y = X.to(self.device), y.to(self.device)
                optimizer.zero_grad()

                output = model(X)
                task_loss = criterion(output, y)

                # Proximal term: (mu/2) * ||w - w_global||^2
                prox_loss = 0.0
                for w, w_global in zip(
                    model.parameters(),
                    global_model.parameters()
                ):
                    prox_loss += (mu / 2) * torch.norm(w - w_global) ** 2

                total_batch_loss = task_loss + prox_loss
                total_batch_loss.backward()
                optimizer.step()

                total_loss += task_loss.item()
                n_batches += 1

        avg_loss = total_loss / max(n_batches, 1)
        return model.state_dict(), avg_loss

5.2 SCAFFOLD: Correcting Client Drift

SCAFFOLD (Stochastic Controlled Averaging for Federated Learning) uses control variates to directly correct client drift. Each client and server maintain control variates c_k and c to correct gradient bias.

class ScaffoldClient:
    """SCAFFOLD client"""

    def __init__(self, client_id, dataset, device='cpu'):
        self.client_id = client_id
        self.dataset = dataset
        self.device = device
        self.c_k = None  # client control variate

    def init_control_variate(self, model: nn.Module):
        """Initialize control variate"""
        self.c_k = {
            name: torch.zeros_like(param)
            for name, param in model.named_parameters()
        }

    def local_train_scaffold(
        self,
        model: nn.Module,
        server_control: Dict,
        local_epochs: int,
        batch_size: int,
        lr: float
    ) -> Tuple[Dict, Dict, float]:
        """
        SCAFFOLD local training
        Returns: (updated weights, control variate update, loss)
        """
        if self.c_k is None:
            self.init_control_variate(model)

        model = deepcopy(model).to(self.device)
        model.train()

        initial_weights = deepcopy(model.state_dict())
        loader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True)
        criterion = nn.CrossEntropyLoss()

        total_loss = 0.0
        n_batches = 0
        total_steps = local_epochs * len(loader)

        for epoch in range(local_epochs):
            for X, y in loader:
                X, y = X.to(self.device), y.to(self.device)
                output = model(X)
                loss = criterion(output, y)
                loss.backward()

                with torch.no_grad():
                    for name, param in model.named_parameters():
                        if param.grad is not None:
                            # SCAFFOLD correction: g - c_k + c
                            correction = (
                                self.c_k[name].to(self.device)
                                - server_control[name].to(self.device)
                            )
                            param -= lr * (param.grad + correction)
                            param.grad.zero_()

                total_loss += loss.item()
                n_batches += 1

        final_weights = model.state_dict()

        # Update control variate
        # c_k+ = c_k - c + (1 / (K * lr)) * (w_0 - w_K)
        new_c_k = {}
        c_k_diff = {}
        for name in self.c_k:
            w_diff = (
                initial_weights[name].float() - final_weights[name].float()
            )
            new_c_k[name] = (
                self.c_k[name]
                - server_control[name]
                + w_diff / (total_steps * lr)
            )
            c_k_diff[name] = new_c_k[name] - self.c_k[name]

        self.c_k = new_c_k
        avg_loss = total_loss / max(n_batches, 1)
        return final_weights, c_k_diff, avg_loss

6. Differential Privacy

6.1 Mathematical Definition of DP

Differential Privacy (DP) is a mathematical framework for privacy protection. Intuitively, it guarantees that "the output distribution does not change significantly when a single data point is added or removed."

Epsilon-Delta DP Definition

A randomized mechanism M satisfies (epsilon, delta)-DP if for all neighboring datasets D and D' and all output sets S:

Pr[M(D) in S] <= exp(epsilon) x Pr[M(D') in S] + delta
  • epsilon: privacy budget — lower values give stronger privacy guarantees
  • delta: failure probability — typically set below 1/|dataset|

6.2 Gaussian Mechanism and Clipping

To apply DP in FL, noise must be added to each client's update.

Gradient Clipping: First clip the L2 norm of gradients to a maximum value C.

g_clipped = g x min(1, C / ||g||_2)

Noise Addition: Add Gaussian noise to the clipped gradient.

g_dp = g_clipped + N(0, sigma^2 x C^2 x I)

Here sigma is the noise multiplier.

6.3 DP-FL Implementation

import torch
import torch.nn as nn
from opacus import PrivacyEngine
from opacus.validators import ModuleValidator


def make_private_model(model: nn.Module) -> nn.Module:
    """Convert to Opacus-compatible model (BatchNorm -> GroupNorm)"""
    model = ModuleValidator.fix(model)
    return model


class DPFLClient:
    """FL client with differential privacy"""

    def __init__(
        self,
        client_id: int,
        dataset,
        target_epsilon: float = 1.0,
        target_delta: float = 1e-5,
        max_grad_norm: float = 1.0,
        device: str = 'cpu'
    ):
        self.client_id = client_id
        self.dataset = dataset
        self.target_epsilon = target_epsilon
        self.target_delta = target_delta
        self.max_grad_norm = max_grad_norm
        self.device = device

    def dp_local_train(
        self,
        model: nn.Module,
        local_epochs: int,
        batch_size: int,
        lr: float
    ) -> Tuple[Dict, float, float]:
        """
        DP local training
        Returns: (weights, loss, epsilon used)
        """
        model = make_private_model(deepcopy(model)).to(self.device)
        model.train()

        loader = DataLoader(
            self.dataset,
            batch_size=batch_size,
            shuffle=True,
            drop_last=True  # required by Opacus
        )
        optimizer = optim.SGD(model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()

        privacy_engine = PrivacyEngine()
        model, optimizer, loader = privacy_engine.make_private_with_epsilon(
            module=model,
            optimizer=optimizer,
            data_loader=loader,
            epochs=local_epochs,
            target_epsilon=self.target_epsilon,
            target_delta=self.target_delta,
            max_grad_norm=self.max_grad_norm
        )

        total_loss = 0.0
        n_batches = 0

        for epoch in range(local_epochs):
            for X, y in loader:
                X, y = X.to(self.device), y.to(self.device)
                optimizer.zero_grad()
                output = model(X)
                loss = criterion(output, y)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                n_batches += 1

        epsilon_used = privacy_engine.get_epsilon(self.target_delta)
        avg_loss = total_loss / max(n_batches, 1)

        # Remove Opacus wrapper to return clean weights
        clean_weights = {
            k.replace('_module.', ''): v
            for k, v in model.state_dict().items()
        }

        return clean_weights, avg_loss, epsilon_used


class DPFLServer:
    """DP-FL server: server-side DP aggregation"""

    def __init__(
        self,
        global_model: nn.Module,
        clients: List['DPFLClient'],
        noise_multiplier: float = 0.5,
        max_grad_norm: float = 1.0,
        device: str = 'cpu'
    ):
        self.global_model = global_model.to(device)
        self.clients = clients
        self.noise_multiplier = noise_multiplier
        self.max_grad_norm = max_grad_norm
        self.device = device

    def clip_and_aggregate(
        self,
        client_weights: List[Dict],
        reference_weights: Dict
    ) -> Dict:
        """
        Server-side clipping and noise addition
        Clip each update then add Gaussian noise
        """
        n_clients = len(client_weights)

        updates = []
        for w in client_weights:
            delta = {
                k: w[k].float() - reference_weights[k].float()
                for k in reference_weights
            }
            updates.append(delta)

        clipped_updates = []
        for delta in updates:
            total_norm = torch.sqrt(
                sum(torch.norm(v) ** 2 for v in delta.values())
            )
            clip_factor = min(1.0, self.max_grad_norm / (total_norm + 1e-8))
            clipped = {k: v * clip_factor for k, v in delta.items()}
            clipped_updates.append(clipped)

        summed = {}
        for key in reference_weights:
            summed[key] = sum(u[key] for u in clipped_updates)

        sigma = self.noise_multiplier * self.max_grad_norm
        noisy = {}
        for key in summed:
            noise = torch.randn_like(summed[key]) * sigma
            noisy[key] = (summed[key] + noise) / n_clients

        aggregated = {
            k: reference_weights[k].float() + noisy[k]
            for k in reference_weights
        }
        return aggregated

7. Secure Aggregation

7.1 The Need for Encrypted Aggregation

DP provides statistical privacy, but the server can still see each client's update individually. Secure Aggregation (SecAgg) is a cryptographic protocol that ensures the server sees only the aggregated result.

7.2 Secret Sharing-Based SecAgg

The Secure Aggregation protocol by Bonawitz et al. (2017) uses the following principle:

  • Clients exchange masks (random number pairs) with each other.
  • Each client adds its mask to its update before sending to the server.
  • When the server sums all masked updates, the masks cancel out, leaving the pure sum of updates.
class SecureAggregation:
    """
    Simplified secure aggregation simulation
    Real implementations use Diffie-Hellman key exchange
    """

    def __init__(self, num_clients: int, seed_length: int = 32):
        self.num_clients = num_clients
        self.seed_length = seed_length

    def generate_pairwise_masks(
        self,
        client_id: int,
        model_shape: Dict[str, torch.Size],
        shared_seeds: Dict
    ) -> Dict:
        """
        Generate pairwise masks between clients
        For client pair i and j: add if i > j, subtract if i < j
        """
        mask = {k: torch.zeros(v) for k, v in model_shape.items()}

        for other_id, seed in shared_seeds.items():
            torch.manual_seed(seed)
            direction = 1 if client_id > other_id else -1

            for key in mask:
                mask[key] += direction * torch.randn(model_shape[key])

        return mask

    def client_mask_update(
        self,
        local_update: Dict,
        mask: Dict
    ) -> Dict:
        """Apply mask to update"""
        return {k: local_update[k] + mask[k] for k in local_update}

    def server_aggregate_masked(
        self,
        masked_updates: List[Dict]
    ) -> Dict:
        """Sum masked updates (masks cancel each other out)"""
        aggregated = {}
        for key in masked_updates[0]:
            aggregated[key] = sum(u[key] for u in masked_updates)
            aggregated[key] /= len(masked_updates)
        return aggregated

7.3 Homomorphic Encryption Overview

Homomorphic Encryption (HE) is a cryptographic scheme that allows computations to be performed on encrypted data. In FL, HE allows the server to aggregate client updates without decrypting them.

  • Partially Homomorphic Encryption (PHE): Supports only addition or multiplication (Paillier cryptosystem)
  • Fully Homomorphic Encryption (FHE): Supports arbitrary operations (CKKS, BGV) — very high computational cost

In practical FL, PHE supporting only addition is sufficient for aggregation (weighted average).


8. The Flower Framework

8.1 Introduction to Flower

Flower (flwr) is a Python framework for federated learning that is framework-agnostic. It works with PyTorch, TensorFlow, JAX, and other ML frameworks.

pip install flwr
pip install flwr[simulation]

8.2 Flower Client Implementation

import flwr as fl
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from typing import Dict, List, Tuple
import numpy as np


class FlowerClient(fl.client.NumPyClient):
    """Flower FL Client"""

    def __init__(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: DataLoader,
        device: str = 'cpu'
    ):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.criterion = nn.CrossEntropyLoss()

    def get_parameters(self, config: Dict) -> List[np.ndarray]:
        """Return current model parameters as NumPy arrays"""
        return [
            val.cpu().numpy()
            for _, val in self.model.state_dict().items()
        ]

    def set_parameters(self, parameters: List[np.ndarray]):
        """Set model parameters from NumPy arrays"""
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = {k: torch.tensor(v) for k, v in params_dict}
        self.model.load_state_dict(state_dict, strict=True)

    def fit(
        self,
        parameters: List[np.ndarray],
        config: Dict
    ) -> Tuple[List[np.ndarray], int, Dict]:
        """Local training with parameters received from server"""
        self.set_parameters(parameters)

        local_epochs = int(config.get('local_epochs', 1))
        lr = float(config.get('lr', 0.01))

        self.model.train()
        optimizer = torch.optim.SGD(
            self.model.parameters(), lr=lr, momentum=0.9
        )

        train_loss = 0.0
        for epoch in range(local_epochs):
            for X, y in self.train_loader:
                X, y = X.to(self.device), y.to(self.device)
                optimizer.zero_grad()
                loss = self.criterion(self.model(X), y)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()

        return (
            self.get_parameters(config={}),
            len(self.train_loader.dataset),
            {'train_loss': train_loss}
        )

    def evaluate(
        self,
        parameters: List[np.ndarray],
        config: Dict
    ) -> Tuple[float, int, Dict]:
        """Evaluate with parameters received from server"""
        self.set_parameters(parameters)
        self.model.eval()

        val_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for X, y in self.val_loader:
                X, y = X.to(self.device), y.to(self.device)
                output = self.model(X)
                val_loss += self.criterion(output, y).item()

                _, predicted = output.max(1)
                total += y.size(0)
                correct += predicted.eq(y).sum().item()

        accuracy = correct / total
        return (
            val_loss / len(self.val_loader),
            len(self.val_loader.dataset),
            {'accuracy': accuracy}
        )

8.3 Flower Server Strategy

from flwr.server.strategy import FedAvg
from flwr.common import Parameters, FitIns, EvaluateRes
from flwr.server.client_proxy import ClientProxy


class CustomFedAvgStrategy(FedAvg):
    """Custom FedAvg strategy"""

    def __init__(
        self,
        fraction_fit: float = 0.1,
        fraction_evaluate: float = 0.1,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
        initial_parameters=None,
    ):
        super().__init__(
            fraction_fit=fraction_fit,
            fraction_evaluate=fraction_evaluate,
            min_fit_clients=min_fit_clients,
            min_evaluate_clients=min_evaluate_clients,
            min_available_clients=min_available_clients,
            initial_parameters=initial_parameters,
        )
        self.round_metrics = []

    def configure_fit(
        self,
        server_round: int,
        parameters: Parameters,
        client_manager
    ) -> List[Tuple[ClientProxy, FitIns]]:
        """Configure training for each round"""
        # Decay learning rate as rounds progress
        lr = 0.01 * (0.99 ** server_round)

        config = {
            'local_epochs': 5,
            'lr': lr,
            'round': server_round,
        }

        fit_ins = FitIns(parameters, config)
        clients = client_manager.sample(
            num_clients=max(
                int(client_manager.num_available() * self.fraction_fit), 1
            ),
            min_num_clients=self.min_fit_clients
        )
        return [(client, fit_ins) for client in clients]

    def aggregate_evaluate(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, EvaluateRes]],
        failures
    ):
        """Aggregate evaluation results and log"""
        aggregated_loss, metrics = super().aggregate_evaluate(
            server_round, results, failures
        )

        if metrics:
            print(
                f"Round {server_round}: "
                f"Aggregated accuracy = {metrics.get('accuracy', 0):.4f}"
            )
            self.round_metrics.append({
                'round': server_round,
                'loss': aggregated_loss,
                'metrics': metrics
            })

        return aggregated_loss, metrics


def run_flower_simulation():
    """Run Flower simulation"""
    import torchvision
    import torchvision.transforms as transforms
    from torch.utils.data import Subset

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = torchvision.datasets.MNIST(
        './data', train=True, download=True, transform=transform
    )
    test_dataset = torchvision.datasets.MNIST(
        './data', train=False, transform=transform
    )

    num_clients = 10

    def client_fn(cid: str):
        cid = int(cid)
        n = len(train_dataset) // num_clients
        start = cid * n
        indices = list(range(start, min(start + n, len(train_dataset))))
        train_subset = Subset(train_dataset, indices)

        val_indices = list(range(cid * 100, (cid + 1) * 100))
        val_subset = Subset(test_dataset, val_indices)

        model = SimpleNet(784, 256, 10)
        return FlowerClient(
            model=model,
            train_loader=DataLoader(train_subset, batch_size=32),
            val_loader=DataLoader(val_subset, batch_size=64)
        ).to_client()

    global_model = SimpleNet(784, 256, 10)
    initial_params = fl.common.ndarrays_to_parameters(
        [val.numpy() for val in global_model.state_dict().values()]
    )

    strategy = CustomFedAvgStrategy(
        fraction_fit=0.5,
        fraction_evaluate=0.5,
        min_fit_clients=2,
        min_evaluate_clients=2,
        min_available_clients=num_clients,
        initial_parameters=initial_params
    )

    history = fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=num_clients,
        config=fl.server.ServerConfig(num_rounds=50),
        strategy=strategy,
    )

    print("\n=== Flower Simulation Complete ===")
    print(f"Final distributed loss: {history.losses_distributed[-1]}")

9. Hands-On FL Project: Hospital Federated Learning

9.1 Multi-Hospital Chest X-Ray Diagnosis

A scenario where multiple hospitals jointly train a lung disease diagnostic model without sharing patient chest X-ray data.

import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import pandas as pd


class ChestXRayDataset(Dataset):
    """Chest X-Ray Dataset (per hospital)"""

    def __init__(
        self,
        data_dir: str,
        labels_file: str,
        transform=None,
        hospital_id: int = 0
    ):
        self.data_dir = data_dir
        self.labels = pd.read_csv(labels_file)
        self.transform = transform
        self.hospital_id = hospital_id

        # Filter to this hospital's data only
        self.labels = self.labels[
            self.labels['hospital_id'] == hospital_id
        ].reset_index(drop=True)

        self.classes = [
            'Normal', 'Pneumonia', 'COVID-19', 'Tuberculosis'
        ]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img_path = os.path.join(
            self.data_dir, self.labels.iloc[idx]['filename']
        )
        image = Image.open(img_path).convert('RGB')
        label = self.labels.iloc[idx]['label']

        if self.transform:
            image = self.transform(image)

        return image, label


def build_chest_model(num_classes: int = 4, pretrained: bool = True):
    """ResNet50-based chest X-ray classification model"""
    model = models.resnet50(pretrained=pretrained)

    in_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(0.3),
        nn.Linear(in_features, 256),
        nn.ReLU(),
        nn.Linear(256, num_classes)
    )
    return model


def train_hospital_federation():
    """Run hospital federated learning"""
    print("=== Multi-Hospital Federated Learning ===")
    print("Patient data at each hospital never leaves its location.")
    print("Only model parameter updates are aggregated.\n")

    global_model = build_chest_model(num_classes=4)

    total_params = sum(p.numel() for p in global_model.parameters())
    print(f"Global model parameters: {total_params:,}")
    print(
        f"Communication cost (FP32): "
        f"{total_params * 4 / 1024 / 1024:.1f} MB/round"
    )

10. The Future of Federated Learning

Federated Learning is establishing itself as a key technology for resolving the conflict between AI and privacy. Key trends to watch:

Cross-Device vs Cross-Silo

  • Cross-device: Millions of mobile devices participate (Google's Gboard)
  • Cross-silo: A small number of institutions (hospitals, banks) participate — more trustworthy but smaller scale

FL + LLMs

With the rise of Large Language Models (LLMs), FL has become even more important. When fine-tuning models on user conversations, FL ensures that conversations never leave the user's device. Combining with Parameter-Efficient Fine-Tuning (PEFT, LoRA) further reduces communication costs.

Regulation-Friendly AI

As regulations like GDPR and HIPAA tighten, FL is becoming a practical compliance solution. FL adoption is expected to expand rapidly in healthcare, finance, and legal sectors.


References

  • McMahan, H. B., et al. (2017). Communication-Efficient Learning of Deep Networks from Decentralized Data. AISTATS 2017.
  • Li, T., et al. (2020). Federated Optimization in Heterogeneous Networks (FedProx). MLSys 2020.
  • Karimireddy, S. P., et al. (2020). SCAFFOLD: Stochastic Controlled Averaging for Federated Learning. ICML 2020.
  • Bonawitz, K., et al. (2017). Practical Secure Aggregation for Privacy-Preserving Machine Learning. CCS 2017.
  • Flower Framework: https://flower.ai/docs/
  • Opacus (PyTorch DP): https://opacus.ai/