Skip to content

Split View: 연합 학습(Federated Learning) 완전 가이드: 프라이버시 보존 분산 AI

|

연합 학습(Federated Learning) 완전 가이드: 프라이버시 보존 분산 AI

연합 학습(Federated Learning) 완전 가이드: 프라이버시 보존 분산 AI

현대 AI의 가장 큰 아이러니 중 하나는 모델의 성능이 올라갈수록 더 많은 데이터가 필요하고, 더 많은 데이터를 모을수록 프라이버시 침해 위험이 커진다는 것이다. 병원은 환자 데이터를 공유할 수 없고, 스마트폰 제조사는 사용자의 타이핑 패턴을 서버로 보낼 수 없으며, 금융기관은 거래 내역을 경쟁사와 공유할 수 없다.

연합 학습(Federated Learning, FL)은 이 딜레마를 해결하는 패러다임이다. 데이터를 중앙 서버로 보내는 대신, 모델을 데이터가 있는 곳으로 보내서 학습한 뒤 모델의 업데이트(가중치 변화)만 집계한다. "데이터는 움직이지 않고, 지능만 움직인다"는 개념이다.

이 가이드에서는 FL의 이론적 기반부터 실전 구현까지 완전히 다룬다.


1. 연합 학습 기초

1.1 전통적 중앙집중식 학습의 문제

전통적인 머신러닝 파이프라인을 생각해보자. 수천 개의 병원에서 환자 데이터를 수집해 중앙 서버에 저장하고, 그 데이터로 진단 모델을 훈련한다. 이 방식의 문제는 무엇인가?

데이터 프라이버시 문제

환자 데이터, 금융 거래 내역, 개인 통신 내용 등은 매우 민감한 정보다. 이런 데이터를 중앙 서버로 전송하면 다음과 같은 위험이 생긴다.

  • 전송 중 도청(eavesdropping) 위험
  • 서버 해킹으로 인한 대규모 데이터 유출
  • 데이터를 소유한 기관의 신뢰 상실
  • 규제 위반으로 인한 법적 제재

법적 규제

전 세계적으로 데이터 프라이버시 규제가 강화되고 있다.

  • GDPR (General Data Protection Regulation): 유럽연합의 개인정보보호법. 데이터 처리 목적의 명시, 동의 요건, 데이터 최소화 원칙 등을 규정한다.
  • HIPAA (Health Insurance Portability and Accountability Act): 미국의 의료정보보호법. 환자의 건강 정보(PHI)를 보호한다.
  • 개인정보보호법(Korea): 한국의 개인정보보호법은 개인정보의 수집, 이용, 제공에 엄격한 제한을 둔다.

통신 비용

수백만 개의 엣지 디바이스(스마트폰, IoT 센서)에서 데이터를 중앙으로 전송하려면 막대한 네트워크 대역폭이 필요하다. 이미지, 오디오 같은 대용량 데이터는 더욱 큰 문제가 된다.

1.2 연합 학습의 핵심 아이디어

연합 학습은 2016년 Google의 McMahan 등이 제안한 개념으로, 다음 원칙에 기반한다.

"데이터는 현장에 남고, 지식(모델 업데이트)만 이동한다"

FL의 기본 프로세스는 다음과 같다.

  1. 초기화: 중앙 서버가 글로벌 모델을 초기화한다.
  2. 배포: 서버가 선택된 클라이언트들에게 현재 글로벌 모델을 배포한다.
  3. 로컬 학습: 각 클라이언트는 자신의 로컬 데이터로 모델을 훈련한다.
  4. 업로드: 클라이언트들이 모델 업데이트(그래디언트 또는 가중치 차이)를 서버로 전송한다.
  5. 집계: 서버가 업데이트들을 집계(평균 등)하여 글로벌 모델을 갱신한다.
  6. 반복: 수렴할 때까지 2-5단계를 반복한다.

이 방식에서 원본 데이터는 클라이언트 장치를 떠나지 않는다. 전송되는 것은 모델 파라미터의 업데이트뿐이다.

1.3 연합 학습의 응용 분야

모바일 / 엣지 디바이스

Google은 Gboard(모바일 키보드)에 FL을 적용했다. 사용자의 타이핑 패턴을 서버로 보내지 않고, 기기에서 직접 다음 단어 예측 모델을 개선한다. 수백만 명의 사용자 데이터를 활용하면서도 개인 정보는 기기 밖으로 나가지 않는다.

의료 분야

여러 병원의 환자 데이터를 공유하지 않고도 더 나은 진단 AI를 만들 수 있다. 예를 들어 희귀 질환의 경우, 단일 병원의 데이터만으로는 모델 학습이 어렵지만, FL로 여러 병원의 지식을 결합할 수 있다.

금융

사기 탐지, 신용 평가 등에서 여러 금융기관이 고객 데이터를 공유하지 않고도 협력할 수 있다. 특히 국경을 넘는 금융 거래에서 각국의 데이터 주권을 지키면서 공동 모델을 학습할 수 있다.

자율주행

여러 자동차 제조사가 각자의 주행 데이터를 공유하지 않으면서 도로 위험 상황 감지 모델을 공동으로 개선할 수 있다.


2. 연합 학습 아키텍처

2.1 클라이언트-서버 구조

가장 일반적인 FL 아키텍처는 중앙 집계 서버(aggregation server)와 여러 클라이언트(client)로 구성된다.

         ┌─────────────────┐
         │   중앙 서버      │
           (Aggregator)         └────────┬────────┘
                  │ 모델 배포 / 업데이트 집계
         ┌────────┼────────┐
         ↓        ↓        ↓
    ┌─────────┐ ┌─────────┐ ┌─────────┐
    │클라이언트1│ │클라이언트2│ │클라이언트3    (로컬 데이│ (로컬 데이│ (로컬 데이│
    │터로 학습)│ │터로 학습)│ │터로 학습)    └─────────┘ └─────────┘ └─────────┘

서버의 역할

  • 글로벌 모델 유지 및 관리
  • 각 라운드에서 참여할 클라이언트 선택
  • 클라이언트 업데이트 집계
  • 집계된 모델을 클라이언트에 배포

클라이언트의 역할

  • 로컬 데이터 보유
  • 서버로부터 받은 모델을 로컬 데이터로 미세 조정
  • 업데이트된 모델(또는 그래디언트) 전송

2.2 수평 연합 학습 (Horizontal Federated Learning)

수평 FL은 모든 클라이언트가 동일한 특성(feature) 공간을 가지지만 다른 데이터 샘플을 보유한 경우에 사용된다. 예를 들어, 여러 병원이 동일한 진단 항목(혈압, 혈당, 나이 등)을 측정하지만 서로 다른 환자를 가진 경우이다.

클라이언트 1: [특성1, 특성2, 특성3] × [샘플 1~1000]
클라이언트 2: [특성1, 특성2, 특성3] × [샘플 1001~2000]
클라이언트 3: [특성1, 특성2, 특성3] × [샘플 2001~3000]

같은 특성 공간, 다른 샘플 공간. 가장 일반적인 FL 형태이다.

2.3 수직 연합 학습 (Vertical Federated Learning)

수직 FL은 클라이언트들이 동일한 사용자(데이터 샘플)를 보유하지만 서로 다른 특성을 가진 경우이다. 예를 들어, 은행은 사용자의 금융 정보를, 병원은 동일 사용자의 의료 정보를 가진 경우이다.

클라이언트 A (은행): [금융 특성] × [사용자 1~10000]
클라이언트 B (병원): [의료 특성] × [사용자 1~10000]

같은 샘플 공간, 다른 특성 공간.

수직 FL은 더 복잡한 프로토콜이 필요하다. 라벨을 가진 클라이언트와 특성만 가진 클라이언트 간의 협력을 위해 암호화 기술이 필요하다.

2.4 연합 이전 학습 (Federated Transfer Learning)

클라이언트들이 부분적으로 겹치는 샘플과 특성 공간을 가질 때, 이전 학습(Transfer Learning) 기법과 결합하는 방식이다. 데이터 오버랩이 거의 없는 상황에서도 FL을 적용할 수 있다.


3. FedAvg 알고리즘

3.1 McMahan et al. (2017) 원본 알고리즘

FedAvg(Federated Averaging)는 FL의 기초가 되는 알고리즘으로, 2017년 Google의 McMahan 등이 발표했다. 핵심 아이디어는 각 클라이언트가 여러 번의 로컬 SGD 업데이트를 수행한 후, 서버에서 가중치를 평균내는 것이다.

알고리즘 개요

서버 실행:
  w_0 초기화
  for 라운드 t = 1, 2, ..., T:
    m = max(C × K, 1)  // C: 참여 비율, K: 전체 클라이언트 수
    S_t = 무작위로 m개 클라이언트 선택
    for 각 클라이언트 k in S_t (병렬):
      w_{t+1}^k = ClientUpdate(k, w_t)
    w_{t+1} = Σ (n_k / n) × w_{t+1}^k  // 가중 평균

클라이언트 k 실행:
  B = 로컬 데이터를 배치로 분할
  for 로컬 에폭 e = 1, ..., E:
    for 배치 b in B:
      w = w - η × ∇ℓ(w; b)
  return w

핵심 파라미터

  • C: 각 라운드에서 참여하는 클라이언트의 비율 (0 < C ≤ 1)
  • E: 각 클라이언트의 로컬 에폭 수
  • B: 로컬 미니배치 크기
  • η: 학습률

E=1이고 B=전체 데이터이면 FedSGD와 동일하다. E를 늘릴수록 통신 횟수가 줄지만 클라이언트 드리프트(drift) 위험이 커진다.

3.2 완전한 FedAvg 구현

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import numpy as np
from copy import deepcopy
from typing import List, Dict, Tuple
import random


# ========== 모델 정의 ==========

class SimpleNet(nn.Module):
    """간단한 분류 네트워크"""
    def __init__(self, input_dim: int, hidden_dim: int, num_classes: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, num_classes)
        )

    def forward(self, x):
        return self.net(x)


# ========== 클라이언트 ==========

class FLClient:
    """연합 학습 클라이언트"""

    def __init__(
        self,
        client_id: int,
        dataset,
        device: str = 'cpu'
    ):
        self.client_id = client_id
        self.dataset = dataset
        self.device = device

    def local_train(
        self,
        model: nn.Module,
        local_epochs: int,
        batch_size: int,
        lr: float
    ) -> Tuple[Dict, float]:
        """
        로컬 데이터로 모델 훈련
        Returns: (업데이트된 가중치, 로컬 손실)
        """
        model = deepcopy(model).to(self.device)
        model.train()

        loader = DataLoader(
            self.dataset,
            batch_size=batch_size,
            shuffle=True
        )
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
        criterion = nn.CrossEntropyLoss()

        total_loss = 0.0
        n_batches = 0

        for epoch in range(local_epochs):
            for X, y in loader:
                X, y = X.to(self.device), y.to(self.device)
                optimizer.zero_grad()
                output = model(X)
                loss = criterion(output, y)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                n_batches += 1

        avg_loss = total_loss / max(n_batches, 1)
        return model.state_dict(), avg_loss

    def evaluate(self, model: nn.Module) -> Tuple[float, float]:
        """로컬 데이터로 모델 평가"""
        model = deepcopy(model).to(self.device)
        model.eval()

        loader = DataLoader(self.dataset, batch_size=64)
        criterion = nn.CrossEntropyLoss()

        total_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for X, y in loader:
                X, y = X.to(self.device), y.to(self.device)
                output = model(X)
                loss = criterion(output, y)
                total_loss += loss.item()

                _, predicted = output.max(1)
                total += y.size(0)
                correct += predicted.eq(y).sum().item()

        return total_loss / len(loader), correct / total


# ========== 서버 ==========

class FedAvgServer:
    """FedAvg 서버"""

    def __init__(
        self,
        global_model: nn.Module,
        clients: List[FLClient],
        fraction: float = 0.1,  # C
        device: str = 'cpu'
    ):
        self.global_model = global_model.to(device)
        self.clients = clients
        self.fraction = fraction
        self.device = device
        self.round_history = []

    def select_clients(self) -> List[FLClient]:
        """각 라운드에서 클라이언트 선택"""
        m = max(int(self.fraction * len(self.clients)), 1)
        return random.sample(self.clients, m)

    def aggregate(
        self,
        client_weights: List[Dict],
        client_sizes: List[int]
    ) -> Dict:
        """
        가중 평균으로 모델 집계 (FedAvg)
        n_k / n 비율로 가중 평균
        """
        total_size = sum(client_sizes)
        aggregated = {}

        for key in client_weights[0].keys():
            aggregated[key] = torch.zeros_like(
                client_weights[0][key], dtype=torch.float32
            )
            for w, size in zip(client_weights, client_sizes):
                weight = size / total_size
                aggregated[key] += weight * w[key].float()

        return aggregated

    def train_round(
        self,
        local_epochs: int = 5,
        batch_size: int = 32,
        lr: float = 0.01
    ) -> Dict:
        """하나의 FL 라운드 실행"""
        selected = self.select_clients()

        client_weights = []
        client_sizes = []
        client_losses = []

        for client in selected:
            weights, loss = client.local_train(
                self.global_model, local_epochs, batch_size, lr
            )
            client_weights.append(weights)
            client_sizes.append(len(client.dataset))
            client_losses.append(loss)

        # 집계
        new_weights = self.aggregate(client_weights, client_sizes)
        self.global_model.load_state_dict(new_weights)

        round_info = {
            'num_clients': len(selected),
            'avg_local_loss': np.mean(client_losses),
            'client_losses': client_losses
        }
        self.round_history.append(round_info)
        return round_info

    def evaluate_global(self, test_loader: DataLoader) -> Tuple[float, float]:
        """글로벌 모델 평가"""
        self.global_model.eval()
        criterion = nn.CrossEntropyLoss()

        total_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for X, y in test_loader:
                X, y = X.to(self.device), y.to(self.device)
                output = self.global_model(X)
                loss = criterion(output, y)
                total_loss += loss.item()

                _, predicted = output.max(1)
                total += y.size(0)
                correct += predicted.eq(y).sum().item()

        return total_loss / len(test_loader), correct / total

    def federated_train(
        self,
        num_rounds: int,
        local_epochs: int = 5,
        batch_size: int = 32,
        lr: float = 0.01,
        test_loader: DataLoader = None
    ):
        """전체 FL 훈련 루프"""
        print(f"연합 학습 시작: {num_rounds} 라운드, {len(self.clients)} 클라이언트")

        for round_num in range(1, num_rounds + 1):
            round_info = self.train_round(local_epochs, batch_size, lr)

            if test_loader and round_num % 10 == 0:
                test_loss, test_acc = self.evaluate_global(test_loader)
                print(
                    f"라운드 {round_num:3d}/{num_rounds} | "
                    f"클라이언트: {round_info['num_clients']} | "
                    f"로컬 손실: {round_info['avg_local_loss']:.4f} | "
                    f"테스트 정확도: {test_acc:.4f}"
                )

        print("연합 학습 완료!")


# ========== 데이터 분산 (Non-IID 시뮬레이션) ==========

def create_non_iid_partition(
    dataset,
    num_clients: int,
    num_classes: int,
    alpha: float = 0.5
) -> List[List[int]]:
    """
    Dirichlet 분포를 사용한 Non-IID 데이터 분할
    alpha가 낮을수록 더 불균형한 분포
    """
    labels = np.array([dataset[i][1] for i in range(len(dataset))])
    client_indices = [[] for _ in range(num_clients)]

    for cls in range(num_classes):
        cls_indices = np.where(labels == cls)[0]
        np.random.shuffle(cls_indices)

        # Dirichlet 분포로 클라이언트에 할당
        proportions = np.random.dirichlet(
            [alpha] * num_clients
        )
        proportions = (proportions * len(cls_indices)).astype(int)
        proportions[-1] = len(cls_indices) - proportions[:-1].sum()

        start = 0
        for k, prop in enumerate(proportions):
            client_indices[k].extend(
                cls_indices[start:start + prop].tolist()
            )
            start += prop

    return client_indices


# ========== 메인 실행 ==========

def run_fedavg_demo():
    """FedAvg 데모 실행"""
    import torchvision
    import torchvision.transforms as transforms

    # MNIST 데이터셋 로드
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = torchvision.datasets.MNIST(
        './data', train=True, download=True, transform=transform
    )
    test_dataset = torchvision.datasets.MNIST(
        './data', train=False, transform=transform
    )

    # Non-IID 데이터 분할
    num_clients = 20
    client_indices = create_non_iid_partition(
        train_dataset, num_clients, num_classes=10, alpha=0.5
    )

    # 클라이언트 생성
    clients = []
    for k in range(num_clients):
        subset = Subset(train_dataset, client_indices[k])
        clients.append(FLClient(k, subset))

    print(f"클라이언트 수: {num_clients}")
    print(f"클라이언트당 평균 샘플: {np.mean([len(c.dataset) for c in clients]):.1f}")

    # 글로벌 모델 생성
    global_model = SimpleNet(784, 256, 10)

    # 서버 생성 (10% 클라이언트 선택)
    server = FedAvgServer(global_model, clients, fraction=0.1)

    # 테스트 로더
    test_loader = DataLoader(test_dataset, batch_size=256)

    # FL 훈련
    server.federated_train(
        num_rounds=100,
        local_epochs=5,
        batch_size=32,
        lr=0.01,
        test_loader=test_loader
    )

    # 최종 평가
    _, final_acc = server.evaluate_global(test_loader)
    print(f"\n최종 테스트 정확도: {final_acc:.4f}")


if __name__ == '__main__':
    run_fedavg_demo()

4. FL의 도전 과제

4.1 데이터 이질성 (Non-IID 문제)

FL에서 가장 큰 기술적 도전은 Non-IID(Non-Independent and Identically Distributed) 데이터이다. 실제 환경에서 각 클라이언트의 데이터 분포는 다르다.

Non-IID의 유형

  • 특성 분포 차이: 클라이언트마다 입력 분포가 다름 (예: 지역별 날씨 패턴)
  • 레이블 분포 차이: 클라이언트마다 클래스 비율이 다름 (스마트폰 사용자마다 자주 쓰는 단어가 다름)
  • 개념 드리프트: 같은 입력에 대해 다른 레이블 (지역마다 다른 문화적 맥락)
  • 수량 불균형: 클라이언트마다 데이터 양이 극단적으로 다름

Non-IID 데이터에서 FedAvg를 사용하면 클라이언트 드리프트(client drift) 문제가 발생한다. 각 클라이언트의 로컬 최적점이 글로벌 최적점과 달라지면서 집계가 효과적이지 않게 된다.

4.2 시스템 이질성

실제 FL 시스템에서는 디바이스의 성능, 배터리 상태, 네트워크 연결 상태가 매우 다양하다.

  • 계산 이질성: GPU가 있는 서버와 저사양 모바일 기기가 동시에 참여
  • 네트워크 이질성: 고속 유선과 불안정한 모바일 네트워크 혼재
  • 메모리 이질성: 일부 클라이언트는 전체 모델을 올릴 수 없을 수 있음

4.3 스트래글러(Straggler) 문제

일부 클라이언트가 느리거나 응답하지 않으면 전체 훈련이 지연된다. Synchronous FL에서는 모든 선택된 클라이언트가 업데이트를 반환할 때까지 기다려야 한다. 이에 대한 해결책으로:

  • 비동기 FL (Asynchronous FL): 응답한 클라이언트의 업데이트만 즉시 집계
  • 타임아웃 설정: 일정 시간 내에 응답한 클라이언트만 사용
  • FedProx: 로컬 업데이트에 근접 항을 추가하여 스트래글러 허용

5. 고급 FL 알고리즘

5.1 FedProx: Non-IID 문제 해결

FedProx는 Li 등이 2020년에 제안한 알고리즘으로, 로컬 최적화에 **근접 항(proximal term)**을 추가한다. 이 항은 로컬 모델이 글로벌 모델에서 너무 멀리 벗어나지 않도록 제한한다.

FedProx 목적 함수

로컬 목적 함수에 근접 항을 추가한다.

h_k(w; w^t) = F_k(w) + (mu/2) × ||w - w^t||^2

여기서 mu는 근접 항의 강도를 조절하는 하이퍼파라미터이다. mu=0이면 FedAvg와 동일하다.

class FedProxClient(FLClient):
    """FedProx 클라이언트: 근접 항 추가"""

    def local_train_prox(
        self,
        model: nn.Module,
        global_weights: Dict,
        local_epochs: int,
        batch_size: int,
        lr: float,
        mu: float = 0.01  # 근접 항 가중치
    ) -> Tuple[Dict, float]:
        """
        근접 항을 포함한 로컬 훈련
        h_k(w) = F_k(w) + (mu/2) * ||w - w^t||^2
        """
        model = deepcopy(model).to(self.device)
        model.train()

        # 글로벌 가중치를 기준점으로 저장
        global_model = deepcopy(model)
        global_model.load_state_dict(global_weights)
        for param in global_model.parameters():
            param.requires_grad = False

        loader = DataLoader(
            self.dataset,
            batch_size=batch_size,
            shuffle=True
        )
        optimizer = optim.SGD(model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()

        total_loss = 0.0
        n_batches = 0

        for epoch in range(local_epochs):
            for X, y in loader:
                X, y = X.to(self.device), y.to(self.device)
                optimizer.zero_grad()

                output = model(X)
                task_loss = criterion(output, y)

                # 근접 항: (mu/2) * ||w - w_global||^2
                prox_loss = 0.0
                for w, w_global in zip(
                    model.parameters(),
                    global_model.parameters()
                ):
                    prox_loss += (mu / 2) * torch.norm(w - w_global) ** 2

                total_batch_loss = task_loss + prox_loss
                total_batch_loss.backward()
                optimizer.step()

                total_loss += task_loss.item()
                n_batches += 1

        avg_loss = total_loss / max(n_batches, 1)
        return model.state_dict(), avg_loss

5.2 SCAFFOLD: 클라이언트 드리프트 수정

SCAFFOLD(Stochastic Controlled Averaging for Federated Learning)는 제어 변량(control variates)을 사용하여 클라이언트 드리프트를 직접 수정한다. 각 클라이언트와 서버가 제어 변량 c_k와 c를 유지하여 그래디언트 편향을 보정한다.

class ScaffoldClient:
    """SCAFFOLD 클라이언트"""

    def __init__(self, client_id, dataset, device='cpu'):
        self.client_id = client_id
        self.dataset = dataset
        self.device = device
        # 클라이언트 제어 변량 초기화
        self.c_k = None

    def init_control_variate(self, model: nn.Module):
        """제어 변량 초기화"""
        self.c_k = {
            name: torch.zeros_like(param)
            for name, param in model.named_parameters()
        }

    def local_train_scaffold(
        self,
        model: nn.Module,
        server_control: Dict,
        local_epochs: int,
        batch_size: int,
        lr: float
    ) -> Tuple[Dict, Dict, float]:
        """
        SCAFFOLD 로컬 훈련
        Returns: (업데이트 가중치, 제어 변량 업데이트, 손실)
        """
        if self.c_k is None:
            self.init_control_variate(model)

        model = deepcopy(model).to(self.device)
        model.train()

        initial_weights = deepcopy(model.state_dict())
        loader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True)
        criterion = nn.CrossEntropyLoss()

        total_loss = 0.0
        n_batches = 0
        total_steps = local_epochs * len(loader)

        for epoch in range(local_epochs):
            for X, y in loader:
                X, y = X.to(self.device), y.to(self.device)

                # 수동 파라미터 업데이트 (SCAFFOLD 보정 포함)
                output = model(X)
                loss = criterion(output, y)
                loss.backward()

                with torch.no_grad():
                    for name, param in model.named_parameters():
                        if param.grad is not None:
                            # SCAFFOLD 보정: g - c_k + c
                            correction = (
                                self.c_k[name].to(self.device)
                                - server_control[name].to(self.device)
                            )
                            param -= lr * (param.grad + correction)
                            param.grad.zero_()

                total_loss += loss.item()
                n_batches += 1

        final_weights = model.state_dict()

        # 제어 변량 업데이트
        # c_k+ = c_k - c + (1 / (K * lr)) * (w_0 - w_K)
        new_c_k = {}
        c_k_diff = {}
        for name in self.c_k:
            w_diff = (
                initial_weights[name].float() - final_weights[name].float()
            )
            new_c_k[name] = (
                self.c_k[name]
                - server_control[name]
                + w_diff / (total_steps * lr)
            )
            c_k_diff[name] = new_c_k[name] - self.c_k[name]

        self.c_k = new_c_k
        avg_loss = total_loss / max(n_batches, 1)
        return final_weights, c_k_diff, avg_loss

6. 차등 프라이버시 (Differential Privacy)

6.1 DP의 수학적 정의

차등 프라이버시(DP)는 개인정보 보호를 위한 수학적 프레임워크이다. 직관적으로, "하나의 데이터 포인트가 추가되거나 제거되어도 출력의 분포가 크게 변하지 않는다"는 보장이다.

엡실론-델타 DP 정의

랜덤화 메커니즘 M이 다음 조건을 만족하면 (엡실론, 델타)-DP를 만족한다.

모든 인접 데이터셋 D, D' 및 출력 집합 S에 대해:
Pr[M(D)S]exp(ε) × Pr[M(D')S] + δ
  • ε (엡실론): 프라이버시 예산. 값이 낮을수록 강한 프라이버시 보장
  • δ (델타): 실패 확률. 보통 1/|데이터셋| 이하로 설정

6.2 가우시안 메커니즘과 클리핑

FL에서 DP를 적용하려면 각 클라이언트의 업데이트에 노이즈를 추가해야 한다.

그래디언트 클리핑: 먼저 그래디언트의 L2 노름을 최대값 C로 클리핑한다.

g_clipped = g × min(1, C / ||g||_2)

노이즈 추가: 클리핑된 그래디언트에 가우시안 노이즈를 추가한다.

g_dp = g_clipped + N(0, σ^2 × C^2 × I)

여기서 σ는 노이즈 배율(noise multiplier)이다.

6.3 DP-FL 구현

import torch
import torch.nn as nn
from opacus import PrivacyEngine
from opacus.validators import ModuleValidator


def make_private_model(model: nn.Module) -> nn.Module:
    """Opacus 호환 모델로 변환 (BatchNorm -> GroupNorm)"""
    model = ModuleValidator.fix(model)
    return model


class DPFLClient:
    """차등 프라이버시를 적용한 FL 클라이언트"""

    def __init__(
        self,
        client_id: int,
        dataset,
        target_epsilon: float = 1.0,    # 프라이버시 예산
        target_delta: float = 1e-5,     # 실패 확률
        max_grad_norm: float = 1.0,     # 그래디언트 클리핑 임계값
        device: str = 'cpu'
    ):
        self.client_id = client_id
        self.dataset = dataset
        self.target_epsilon = target_epsilon
        self.target_delta = target_delta
        self.max_grad_norm = max_grad_norm
        self.device = device

    def dp_local_train(
        self,
        model: nn.Module,
        local_epochs: int,
        batch_size: int,
        lr: float
    ) -> Tuple[Dict, float, float]:
        """
        차등 프라이버시 로컬 훈련
        Returns: (가중치, 손실, 소비된 엡실론)
        """
        model = make_private_model(deepcopy(model)).to(self.device)
        model.train()

        loader = DataLoader(
            self.dataset,
            batch_size=batch_size,
            shuffle=True,
            drop_last=True  # Opacus 요구사항
        )
        optimizer = optim.SGD(model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()

        # Opacus PrivacyEngine 부착
        privacy_engine = PrivacyEngine()
        model, optimizer, loader = privacy_engine.make_private_with_epsilon(
            module=model,
            optimizer=optimizer,
            data_loader=loader,
            epochs=local_epochs,
            target_epsilon=self.target_epsilon,
            target_delta=self.target_delta,
            max_grad_norm=self.max_grad_norm
        )

        total_loss = 0.0
        n_batches = 0

        for epoch in range(local_epochs):
            for X, y in loader:
                X, y = X.to(self.device), y.to(self.device)
                optimizer.zero_grad()
                output = model(X)
                loss = criterion(output, y)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                n_batches += 1

        epsilon_used = privacy_engine.get_epsilon(self.target_delta)
        avg_loss = total_loss / max(n_batches, 1)

        # Opacus 래퍼 제거하여 순수 가중치 반환
        clean_weights = {
            k.replace('_module.', ''): v
            for k, v in model.state_dict().items()
        }

        return clean_weights, avg_loss, epsilon_used


class DPFLServer:
    """DP-FL 서버: 서버 측 DP 집계"""

    def __init__(
        self,
        global_model: nn.Module,
        clients: List['DPFLClient'],
        noise_multiplier: float = 0.5,
        max_grad_norm: float = 1.0,
        device: str = 'cpu'
    ):
        self.global_model = global_model.to(device)
        self.clients = clients
        self.noise_multiplier = noise_multiplier
        self.max_grad_norm = max_grad_norm
        self.device = device

    def clip_and_aggregate(
        self,
        client_weights: List[Dict],
        reference_weights: Dict
    ) -> Dict:
        """
        서버 측 클리핑 및 노이즈 추가
        각 업데이트를 클리핑한 후 가우시안 노이즈 추가
        """
        n_clients = len(client_weights)

        # 업데이트 계산 (delta_w = w_k - w_global)
        updates = []
        for w in client_weights:
            delta = {
                k: w[k].float() - reference_weights[k].float()
                for k in reference_weights
            }
            updates.append(delta)

        # L2 노름 클리핑
        clipped_updates = []
        for delta in updates:
            total_norm = torch.sqrt(
                sum(torch.norm(v) ** 2 for v in delta.values())
            )
            clip_factor = min(1.0, self.max_grad_norm / (total_norm + 1e-8))
            clipped = {k: v * clip_factor for k, v in delta.items()}
            clipped_updates.append(clipped)

        # 합산
        summed = {}
        for key in reference_weights:
            summed[key] = sum(u[key] for u in clipped_updates)

        # 가우시안 노이즈 추가
        sigma = self.noise_multiplier * self.max_grad_norm
        noisy = {}
        for key in summed:
            noise = torch.randn_like(summed[key]) * sigma
            noisy[key] = (summed[key] + noise) / n_clients

        # 글로벌 모델에 적용
        aggregated = {
            k: reference_weights[k].float() + noisy[k]
            for k in reference_weights
        }
        return aggregated

7. 보안 집계 (Secure Aggregation)

7.1 암호화된 집계의 필요성

DP는 통계적 프라이버시를 보장하지만, 서버는 여전히 각 클라이언트의 업데이트를 개별적으로 볼 수 있다. 보안 집계(SecAgg)는 서버가 집계된 결과만 볼 수 있도록 하는 암호화 프로토콜이다.

7.2 비밀 분산 (Secret Sharing) 기반 SecAgg

Bonawitz et al. (2017)의 Secure Aggregation 프로토콜은 다음 원리를 사용한다.

  • 클라이언트들이 서로 마스크(난수 쌍)를 교환한다.
  • 각 클라이언트는 자신의 업데이트에 마스크를 더해 서버로 전송한다.
  • 서버가 합산하면 마스크가 서로 상쇄되어 순수한 합산 업데이트가 나온다.
import secrets
import hashlib


class SecureAggregation:
    """
    간략화된 보안 집계 시뮬레이션
    실제 구현에서는 Diffie-Hellman 키 교환 사용
    """

    def __init__(self, num_clients: int, seed_length: int = 32):
        self.num_clients = num_clients
        self.seed_length = seed_length

    def generate_pairwise_masks(
        self,
        client_id: int,
        model_shape: Dict[str, torch.Size],
        shared_seeds: Dict
    ) -> Dict:
        """
        클라이언트 쌍 간의 마스크 생성
        client i와 j 사이: i>j면 더하고, i<j면 빼기
        """
        mask = {k: torch.zeros(v) for k, v in model_shape.items()}

        for other_id, seed in shared_seeds.items():
            # 공유된 시드에서 의사 난수 마스크 생성
            torch.manual_seed(seed)
            direction = 1 if client_id > other_id else -1

            for key in mask:
                mask[key] += direction * torch.randn(model_shape[key])

        return mask

    def client_mask_update(
        self,
        local_update: Dict,
        mask: Dict
    ) -> Dict:
        """업데이트에 마스크 적용"""
        return {k: local_update[k] + mask[k] for k in local_update}

    def server_aggregate_masked(
        self,
        masked_updates: List[Dict]
    ) -> Dict:
        """마스크된 업데이트 합산 (마스크는 서로 상쇄됨)"""
        aggregated = {}
        for key in masked_updates[0]:
            aggregated[key] = sum(u[key] for u in masked_updates)
            aggregated[key] /= len(masked_updates)
        return aggregated

7.3 동형 암호 (Homomorphic Encryption) 개요

동형 암호(HE)는 암호화된 상태에서 연산을 수행할 수 있는 암호 체계이다. FL에서 HE를 사용하면 서버가 클라이언트 업데이트를 복호화하지 않고도 집계할 수 있다.

부분 동형 암호(PHE): 덧셈 또는 곱셈 중 하나만 지원 (Paillier 암호) 완전 동형 암호(FHE): 임의의 연산 지원 (CKKS, BGV) - 계산 비용이 매우 높음

실용적인 FL에서는 덧셈만 지원하는 PHE로도 집계(가중 평균)를 구현할 수 있다.


8. Flower 프레임워크

8.1 Flower 소개

Flower(flwr)는 연합 학습을 위한 파이썬 프레임워크로, 프레임워크 독립적이다. PyTorch, TensorFlow, JAX 등 다양한 ML 프레임워크와 함께 사용할 수 있다.

pip install flwr
pip install flwr[simulation]

8.2 Flower 클라이언트 구현

import flwr as fl
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from typing import Dict, List, Tuple
import numpy as np


class FlowerClient(fl.client.NumPyClient):
    """Flower FL 클라이언트"""

    def __init__(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: DataLoader,
        device: str = 'cpu'
    ):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.criterion = nn.CrossEntropyLoss()

    def get_parameters(self, config: Dict) -> List[np.ndarray]:
        """현재 모델 파라미터를 NumPy 배열로 반환"""
        return [
            val.cpu().numpy()
            for _, val in self.model.state_dict().items()
        ]

    def set_parameters(self, parameters: List[np.ndarray]):
        """NumPy 배열로 모델 파라미터 설정"""
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = {
            k: torch.tensor(v)
            for k, v in params_dict
        }
        self.model.load_state_dict(state_dict, strict=True)

    def fit(
        self,
        parameters: List[np.ndarray],
        config: Dict
    ) -> Tuple[List[np.ndarray], int, Dict]:
        """서버로부터 받은 파라미터로 로컬 훈련"""
        # 글로벌 파라미터 설정
        self.set_parameters(parameters)

        # 설정 파라미터 파싱
        local_epochs = int(config.get('local_epochs', 1))
        lr = float(config.get('lr', 0.01))

        # 로컬 훈련
        self.model.train()
        optimizer = torch.optim.SGD(
            self.model.parameters(), lr=lr, momentum=0.9
        )

        train_loss = 0.0
        for epoch in range(local_epochs):
            for X, y in self.train_loader:
                X, y = X.to(self.device), y.to(self.device)
                optimizer.zero_grad()
                loss = self.criterion(self.model(X), y)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()

        return (
            self.get_parameters(config={}),
            len(self.train_loader.dataset),
            {'train_loss': train_loss}
        )

    def evaluate(
        self,
        parameters: List[np.ndarray],
        config: Dict
    ) -> Tuple[float, int, Dict]:
        """서버로부터 받은 파라미터로 평가"""
        self.set_parameters(parameters)
        self.model.eval()

        val_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for X, y in self.val_loader:
                X, y = X.to(self.device), y.to(self.device)
                output = self.model(X)
                val_loss += self.criterion(output, y).item()

                _, predicted = output.max(1)
                total += y.size(0)
                correct += predicted.eq(y).sum().item()

        accuracy = correct / total
        return (
            val_loss / len(self.val_loader),
            len(self.val_loader.dataset),
            {'accuracy': accuracy}
        )

8.3 Flower 서버 전략 구현

from flwr.server.strategy import FedAvg
from flwr.common import Parameters, FitIns, FitRes, EvaluateIns, EvaluateRes
from flwr.server.client_proxy import ClientProxy


class CustomFedAvgStrategy(FedAvg):
    """커스텀 FedAvg 전략"""

    def __init__(
        self,
        fraction_fit: float = 0.1,
        fraction_evaluate: float = 0.1,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
        initial_parameters=None,
    ):
        super().__init__(
            fraction_fit=fraction_fit,
            fraction_evaluate=fraction_evaluate,
            min_fit_clients=min_fit_clients,
            min_evaluate_clients=min_evaluate_clients,
            min_available_clients=min_available_clients,
            initial_parameters=initial_parameters,
        )
        self.round_metrics = []

    def configure_fit(
        self,
        server_round: int,
        parameters: Parameters,
        client_manager
    ) -> List[Tuple[ClientProxy, FitIns]]:
        """각 라운드의 훈련 설정"""
        # 라운드가 늘어날수록 학습률 감소
        lr = 0.01 * (0.99 ** server_round)

        config = {
            'local_epochs': 5,
            'lr': lr,
            'round': server_round,
        }

        fit_ins = FitIns(parameters, config)
        clients = client_manager.sample(
            num_clients=max(
                int(client_manager.num_available() * self.fraction_fit), 1
            ),
            min_num_clients=self.min_fit_clients
        )
        return [(client, fit_ins) for client in clients]

    def aggregate_evaluate(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, EvaluateRes]],
        failures
    ):
        """평가 결과 집계 및 로깅"""
        aggregated_loss, metrics = super().aggregate_evaluate(
            server_round, results, failures
        )

        if metrics:
            print(f"라운드 {server_round}: 집계된 정확도 = {metrics.get('accuracy', 0):.4f}")
            self.round_metrics.append({
                'round': server_round,
                'loss': aggregated_loss,
                'metrics': metrics
            })

        return aggregated_loss, metrics


def run_flower_simulation():
    """Flower 시뮬레이션 실행"""
    import torchvision
    import torchvision.transforms as transforms
    from torch.utils.data import Subset

    # 데이터 준비
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = torchvision.datasets.MNIST(
        './data', train=True, download=True, transform=transform
    )
    test_dataset = torchvision.datasets.MNIST(
        './data', train=False, transform=transform
    )

    num_clients = 10

    def client_fn(cid: str):
        """클라이언트 생성 함수"""
        cid = int(cid)

        # 데이터 분할
        n = len(train_dataset) // num_clients
        start = cid * n
        indices = list(range(start, min(start + n, len(train_dataset))))
        train_subset = Subset(train_dataset, indices)

        val_indices = list(range(cid * 100, (cid + 1) * 100))
        val_subset = Subset(test_dataset, val_indices)

        model = SimpleNet(784, 256, 10)

        return FlowerClient(
            model=model,
            train_loader=DataLoader(train_subset, batch_size=32),
            val_loader=DataLoader(val_subset, batch_size=64)
        ).to_client()

    # 글로벌 모델 파라미터 초기화
    global_model = SimpleNet(784, 256, 10)
    initial_params = fl.common.ndarrays_to_parameters(
        [val.numpy() for val in global_model.state_dict().values()]
    )

    # 전략 설정
    strategy = CustomFedAvgStrategy(
        fraction_fit=0.5,
        fraction_evaluate=0.5,
        min_fit_clients=2,
        min_evaluate_clients=2,
        min_available_clients=num_clients,
        initial_parameters=initial_params
    )

    # 시뮬레이션 실행
    history = fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=num_clients,
        config=fl.server.ServerConfig(num_rounds=50),
        strategy=strategy,
    )

    print("\n=== Flower 시뮬레이션 완료 ===")
    print(f"최종 분산 손실: {history.losses_distributed[-1]}")

9. 실전 FL 프로젝트: 병원 연합 학습

9.1 멀티-병원 흉부 X-ray 진단

여러 병원이 흉부 X-ray 데이터를 공유하지 않고 폐 질환 진단 모델을 공동 훈련하는 시나리오이다.

import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import pandas as pd


class ChestXRayDataset(Dataset):
    """흉부 X-ray 데이터셋 (병원별)"""

    def __init__(
        self,
        data_dir: str,
        labels_file: str,
        transform=None,
        hospital_id: int = 0
    ):
        self.data_dir = data_dir
        self.labels = pd.read_csv(labels_file)
        self.transform = transform
        self.hospital_id = hospital_id

        # 각 병원의 데이터만 필터링
        self.labels = self.labels[
            self.labels['hospital_id'] == hospital_id
        ].reset_index(drop=True)

        self.classes = [
            'Normal', 'Pneumonia', 'COVID-19', 'Tuberculosis'
        ]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img_path = os.path.join(
            self.data_dir, self.labels.iloc[idx]['filename']
        )
        image = Image.open(img_path).convert('RGB')
        label = self.labels.iloc[idx]['label']

        if self.transform:
            image = self.transform(image)

        return image, label


def build_chest_model(num_classes: int = 4, pretrained: bool = True):
    """ResNet50 기반 흉부 X-ray 분류 모델"""
    model = models.resnet50(pretrained=pretrained)

    # 마지막 레이어 교체
    in_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(0.3),
        nn.Linear(in_features, 256),
        nn.ReLU(),
        nn.Linear(256, num_classes)
    )
    return model


class HospitalFLClient(FlowerClient):
    """병원용 FL 클라이언트 (클래스 불균형 처리)"""

    def __init__(
        self,
        hospital_id: int,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: DataLoader,
        class_weights: torch.Tensor = None,
        device: str = 'cpu'
    ):
        super().__init__(model, train_loader, val_loader, device)
        self.hospital_id = hospital_id

        if class_weights is not None:
            self.criterion = nn.CrossEntropyLoss(
                weight=class_weights.to(device)
            )
        else:
            self.criterion = nn.CrossEntropyLoss()

    def compute_class_weights(self, dataset) -> torch.Tensor:
        """클래스 불균형 보정을 위한 가중치 계산"""
        labels = [dataset[i][1] for i in range(len(dataset))]
        class_counts = torch.bincount(torch.tensor(labels))
        weights = 1.0 / class_counts.float()
        weights = weights / weights.sum()
        return weights


def setup_hospital_federation(num_hospitals: int = 5):
    """
    다중 병원 연합 학습 설정
    실제로는 각 병원의 데이터가 로컬에 있음
    """
    import torchvision.transforms as transforms

    # 데이터 전처리
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

    hospitals = []
    for h_id in range(num_hospitals):
        # 실제로는 각 병원 서버에서 실행됨
        # 여기서는 시뮬레이션을 위해 같은 서버에서 실행
        print(f"병원 {h_id} 데이터 준비 중...")

        # 더미 데이터 생성 (실제 구현 시 실제 데이터로 교체)
        # 여기서는 예시 구조만 보여줌

    return hospitals


# 연합 학습 훈련 루프
def train_hospital_federation():
    """병원 연합 학습 실행"""
    print("=== 멀티-병원 연합 학습 시작 ===")
    print("각 병원의 환자 데이터는 외부로 전송되지 않습니다.")
    print("오직 모델 파라미터 업데이트만 집계됩니다.")

    global_model = build_chest_model(num_classes=4)

    print(f"\n글로벌 모델 파라미터 수: {sum(p.numel() for p in global_model.parameters()):,}")
    print(f"통신 비용 (FP32): {sum(p.numel() for p in global_model.parameters()) * 4 / 1024 / 1024:.1f} MB/라운드")

10. 마무리: 연합 학습의 미래

연합 학습은 AI와 프라이버시의 갈등을 해소하는 핵심 기술로 자리잡고 있다. 특히 다음 트렌드에 주목해야 한다.

크로스-디바이스 vs 크로스-사일로

  • 크로스-디바이스: 수백만 개의 모바일 기기가 참여 (Google의 Gboard)
  • 크로스-사일로: 소수의 기관(병원, 은행)이 참여 - 더 신뢰할 수 있지만 규모가 작음

FL + LLM

대형 언어 모델(LLM)의 등장으로 FL이 더욱 중요해졌다. 사용자의 대화 내용으로 모델을 파인튜닝할 때, FL을 사용하면 대화 내용이 서버로 전송되지 않는다. 파라미터 효율적 파인튜닝(PEFT, LoRA)과 결합하면 통신 비용도 줄일 수 있다.

규제 친화적 AI

GDPR, HIPAA 등 규제가 강화되면서 FL은 규정 준수를 위한 실용적 해결책이 되고 있다. 의료, 금융, 법률 분야에서 FL 적용이 빠르게 확산될 것으로 예상된다.


참고 자료

  • McMahan, H. B., et al. (2017). Communication-Efficient Learning of Deep Networks from Decentralized Data. AISTATS 2017.
  • Li, T., et al. (2020). Federated Optimization in Heterogeneous Networks (FedProx). MLSys 2020.
  • Karimireddy, S. P., et al. (2020). SCAFFOLD: Stochastic Controlled Averaging for Federated Learning. ICML 2020.
  • Bonawitz, K., et al. (2017). Practical Secure Aggregation for Privacy-Preserving Machine Learning. CCS 2017.
  • Flower Framework: https://flower.ai/docs/
  • Opacus (PyTorch DP): https://opacus.ai/

Federated Learning Complete Guide: Privacy-Preserving Distributed AI

Federated Learning Complete Guide: Privacy-Preserving Distributed AI

One of the greatest ironies of modern AI is that as model performance improves, more data is needed, and as more data is collected, the risk of privacy violations grows. Hospitals cannot share patient data, smartphone manufacturers cannot send users' typing patterns to servers, and financial institutions cannot share transaction records with competitors.

Federated Learning (FL) solves this dilemma. Instead of sending data to a central server, the model is sent to where the data lives, trained locally, and only the model updates (weight changes) are aggregated. The concept is: "data stays put; only intelligence moves."

This guide covers everything from the theoretical foundations of FL to hands-on implementation.


1. Federated Learning Fundamentals

1.1 Problems with Traditional Centralized Learning

Consider the traditional ML pipeline: collect patient data from thousands of hospitals, store it on a central server, and train a diagnostic model on that data. What are the problems?

Data Privacy Issues

Patient records, financial transaction histories, and personal communications are extremely sensitive. Transmitting such data to a central server creates the following risks:

  • Eavesdropping risk during transmission
  • Large-scale data breaches from server hacking
  • Loss of trust in the data-owning organization
  • Legal sanctions for regulatory violations

Legal Regulations

Data privacy regulations are tightening worldwide.

  • GDPR (General Data Protection Regulation): The EU's data protection law specifying processing purpose disclosure, consent requirements, and data minimization principles.
  • HIPAA (Health Insurance Portability and Accountability Act): US law protecting patient health information (PHI).
  • CCPA (California Consumer Privacy Act): Gives California residents rights over their personal information collected by businesses.

Communication Costs

Transmitting data from millions of edge devices (smartphones, IoT sensors) to a central location requires enormous network bandwidth. Large data types like images and audio make this even more problematic.

1.2 The Core Idea of Federated Learning

Federated Learning, proposed by McMahan et al. at Google in 2016, is based on the following principle:

"Data stays local; only knowledge (model updates) moves."

The basic FL process:

  1. Initialization: The central server initializes a global model.
  2. Distribution: The server distributes the current global model to selected clients.
  3. Local Training: Each client trains the model on its local data.
  4. Upload: Clients send model updates (gradients or weight differences) to the server.
  5. Aggregation: The server aggregates the updates (e.g., by averaging) to update the global model.
  6. Repeat: Steps 2–5 repeat until convergence.

In this approach, raw data never leaves the client device. Only model parameter updates are transmitted.

1.3 Applications of Federated Learning

Mobile / Edge Devices

Google applied FL to Gboard (mobile keyboard). Users' typing patterns are not sent to servers; instead, the next-word prediction model is improved directly on the device. Hundreds of millions of users' data contribute without any personal information leaving the device.

Healthcare

Better diagnostic AI can be built without sharing patient data from multiple hospitals. For rare diseases, single-hospital data is insufficient for model training, but FL can combine knowledge from multiple hospitals.

Finance

Multiple financial institutions can collaborate on fraud detection and credit scoring without sharing customer data. Especially for cross-border financial transactions, a joint model can be trained while respecting data sovereignty of each country.

Autonomous Vehicles

Multiple automakers can jointly improve road hazard detection models without sharing their proprietary driving data.


2. Federated Learning Architecture

2.1 Client-Server Structure

The most common FL architecture consists of a central aggregation server and multiple clients.

         ┌─────────────────┐
Central Server           (Aggregator)         └────────┬─────────┘
                  │ model distribution / update aggregation
         ┌────────┼────────┐
         ↓        ↓        ↓
    ┌─────────┐ ┌─────────┐ ┌─────────┐
    │Client 1 │ │Client 2 │ │Client 3    (local   │ (local   │ (local   │
    │training)│ │training)│ │training)    └─────────┘ └─────────┘ └─────────┘

Server Responsibilities

  • Maintain and manage the global model
  • Select participating clients for each round
  • Aggregate client updates
  • Distribute the aggregated model to clients

Client Responsibilities

  • Hold local data
  • Fine-tune the received model on local data
  • Transmit the updated model (or gradients)

2.2 Horizontal Federated Learning

Horizontal FL is used when all clients have the same feature space but different data samples. For example, multiple hospitals measure the same diagnostic items (blood pressure, blood glucose, age) but have different patients.

Client 1: [feature1, feature2, feature3] x [samples 1-1000]
Client 2: [feature1, feature2, feature3] x [samples 1001-2000]
Client 3: [feature1, feature2, feature3] x [samples 2001-3000]

Same feature space, different sample space. The most common form of FL.

2.3 Vertical Federated Learning

Vertical FL is used when clients hold the same users (data samples) but have different features. For example, a bank has a user's financial information while a hospital has the same user's medical information.

Client A (Bank):     [financial features] x [users 1-10000]
Client B (Hospital): [medical features]   x [users 1-10000]

Same sample space, different feature space.

Vertical FL requires more complex protocols. Cryptographic techniques are needed for cooperation between a client holding labels and clients holding only features.

2.4 Federated Transfer Learning

When clients have partially overlapping sample and feature spaces, transfer learning techniques are combined. This allows FL to be applied even when data overlap is minimal.


3. The FedAvg Algorithm

3.1 McMahan et al. (2017) Original Algorithm

FedAvg (Federated Averaging) is the foundational FL algorithm, published by McMahan et al. at Google in 2017. The key idea is that each client performs multiple local SGD updates, then the server averages the weights.

Algorithm Overview

Server executes:
  initialize w_0
  for round t = 1, 2, ..., T:
    m = max(C x K, 1)  // C: participation fraction, K: total clients
    S_t = randomly select m clients
    for each client k in S_t (parallel):
      w_{t+1}^k = ClientUpdate(k, w_t)
    w_{t+1} = sum (n_k / n) x w_{t+1}^k  // weighted average

Client k executes:
  B = split local data into batches
  for local epoch e = 1, ..., E:
    for batch b in B:
      w = w - lr x grad_loss(w; b)
  return w

Key Parameters

  • C: fraction of clients participating in each round, chosen between 0 and 1 inclusive of 1
  • E: number of local epochs per client
  • B: local mini-batch size
  • lr: learning rate

When E=1 and B equals all data, this is equivalent to FedSGD. Increasing E reduces communication rounds but increases the risk of client drift.

3.2 Complete FedAvg Implementation

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import numpy as np
from copy import deepcopy
from typing import List, Dict, Tuple
import random


# ========== Model Definition ==========

class SimpleNet(nn.Module):
    """Simple classification network"""
    def __init__(self, input_dim: int, hidden_dim: int, num_classes: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, num_classes)
        )

    def forward(self, x):
        return self.net(x)


# ========== Client ==========

class FLClient:
    """Federated Learning Client"""

    def __init__(
        self,
        client_id: int,
        dataset,
        device: str = 'cpu'
    ):
        self.client_id = client_id
        self.dataset = dataset
        self.device = device

    def local_train(
        self,
        model: nn.Module,
        local_epochs: int,
        batch_size: int,
        lr: float
    ) -> Tuple[Dict, float]:
        """
        Train the model on local data
        Returns: (updated weights, local loss)
        """
        model = deepcopy(model).to(self.device)
        model.train()

        loader = DataLoader(
            self.dataset,
            batch_size=batch_size,
            shuffle=True
        )
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
        criterion = nn.CrossEntropyLoss()

        total_loss = 0.0
        n_batches = 0

        for epoch in range(local_epochs):
            for X, y in loader:
                X, y = X.to(self.device), y.to(self.device)
                optimizer.zero_grad()
                output = model(X)
                loss = criterion(output, y)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                n_batches += 1

        avg_loss = total_loss / max(n_batches, 1)
        return model.state_dict(), avg_loss

    def evaluate(self, model: nn.Module) -> Tuple[float, float]:
        """Evaluate model on local data"""
        model = deepcopy(model).to(self.device)
        model.eval()

        loader = DataLoader(self.dataset, batch_size=64)
        criterion = nn.CrossEntropyLoss()

        total_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for X, y in loader:
                X, y = X.to(self.device), y.to(self.device)
                output = model(X)
                loss = criterion(output, y)
                total_loss += loss.item()

                _, predicted = output.max(1)
                total += y.size(0)
                correct += predicted.eq(y).sum().item()

        return total_loss / len(loader), correct / total


# ========== Server ==========

class FedAvgServer:
    """FedAvg Server"""

    def __init__(
        self,
        global_model: nn.Module,
        clients: List[FLClient],
        fraction: float = 0.1,
        device: str = 'cpu'
    ):
        self.global_model = global_model.to(device)
        self.clients = clients
        self.fraction = fraction
        self.device = device
        self.round_history = []

    def select_clients(self) -> List[FLClient]:
        """Select clients for each round"""
        m = max(int(self.fraction * len(self.clients)), 1)
        return random.sample(self.clients, m)

    def aggregate(
        self,
        client_weights: List[Dict],
        client_sizes: List[int]
    ) -> Dict:
        """
        Weighted average aggregation (FedAvg)
        Weighted by n_k / n
        """
        total_size = sum(client_sizes)
        aggregated = {}

        for key in client_weights[0].keys():
            aggregated[key] = torch.zeros_like(
                client_weights[0][key], dtype=torch.float32
            )
            for w, size in zip(client_weights, client_sizes):
                weight = size / total_size
                aggregated[key] += weight * w[key].float()

        return aggregated

    def train_round(
        self,
        local_epochs: int = 5,
        batch_size: int = 32,
        lr: float = 0.01
    ) -> Dict:
        """Execute one FL round"""
        selected = self.select_clients()

        client_weights = []
        client_sizes = []
        client_losses = []

        for client in selected:
            weights, loss = client.local_train(
                self.global_model, local_epochs, batch_size, lr
            )
            client_weights.append(weights)
            client_sizes.append(len(client.dataset))
            client_losses.append(loss)

        new_weights = self.aggregate(client_weights, client_sizes)
        self.global_model.load_state_dict(new_weights)

        round_info = {
            'num_clients': len(selected),
            'avg_local_loss': np.mean(client_losses),
            'client_losses': client_losses
        }
        self.round_history.append(round_info)
        return round_info

    def evaluate_global(self, test_loader: DataLoader) -> Tuple[float, float]:
        """Evaluate global model"""
        self.global_model.eval()
        criterion = nn.CrossEntropyLoss()

        total_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for X, y in test_loader:
                X, y = X.to(self.device), y.to(self.device)
                output = self.global_model(X)
                loss = criterion(output, y)
                total_loss += loss.item()

                _, predicted = output.max(1)
                total += y.size(0)
                correct += predicted.eq(y).sum().item()

        return total_loss / len(test_loader), correct / total

    def federated_train(
        self,
        num_rounds: int,
        local_epochs: int = 5,
        batch_size: int = 32,
        lr: float = 0.01,
        test_loader: DataLoader = None
    ):
        """Full FL training loop"""
        print(f"FL Training: {num_rounds} rounds, {len(self.clients)} clients")

        for round_num in range(1, num_rounds + 1):
            round_info = self.train_round(local_epochs, batch_size, lr)

            if test_loader and round_num % 10 == 0:
                test_loss, test_acc = self.evaluate_global(test_loader)
                print(
                    f"Round {round_num:3d}/{num_rounds} | "
                    f"Clients: {round_info['num_clients']} | "
                    f"Local Loss: {round_info['avg_local_loss']:.4f} | "
                    f"Test Acc: {test_acc:.4f}"
                )

        print("Federated training complete!")


# ========== Non-IID Data Partitioning ==========

def create_non_iid_partition(
    dataset,
    num_clients: int,
    num_classes: int,
    alpha: float = 0.5
) -> List[List[int]]:
    """
    Non-IID data partitioning using Dirichlet distribution
    Lower alpha = more heterogeneous distribution
    """
    labels = np.array([dataset[i][1] for i in range(len(dataset))])
    client_indices = [[] for _ in range(num_clients)]

    for cls in range(num_classes):
        cls_indices = np.where(labels == cls)[0]
        np.random.shuffle(cls_indices)

        proportions = np.random.dirichlet([alpha] * num_clients)
        proportions = (proportions * len(cls_indices)).astype(int)
        proportions[-1] = len(cls_indices) - proportions[:-1].sum()

        start = 0
        for k, prop in enumerate(proportions):
            client_indices[k].extend(
                cls_indices[start:start + prop].tolist()
            )
            start += prop

    return client_indices


def run_fedavg_demo():
    """Run FedAvg demo"""
    import torchvision
    import torchvision.transforms as transforms

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = torchvision.datasets.MNIST(
        './data', train=True, download=True, transform=transform
    )
    test_dataset = torchvision.datasets.MNIST(
        './data', train=False, transform=transform
    )

    num_clients = 20
    client_indices = create_non_iid_partition(
        train_dataset, num_clients, num_classes=10, alpha=0.5
    )

    clients = []
    for k in range(num_clients):
        subset = Subset(train_dataset, client_indices[k])
        clients.append(FLClient(k, subset))

    print(f"Number of clients: {num_clients}")
    print(f"Avg samples per client: {np.mean([len(c.dataset) for c in clients]):.1f}")

    global_model = SimpleNet(784, 256, 10)
    server = FedAvgServer(global_model, clients, fraction=0.1)
    test_loader = DataLoader(test_dataset, batch_size=256)

    server.federated_train(
        num_rounds=100,
        local_epochs=5,
        batch_size=32,
        lr=0.01,
        test_loader=test_loader
    )

    _, final_acc = server.evaluate_global(test_loader)
    print(f"\nFinal Test Accuracy: {final_acc:.4f}")


if __name__ == '__main__':
    run_fedavg_demo()

4. Challenges in Federated Learning

4.1 Data Heterogeneity (Non-IID Problem)

The biggest technical challenge in FL is Non-IID (Non-Independent and Identically Distributed) data. In real environments, each client's data distribution is different.

Types of Non-IID

  • Feature distribution skew: Different input distributions per client (e.g., regional weather patterns)
  • Label distribution skew: Different class proportions per client (each smartphone user has different frequent words)
  • Concept drift: Different labels for the same input (different cultural contexts by region)
  • Quantity imbalance: Extremely different amounts of data per client

Non-IID data in FedAvg causes client drift. When each client's local optimum differs from the global optimum, aggregation becomes ineffective.

4.2 System Heterogeneity

In real FL systems, device performance, battery status, and network connectivity vary widely.

  • Compute heterogeneity: GPU-equipped servers and low-end mobile devices participating simultaneously
  • Network heterogeneity: High-speed wired and unstable mobile networks coexisting
  • Memory heterogeneity: Some clients may not be able to load the full model

4.3 The Straggler Problem

If some clients are slow or unresponsive, the entire training is delayed. In Synchronous FL, all selected clients must return their updates before the next round. Solutions include:

  • Asynchronous FL: Immediately aggregate updates from responding clients
  • Timeout setting: Only use clients that respond within a time limit
  • FedProx: Add a proximal term to local updates to tolerate stragglers

5. Advanced FL Algorithms

5.1 FedProx: Handling Non-IID

FedProx, proposed by Li et al. in 2020, adds a proximal term to local optimization. This term prevents the local model from drifting too far from the global model.

FedProx Objective Function

A proximal term is added to the local objective function.

h_k(w; w^t) = F_k(w) + (mu/2) x ||w - w^t||^2

Here mu is a hyperparameter controlling the strength of the proximal term. When mu=0, it is equivalent to FedAvg.

class FedProxClient(FLClient):
    """FedProx client: adds proximal term"""

    def local_train_prox(
        self,
        model: nn.Module,
        global_weights: Dict,
        local_epochs: int,
        batch_size: int,
        lr: float,
        mu: float = 0.01
    ) -> Tuple[Dict, float]:
        """
        Local training with proximal term
        h_k(w) = F_k(w) + (mu/2) * ||w - w^t||^2
        """
        model = deepcopy(model).to(self.device)
        model.train()

        global_model = deepcopy(model)
        global_model.load_state_dict(global_weights)
        for param in global_model.parameters():
            param.requires_grad = False

        loader = DataLoader(
            self.dataset,
            batch_size=batch_size,
            shuffle=True
        )
        optimizer = optim.SGD(model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()

        total_loss = 0.0
        n_batches = 0

        for epoch in range(local_epochs):
            for X, y in loader:
                X, y = X.to(self.device), y.to(self.device)
                optimizer.zero_grad()

                output = model(X)
                task_loss = criterion(output, y)

                # Proximal term: (mu/2) * ||w - w_global||^2
                prox_loss = 0.0
                for w, w_global in zip(
                    model.parameters(),
                    global_model.parameters()
                ):
                    prox_loss += (mu / 2) * torch.norm(w - w_global) ** 2

                total_batch_loss = task_loss + prox_loss
                total_batch_loss.backward()
                optimizer.step()

                total_loss += task_loss.item()
                n_batches += 1

        avg_loss = total_loss / max(n_batches, 1)
        return model.state_dict(), avg_loss

5.2 SCAFFOLD: Correcting Client Drift

SCAFFOLD (Stochastic Controlled Averaging for Federated Learning) uses control variates to directly correct client drift. Each client and server maintain control variates c_k and c to correct gradient bias.

class ScaffoldClient:
    """SCAFFOLD client"""

    def __init__(self, client_id, dataset, device='cpu'):
        self.client_id = client_id
        self.dataset = dataset
        self.device = device
        self.c_k = None  # client control variate

    def init_control_variate(self, model: nn.Module):
        """Initialize control variate"""
        self.c_k = {
            name: torch.zeros_like(param)
            for name, param in model.named_parameters()
        }

    def local_train_scaffold(
        self,
        model: nn.Module,
        server_control: Dict,
        local_epochs: int,
        batch_size: int,
        lr: float
    ) -> Tuple[Dict, Dict, float]:
        """
        SCAFFOLD local training
        Returns: (updated weights, control variate update, loss)
        """
        if self.c_k is None:
            self.init_control_variate(model)

        model = deepcopy(model).to(self.device)
        model.train()

        initial_weights = deepcopy(model.state_dict())
        loader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True)
        criterion = nn.CrossEntropyLoss()

        total_loss = 0.0
        n_batches = 0
        total_steps = local_epochs * len(loader)

        for epoch in range(local_epochs):
            for X, y in loader:
                X, y = X.to(self.device), y.to(self.device)
                output = model(X)
                loss = criterion(output, y)
                loss.backward()

                with torch.no_grad():
                    for name, param in model.named_parameters():
                        if param.grad is not None:
                            # SCAFFOLD correction: g - c_k + c
                            correction = (
                                self.c_k[name].to(self.device)
                                - server_control[name].to(self.device)
                            )
                            param -= lr * (param.grad + correction)
                            param.grad.zero_()

                total_loss += loss.item()
                n_batches += 1

        final_weights = model.state_dict()

        # Update control variate
        # c_k+ = c_k - c + (1 / (K * lr)) * (w_0 - w_K)
        new_c_k = {}
        c_k_diff = {}
        for name in self.c_k:
            w_diff = (
                initial_weights[name].float() - final_weights[name].float()
            )
            new_c_k[name] = (
                self.c_k[name]
                - server_control[name]
                + w_diff / (total_steps * lr)
            )
            c_k_diff[name] = new_c_k[name] - self.c_k[name]

        self.c_k = new_c_k
        avg_loss = total_loss / max(n_batches, 1)
        return final_weights, c_k_diff, avg_loss

6. Differential Privacy

6.1 Mathematical Definition of DP

Differential Privacy (DP) is a mathematical framework for privacy protection. Intuitively, it guarantees that "the output distribution does not change significantly when a single data point is added or removed."

Epsilon-Delta DP Definition

A randomized mechanism M satisfies (epsilon, delta)-DP if for all neighboring datasets D and D' and all output sets S:

Pr[M(D) in S] <= exp(epsilon) x Pr[M(D') in S] + delta
  • epsilon: privacy budget — lower values give stronger privacy guarantees
  • delta: failure probability — typically set below 1/|dataset|

6.2 Gaussian Mechanism and Clipping

To apply DP in FL, noise must be added to each client's update.

Gradient Clipping: First clip the L2 norm of gradients to a maximum value C.

g_clipped = g x min(1, C / ||g||_2)

Noise Addition: Add Gaussian noise to the clipped gradient.

g_dp = g_clipped + N(0, sigma^2 x C^2 x I)

Here sigma is the noise multiplier.

6.3 DP-FL Implementation

import torch
import torch.nn as nn
from opacus import PrivacyEngine
from opacus.validators import ModuleValidator


def make_private_model(model: nn.Module) -> nn.Module:
    """Convert to Opacus-compatible model (BatchNorm -> GroupNorm)"""
    model = ModuleValidator.fix(model)
    return model


class DPFLClient:
    """FL client with differential privacy"""

    def __init__(
        self,
        client_id: int,
        dataset,
        target_epsilon: float = 1.0,
        target_delta: float = 1e-5,
        max_grad_norm: float = 1.0,
        device: str = 'cpu'
    ):
        self.client_id = client_id
        self.dataset = dataset
        self.target_epsilon = target_epsilon
        self.target_delta = target_delta
        self.max_grad_norm = max_grad_norm
        self.device = device

    def dp_local_train(
        self,
        model: nn.Module,
        local_epochs: int,
        batch_size: int,
        lr: float
    ) -> Tuple[Dict, float, float]:
        """
        DP local training
        Returns: (weights, loss, epsilon used)
        """
        model = make_private_model(deepcopy(model)).to(self.device)
        model.train()

        loader = DataLoader(
            self.dataset,
            batch_size=batch_size,
            shuffle=True,
            drop_last=True  # required by Opacus
        )
        optimizer = optim.SGD(model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()

        privacy_engine = PrivacyEngine()
        model, optimizer, loader = privacy_engine.make_private_with_epsilon(
            module=model,
            optimizer=optimizer,
            data_loader=loader,
            epochs=local_epochs,
            target_epsilon=self.target_epsilon,
            target_delta=self.target_delta,
            max_grad_norm=self.max_grad_norm
        )

        total_loss = 0.0
        n_batches = 0

        for epoch in range(local_epochs):
            for X, y in loader:
                X, y = X.to(self.device), y.to(self.device)
                optimizer.zero_grad()
                output = model(X)
                loss = criterion(output, y)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                n_batches += 1

        epsilon_used = privacy_engine.get_epsilon(self.target_delta)
        avg_loss = total_loss / max(n_batches, 1)

        # Remove Opacus wrapper to return clean weights
        clean_weights = {
            k.replace('_module.', ''): v
            for k, v in model.state_dict().items()
        }

        return clean_weights, avg_loss, epsilon_used


class DPFLServer:
    """DP-FL server: server-side DP aggregation"""

    def __init__(
        self,
        global_model: nn.Module,
        clients: List['DPFLClient'],
        noise_multiplier: float = 0.5,
        max_grad_norm: float = 1.0,
        device: str = 'cpu'
    ):
        self.global_model = global_model.to(device)
        self.clients = clients
        self.noise_multiplier = noise_multiplier
        self.max_grad_norm = max_grad_norm
        self.device = device

    def clip_and_aggregate(
        self,
        client_weights: List[Dict],
        reference_weights: Dict
    ) -> Dict:
        """
        Server-side clipping and noise addition
        Clip each update then add Gaussian noise
        """
        n_clients = len(client_weights)

        updates = []
        for w in client_weights:
            delta = {
                k: w[k].float() - reference_weights[k].float()
                for k in reference_weights
            }
            updates.append(delta)

        clipped_updates = []
        for delta in updates:
            total_norm = torch.sqrt(
                sum(torch.norm(v) ** 2 for v in delta.values())
            )
            clip_factor = min(1.0, self.max_grad_norm / (total_norm + 1e-8))
            clipped = {k: v * clip_factor for k, v in delta.items()}
            clipped_updates.append(clipped)

        summed = {}
        for key in reference_weights:
            summed[key] = sum(u[key] for u in clipped_updates)

        sigma = self.noise_multiplier * self.max_grad_norm
        noisy = {}
        for key in summed:
            noise = torch.randn_like(summed[key]) * sigma
            noisy[key] = (summed[key] + noise) / n_clients

        aggregated = {
            k: reference_weights[k].float() + noisy[k]
            for k in reference_weights
        }
        return aggregated

7. Secure Aggregation

7.1 The Need for Encrypted Aggregation

DP provides statistical privacy, but the server can still see each client's update individually. Secure Aggregation (SecAgg) is a cryptographic protocol that ensures the server sees only the aggregated result.

7.2 Secret Sharing-Based SecAgg

The Secure Aggregation protocol by Bonawitz et al. (2017) uses the following principle:

  • Clients exchange masks (random number pairs) with each other.
  • Each client adds its mask to its update before sending to the server.
  • When the server sums all masked updates, the masks cancel out, leaving the pure sum of updates.
class SecureAggregation:
    """
    Simplified secure aggregation simulation
    Real implementations use Diffie-Hellman key exchange
    """

    def __init__(self, num_clients: int, seed_length: int = 32):
        self.num_clients = num_clients
        self.seed_length = seed_length

    def generate_pairwise_masks(
        self,
        client_id: int,
        model_shape: Dict[str, torch.Size],
        shared_seeds: Dict
    ) -> Dict:
        """
        Generate pairwise masks between clients
        For client pair i and j: add if i > j, subtract if i < j
        """
        mask = {k: torch.zeros(v) for k, v in model_shape.items()}

        for other_id, seed in shared_seeds.items():
            torch.manual_seed(seed)
            direction = 1 if client_id > other_id else -1

            for key in mask:
                mask[key] += direction * torch.randn(model_shape[key])

        return mask

    def client_mask_update(
        self,
        local_update: Dict,
        mask: Dict
    ) -> Dict:
        """Apply mask to update"""
        return {k: local_update[k] + mask[k] for k in local_update}

    def server_aggregate_masked(
        self,
        masked_updates: List[Dict]
    ) -> Dict:
        """Sum masked updates (masks cancel each other out)"""
        aggregated = {}
        for key in masked_updates[0]:
            aggregated[key] = sum(u[key] for u in masked_updates)
            aggregated[key] /= len(masked_updates)
        return aggregated

7.3 Homomorphic Encryption Overview

Homomorphic Encryption (HE) is a cryptographic scheme that allows computations to be performed on encrypted data. In FL, HE allows the server to aggregate client updates without decrypting them.

  • Partially Homomorphic Encryption (PHE): Supports only addition or multiplication (Paillier cryptosystem)
  • Fully Homomorphic Encryption (FHE): Supports arbitrary operations (CKKS, BGV) — very high computational cost

In practical FL, PHE supporting only addition is sufficient for aggregation (weighted average).


8. The Flower Framework

8.1 Introduction to Flower

Flower (flwr) is a Python framework for federated learning that is framework-agnostic. It works with PyTorch, TensorFlow, JAX, and other ML frameworks.

pip install flwr
pip install flwr[simulation]

8.2 Flower Client Implementation

import flwr as fl
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from typing import Dict, List, Tuple
import numpy as np


class FlowerClient(fl.client.NumPyClient):
    """Flower FL Client"""

    def __init__(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: DataLoader,
        device: str = 'cpu'
    ):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.criterion = nn.CrossEntropyLoss()

    def get_parameters(self, config: Dict) -> List[np.ndarray]:
        """Return current model parameters as NumPy arrays"""
        return [
            val.cpu().numpy()
            for _, val in self.model.state_dict().items()
        ]

    def set_parameters(self, parameters: List[np.ndarray]):
        """Set model parameters from NumPy arrays"""
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = {k: torch.tensor(v) for k, v in params_dict}
        self.model.load_state_dict(state_dict, strict=True)

    def fit(
        self,
        parameters: List[np.ndarray],
        config: Dict
    ) -> Tuple[List[np.ndarray], int, Dict]:
        """Local training with parameters received from server"""
        self.set_parameters(parameters)

        local_epochs = int(config.get('local_epochs', 1))
        lr = float(config.get('lr', 0.01))

        self.model.train()
        optimizer = torch.optim.SGD(
            self.model.parameters(), lr=lr, momentum=0.9
        )

        train_loss = 0.0
        for epoch in range(local_epochs):
            for X, y in self.train_loader:
                X, y = X.to(self.device), y.to(self.device)
                optimizer.zero_grad()
                loss = self.criterion(self.model(X), y)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()

        return (
            self.get_parameters(config={}),
            len(self.train_loader.dataset),
            {'train_loss': train_loss}
        )

    def evaluate(
        self,
        parameters: List[np.ndarray],
        config: Dict
    ) -> Tuple[float, int, Dict]:
        """Evaluate with parameters received from server"""
        self.set_parameters(parameters)
        self.model.eval()

        val_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for X, y in self.val_loader:
                X, y = X.to(self.device), y.to(self.device)
                output = self.model(X)
                val_loss += self.criterion(output, y).item()

                _, predicted = output.max(1)
                total += y.size(0)
                correct += predicted.eq(y).sum().item()

        accuracy = correct / total
        return (
            val_loss / len(self.val_loader),
            len(self.val_loader.dataset),
            {'accuracy': accuracy}
        )

8.3 Flower Server Strategy

from flwr.server.strategy import FedAvg
from flwr.common import Parameters, FitIns, EvaluateRes
from flwr.server.client_proxy import ClientProxy


class CustomFedAvgStrategy(FedAvg):
    """Custom FedAvg strategy"""

    def __init__(
        self,
        fraction_fit: float = 0.1,
        fraction_evaluate: float = 0.1,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
        initial_parameters=None,
    ):
        super().__init__(
            fraction_fit=fraction_fit,
            fraction_evaluate=fraction_evaluate,
            min_fit_clients=min_fit_clients,
            min_evaluate_clients=min_evaluate_clients,
            min_available_clients=min_available_clients,
            initial_parameters=initial_parameters,
        )
        self.round_metrics = []

    def configure_fit(
        self,
        server_round: int,
        parameters: Parameters,
        client_manager
    ) -> List[Tuple[ClientProxy, FitIns]]:
        """Configure training for each round"""
        # Decay learning rate as rounds progress
        lr = 0.01 * (0.99 ** server_round)

        config = {
            'local_epochs': 5,
            'lr': lr,
            'round': server_round,
        }

        fit_ins = FitIns(parameters, config)
        clients = client_manager.sample(
            num_clients=max(
                int(client_manager.num_available() * self.fraction_fit), 1
            ),
            min_num_clients=self.min_fit_clients
        )
        return [(client, fit_ins) for client in clients]

    def aggregate_evaluate(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, EvaluateRes]],
        failures
    ):
        """Aggregate evaluation results and log"""
        aggregated_loss, metrics = super().aggregate_evaluate(
            server_round, results, failures
        )

        if metrics:
            print(
                f"Round {server_round}: "
                f"Aggregated accuracy = {metrics.get('accuracy', 0):.4f}"
            )
            self.round_metrics.append({
                'round': server_round,
                'loss': aggregated_loss,
                'metrics': metrics
            })

        return aggregated_loss, metrics


def run_flower_simulation():
    """Run Flower simulation"""
    import torchvision
    import torchvision.transforms as transforms
    from torch.utils.data import Subset

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = torchvision.datasets.MNIST(
        './data', train=True, download=True, transform=transform
    )
    test_dataset = torchvision.datasets.MNIST(
        './data', train=False, transform=transform
    )

    num_clients = 10

    def client_fn(cid: str):
        cid = int(cid)
        n = len(train_dataset) // num_clients
        start = cid * n
        indices = list(range(start, min(start + n, len(train_dataset))))
        train_subset = Subset(train_dataset, indices)

        val_indices = list(range(cid * 100, (cid + 1) * 100))
        val_subset = Subset(test_dataset, val_indices)

        model = SimpleNet(784, 256, 10)
        return FlowerClient(
            model=model,
            train_loader=DataLoader(train_subset, batch_size=32),
            val_loader=DataLoader(val_subset, batch_size=64)
        ).to_client()

    global_model = SimpleNet(784, 256, 10)
    initial_params = fl.common.ndarrays_to_parameters(
        [val.numpy() for val in global_model.state_dict().values()]
    )

    strategy = CustomFedAvgStrategy(
        fraction_fit=0.5,
        fraction_evaluate=0.5,
        min_fit_clients=2,
        min_evaluate_clients=2,
        min_available_clients=num_clients,
        initial_parameters=initial_params
    )

    history = fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=num_clients,
        config=fl.server.ServerConfig(num_rounds=50),
        strategy=strategy,
    )

    print("\n=== Flower Simulation Complete ===")
    print(f"Final distributed loss: {history.losses_distributed[-1]}")

9. Hands-On FL Project: Hospital Federated Learning

9.1 Multi-Hospital Chest X-Ray Diagnosis

A scenario where multiple hospitals jointly train a lung disease diagnostic model without sharing patient chest X-ray data.

import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import pandas as pd


class ChestXRayDataset(Dataset):
    """Chest X-Ray Dataset (per hospital)"""

    def __init__(
        self,
        data_dir: str,
        labels_file: str,
        transform=None,
        hospital_id: int = 0
    ):
        self.data_dir = data_dir
        self.labels = pd.read_csv(labels_file)
        self.transform = transform
        self.hospital_id = hospital_id

        # Filter to this hospital's data only
        self.labels = self.labels[
            self.labels['hospital_id'] == hospital_id
        ].reset_index(drop=True)

        self.classes = [
            'Normal', 'Pneumonia', 'COVID-19', 'Tuberculosis'
        ]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img_path = os.path.join(
            self.data_dir, self.labels.iloc[idx]['filename']
        )
        image = Image.open(img_path).convert('RGB')
        label = self.labels.iloc[idx]['label']

        if self.transform:
            image = self.transform(image)

        return image, label


def build_chest_model(num_classes: int = 4, pretrained: bool = True):
    """ResNet50-based chest X-ray classification model"""
    model = models.resnet50(pretrained=pretrained)

    in_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(0.3),
        nn.Linear(in_features, 256),
        nn.ReLU(),
        nn.Linear(256, num_classes)
    )
    return model


def train_hospital_federation():
    """Run hospital federated learning"""
    print("=== Multi-Hospital Federated Learning ===")
    print("Patient data at each hospital never leaves its location.")
    print("Only model parameter updates are aggregated.\n")

    global_model = build_chest_model(num_classes=4)

    total_params = sum(p.numel() for p in global_model.parameters())
    print(f"Global model parameters: {total_params:,}")
    print(
        f"Communication cost (FP32): "
        f"{total_params * 4 / 1024 / 1024:.1f} MB/round"
    )

10. The Future of Federated Learning

Federated Learning is establishing itself as a key technology for resolving the conflict between AI and privacy. Key trends to watch:

Cross-Device vs Cross-Silo

  • Cross-device: Millions of mobile devices participate (Google's Gboard)
  • Cross-silo: A small number of institutions (hospitals, banks) participate — more trustworthy but smaller scale

FL + LLMs

With the rise of Large Language Models (LLMs), FL has become even more important. When fine-tuning models on user conversations, FL ensures that conversations never leave the user's device. Combining with Parameter-Efficient Fine-Tuning (PEFT, LoRA) further reduces communication costs.

Regulation-Friendly AI

As regulations like GDPR and HIPAA tighten, FL is becoming a practical compliance solution. FL adoption is expected to expand rapidly in healthcare, finance, and legal sectors.


References

  • McMahan, H. B., et al. (2017). Communication-Efficient Learning of Deep Networks from Decentralized Data. AISTATS 2017.
  • Li, T., et al. (2020). Federated Optimization in Heterogeneous Networks (FedProx). MLSys 2020.
  • Karimireddy, S. P., et al. (2020). SCAFFOLD: Stochastic Controlled Averaging for Federated Learning. ICML 2020.
  • Bonawitz, K., et al. (2017). Practical Secure Aggregation for Privacy-Preserving Machine Learning. CCS 2017.
  • Flower Framework: https://flower.ai/docs/
  • Opacus (PyTorch DP): https://opacus.ai/