- Authors

- Name
- Youngju Kim
- @fjvbn20031
連合学習完全ガイド: プライバシー保護分散AI
現代AIの大きな皮肉の一つは、モデルの性能が向上するほど多くのデータが必要となり、データが増えるほどプライバシー侵害のリスクが高まるという点です。病院は患者データを共有できず、スマートフォンメーカーはユーザーのタイピングパターンをサーバーに送信できず、金融機関は競合他社と取引履歴を共有できません。
連合学習(Federated Learning、FL)はこのジレンマを解決します。データを中央サーバーに送信する代わりに、データのある場所にモデルを送り、ローカルで学習させ、モデルの更新(重みの変化)のみを集約します。コンセプトは「データは動かず、知識(モデル更新)だけが動く」というものです。
このガイドでは、FLの理論的基礎から実践的な実装まで、すべてを解説します。
1. 連合学習の基礎
1.1 従来の中央集権的学習の問題点
従来のMLパイプラインを考えてみましょう:数千の病院から患者データを収集し、中央サーバーに保存し、そのデータで診断モデルを学習する。問題は何でしょうか?
データプライバシーの問題
患者記録、金融取引履歴、個人的なコミュニケーションは非常に機密性が高いです。このようなデータを中央サーバーに送信すると、以下のリスクが生じます:
- 送信中の盗聴リスク
- サーバーへのハッキングによる大規模データ漏洩
- データを保有する組織への信頼失墜
- 規制違反による法的制裁
法的規制
データプライバシー規制は世界中で強化されています。
- GDPR(一般データ保護規則):処理目的の開示、同意要件、データ最小化原則を定めたEUのデータ保護法。
- HIPAA(医療保険携行性・責任法):患者の健康情報(PHI)を保護する米国法。
- CCPA(カリフォルニア州消費者プライバシー法):企業が収集した個人情報に対するカリフォルニア州住民の権利を保障。
通信コスト
何百万ものエッジデバイス(スマートフォン、IoTセンサー)から中央拠点にデータを送信するには、膨大なネットワーク帯域幅が必要です。画像や音声などの大きなデータ形式はさらに問題を悪化させます。
1.2 連合学習の核心的なアイデア
2016年にGoogleのMcMahanらが提案した連合学習は、以下の原則に基づいています:
「データはローカルに留まる;知識(モデル更新)だけが移動する。」
基本的なFLプロセス:
- 初期化:中央サーバーがグローバルモデルを初期化。
- 配布:サーバーが現在のグローバルモデルを選択したクライアントに配布。
- ローカル学習:各クライアントがローカルデータでモデルを学習。
- アップロード:クライアントがモデル更新(勾配または重みの差分)をサーバーに送信。
- 集約:サーバーが更新を集約(例:平均化)してグローバルモデルを更新。
- 繰り返し:収束するまでステップ2〜5を繰り返す。
このアプローチでは、生データはクライアントデバイスから出ることがありません。モデルパラメータの更新だけが送信されます。
1.3 連合学習の応用
モバイル / エッジデバイス
GoogleはGboard(モバイルキーボード)にFLを適用しました。ユーザーのタイピングパターンはサーバーに送信されず、デバイス上で直接次の単語予測モデルが改善されます。何億人ものユーザーのデータが、個人情報をデバイスから出すことなく貢献します。
ヘルスケア
複数の病院から患者データを共有せずに、より良い診断AIを構築できます。希少疾患については、一つの病院のデータだけではモデル学習に不十分ですが、FLで複数病院の知識を組み合わせることができます。
金融
複数の金融機関が顧客データを共有せずに、不正検出やクレジットスコアリングで協力できます。特に国境を越えた金融取引では、各国のデータ主権を尊重しながら共同モデルを学習できます。
自動運転
複数の自動車メーカーが独自の走行データを共有せずに、道路危険検出モデルを共同改善できます。
2. 連合学習アーキテクチャ
2.1 クライアント・サーバー構造
最も一般的なFLアーキテクチャは、中央集約サーバーと複数のクライアントで構成されます。
┌─────────────────┐
│ 中央サーバー │
│ (アグリゲーター) │
└────────┬─────────┘
│ モデル配布 / 更新集約
┌────────┼────────┐
↓ ↓ ↓
┌─────────┐ ┌─────────┐ ┌─────────┐
│クライアント1│ │クライアント2│ │クライアント3│
│(ローカル │ │(ローカル │ │(ローカル │
│学習) │ │学習) │ │学習) │
└─────────┘ └─────────┘ └─────────┘
サーバーの責任
- グローバルモデルの保守と管理
- 各ラウンドの参加クライアントを選択
- クライアントの更新を集約
- 集約されたモデルをクライアントに配布
クライアントの責任
- ローカルデータを保持
- 受け取ったモデルをローカルデータでファインチューニング
- 更新されたモデル(または勾配)を送信
2.2 水平連合学習
水平FLは、すべてのクライアントが同じ特徴空間を持つが、異なるデータサンプルを持つ場合に使用されます。例えば、複数の病院が同じ診断項目(血圧、血糖、年齢)を測定しているが、異なる患者を持つ場合です。
クライアント1: [feature1, feature2, feature3] x [サンプル 1-1000]
クライアント2: [feature1, feature2, feature3] x [サンプル 1001-2000]
クライアント3: [feature1, feature2, feature3] x [サンプル 2001-3000]
同じ特徴空間、異なるサンプル空間。最も一般的なFLの形態です。
2.3 垂直連合学習
垂直FLは、クライアントが同じユーザー(データサンプル)を持つが、異なる特徴を持つ場合に使用されます。例えば、銀行がユーザーの財務情報を持ち、病院が同じユーザーの医療情報を持つ場合です。
クライアントA (銀行): [財務特徴] x [ユーザー 1-10000]
クライアントB (病院): [医療特徴] x [ユーザー 1-10000]
同じサンプル空間、異なる特徴空間。
垂直FLにはより複雑なプロトコルが必要です。ラベルを持つクライアントと特徴だけを持つクライアント間の協力には暗号化技術が必要です。
2.4 連合転移学習
クライアントのサンプルと特徴空間が部分的に重なる場合、転移学習技術を組み合わせます。これにより、データの重複が最小限の場合でもFLを適用できます。
3. FedAvgアルゴリズム
3.1 McMahan et al. (2017) オリジナルアルゴリズム
FedAvg(Federated Averaging)は基本的なFLアルゴリズムで、2017年にGoogleのMcMahanらが発表しました。核心的なアイデアは、各クライアントが複数のローカルSGD更新を実行し、サーバーが重みを平均化するというものです。
アルゴリズム概要
サーバーが実行:
w_0を初期化
ラウンド t = 1, 2, ..., T:
m = max(C x K, 1) // C: 参加率, K: 全クライアント数
S_t = mクライアントをランダムに選択
S_tの各クライアント k(並行):
w_{t+1}^k = ClientUpdate(k, w_t)
w_{t+1} = sum (n_k / n) x w_{t+1}^k // 重み付き平均
クライアント k が実行:
B = ローカルデータをバッチに分割
ローカルエポック e = 1, ..., E:
バッチ b in B:
w = w - lr x grad_loss(w; b)
return w
主要パラメータ
- C: 各ラウンドに参加するクライアントの割合(0より大きく1以下)
- E: クライアントあたりのローカルエポック数
- B: ローカルミニバッチサイズ
- lr: 学習率
E=1でBが全データの場合はFedSGDと同等です。Eを増やすと通信ラウンドは減りますが、クライアントドリフトのリスクが高まります。
3.2 完全なFedAvg実装
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import numpy as np
from copy import deepcopy
from typing import List, Dict, Tuple
import random
# ========== モデル定義 ==========
class SimpleNet(nn.Module):
"""シンプルな分類ネットワーク"""
def __init__(self, input_dim: int, hidden_dim: int, num_classes: int):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, num_classes)
)
def forward(self, x):
return self.net(x)
# ========== クライアント ==========
class FLClient:
"""連合学習クライアント"""
def __init__(
self,
client_id: int,
dataset,
device: str = 'cpu'
):
self.client_id = client_id
self.dataset = dataset
self.device = device
def local_train(
self,
model: nn.Module,
local_epochs: int,
batch_size: int,
lr: float
) -> Tuple[Dict, float]:
"""
ローカルデータでモデルを学習
戻り値: (更新された重み, ローカル損失)
"""
model = deepcopy(model).to(self.device)
model.train()
loader = DataLoader(
self.dataset,
batch_size=batch_size,
shuffle=True
)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
criterion = nn.CrossEntropyLoss()
total_loss = 0.0
n_batches = 0
for epoch in range(local_epochs):
for X, y in loader:
X, y = X.to(self.device), y.to(self.device)
optimizer.zero_grad()
output = model(X)
loss = criterion(output, y)
loss.backward()
optimizer.step()
total_loss += loss.item()
n_batches += 1
avg_loss = total_loss / max(n_batches, 1)
return model.state_dict(), avg_loss
def evaluate(self, model: nn.Module) -> Tuple[float, float]:
"""ローカルデータでモデルを評価"""
model = deepcopy(model).to(self.device)
model.eval()
loader = DataLoader(self.dataset, batch_size=64)
criterion = nn.CrossEntropyLoss()
total_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for X, y in loader:
X, y = X.to(self.device), y.to(self.device)
output = model(X)
loss = criterion(output, y)
total_loss += loss.item()
_, predicted = output.max(1)
total += y.size(0)
correct += predicted.eq(y).sum().item()
return total_loss / len(loader), correct / total
# ========== サーバー ==========
class FedAvgServer:
"""FedAvgサーバー"""
def __init__(
self,
global_model: nn.Module,
clients: List[FLClient],
fraction: float = 0.1,
device: str = 'cpu'
):
self.global_model = global_model.to(device)
self.clients = clients
self.fraction = fraction
self.device = device
self.round_history = []
def select_clients(self) -> List[FLClient]:
"""各ラウンドのクライアントを選択"""
m = max(int(self.fraction * len(self.clients)), 1)
return random.sample(self.clients, m)
def aggregate(
self,
client_weights: List[Dict],
client_sizes: List[int]
) -> Dict:
"""
重み付き平均集約(FedAvg)
n_k / n で重み付け
"""
total_size = sum(client_sizes)
aggregated = {}
for key in client_weights[0].keys():
aggregated[key] = torch.zeros_like(
client_weights[0][key], dtype=torch.float32
)
for w, size in zip(client_weights, client_sizes):
weight = size / total_size
aggregated[key] += weight * w[key].float()
return aggregated
def train_round(
self,
local_epochs: int = 5,
batch_size: int = 32,
lr: float = 0.01
) -> Dict:
"""1回のFLラウンドを実行"""
selected = self.select_clients()
client_weights = []
client_sizes = []
client_losses = []
for client in selected:
weights, loss = client.local_train(
self.global_model, local_epochs, batch_size, lr
)
client_weights.append(weights)
client_sizes.append(len(client.dataset))
client_losses.append(loss)
new_weights = self.aggregate(client_weights, client_sizes)
self.global_model.load_state_dict(new_weights)
round_info = {
'num_clients': len(selected),
'avg_local_loss': np.mean(client_losses),
'client_losses': client_losses
}
self.round_history.append(round_info)
return round_info
def evaluate_global(self, test_loader: DataLoader) -> Tuple[float, float]:
"""グローバルモデルを評価"""
self.global_model.eval()
criterion = nn.CrossEntropyLoss()
total_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for X, y in test_loader:
X, y = X.to(self.device), y.to(self.device)
output = self.global_model(X)
loss = criterion(output, y)
total_loss += loss.item()
_, predicted = output.max(1)
total += y.size(0)
correct += predicted.eq(y).sum().item()
return total_loss / len(test_loader), correct / total
def federated_train(
self,
num_rounds: int,
local_epochs: int = 5,
batch_size: int = 32,
lr: float = 0.01,
test_loader: DataLoader = None
):
"""完全なFL学習ループ"""
print(f"FL学習: {num_rounds}ラウンド, {len(self.clients)}クライアント")
for round_num in range(1, num_rounds + 1):
round_info = self.train_round(local_epochs, batch_size, lr)
if test_loader and round_num % 10 == 0:
test_loss, test_acc = self.evaluate_global(test_loader)
print(
f"ラウンド {round_num:3d}/{num_rounds} | "
f"クライアント: {round_info['num_clients']} | "
f"ローカル損失: {round_info['avg_local_loss']:.4f} | "
f"テスト精度: {test_acc:.4f}"
)
print("連合学習完了!")
# ========== Non-IIDデータ分割 ==========
def create_non_iid_partition(
dataset,
num_clients: int,
num_classes: int,
alpha: float = 0.5
) -> List[List[int]]:
"""
ディリクレ分布を使ったNon-IIDデータ分割
alphaが低いほど不均一な分布
"""
labels = np.array([dataset[i][1] for i in range(len(dataset))])
client_indices = [[] for _ in range(num_clients)]
for cls in range(num_classes):
cls_indices = np.where(labels == cls)[0]
np.random.shuffle(cls_indices)
proportions = np.random.dirichlet([alpha] * num_clients)
proportions = (proportions * len(cls_indices)).astype(int)
proportions[-1] = len(cls_indices) - proportions[:-1].sum()
start = 0
for k, prop in enumerate(proportions):
client_indices[k].extend(
cls_indices[start:start + prop].tolist()
)
start += prop
return client_indices
4. 高度なFLアルゴリズム
4.1 FedProx - 不均一データへの対応
FedAvgはデータが非常に不均一(Non-IID)な場合に収束しないことがあります。FedProxはこれを解決します。
核心的なアイデア:ローカル損失に近接項(Proximal Term)を追加し、クライアントモデルがグローバルモデルから大きく離れないように制約します。
class FedProxClient:
"""FedProxクライアント(近接項付き)"""
def __init__(self, client_id: int, dataset, mu: float = 0.01, device: str = 'cpu'):
self.client_id = client_id
self.dataset = dataset
self.mu = mu # 近接項の強さ
self.device = device
def local_train_proximal(
self,
model: nn.Module,
global_model: nn.Module,
local_epochs: int,
batch_size: int,
lr: float
) -> Tuple[Dict, float]:
"""
近接項付きローカル学習
L_FedProx = L_CE + (mu/2) * ||w - w_global||^2
"""
local_model = deepcopy(model).to(self.device)
global_weights = deepcopy(global_model).to(self.device)
local_model.train()
loader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True)
optimizer = optim.SGD(local_model.parameters(), lr=lr, momentum=0.9)
criterion = nn.CrossEntropyLoss()
total_loss = 0.0
n_batches = 0
for epoch in range(local_epochs):
for X, y in loader:
X, y = X.to(self.device), y.to(self.device)
optimizer.zero_grad()
output = local_model(X)
ce_loss = criterion(output, y)
# 近接項: (mu/2) * ||w - w_global||^2
proximal_term = 0.0
for w, w_g in zip(
local_model.parameters(),
global_weights.parameters()
):
proximal_term += (w - w_g.detach()).norm() ** 2
loss = ce_loss + (self.mu / 2) * proximal_term
loss.backward()
optimizer.step()
total_loss += loss.item()
n_batches += 1
avg_loss = total_loss / max(n_batches, 1)
return local_model.state_dict(), avg_loss
4.2 SCAFFOLD - クライアントドリフトの修正
FedAvgでE > 1の場合、クライアントドリフトが発生します。SCAFFOLDはコントロールバリエートを使ってこれを修正します。
class SCAFFOLDClient:
"""SCAFFOLDクライアント(コントロールバリエート付き)"""
def __init__(self, client_id: int, dataset, device: str = 'cpu'):
self.client_id = client_id
self.dataset = dataset
self.device = device
self.c_k = None # クライアントコントロールバリエート
def scaffold_update(
self,
model: nn.Module,
server_control: Dict,
local_epochs: int,
batch_size: int,
lr: float
) -> Tuple[Dict, Dict, float]:
"""
SCAFFOLDローカル更新
戻り値: (最終重み, c_k差分, 平均損失)
"""
local_model = deepcopy(model).to(self.device)
# c_kの初期化
if self.c_k is None:
self.c_k = {
k: torch.zeros_like(v)
for k, v in server_control.items()
}
initial_weights = deepcopy(local_model.state_dict())
loader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True)
optimizer = optim.SGD(local_model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
total_loss = 0.0
n_batches = 0
total_steps = 0
local_model.train()
for epoch in range(local_epochs):
for X, y in loader:
X, y = X.to(self.device), y.to(self.device)
optimizer.zero_grad()
output = local_model(X)
loss = criterion(output, y)
loss.backward()
# コントロールバリエートで勾配を修正
for name, param in local_model.named_parameters():
if param.grad is not None:
param.grad.data += (
server_control[name].to(self.device)
- self.c_k[name].to(self.device)
)
optimizer.step()
total_loss += loss.item()
n_batches += 1
total_steps += 1
final_weights = local_model.state_dict()
# c_kを更新
new_c_k = {}
c_k_diff = {}
for name in self.c_k:
w_diff = (
initial_weights[name].float() - final_weights[name].float()
)
new_c_k[name] = (
self.c_k[name]
- server_control[name]
+ w_diff / (total_steps * lr)
)
c_k_diff[name] = new_c_k[name] - self.c_k[name]
self.c_k = new_c_k
avg_loss = total_loss / max(n_batches, 1)
return final_weights, c_k_diff, avg_loss
5. 差分プライバシー(DP)
5.1 DPの数学的定義
差分プライバシー(DP)はプライバシー保護のための数学的フレームワークです。直感的には「単一のデータポイントを追加または削除しても、出力分布が大きく変わらない」ことを保証します。
イプシロン-デルタDP定義
ランダム化メカニズムMが全ての隣接データセットDとD'および全ての出力集合Sに対して以下を満たす場合、(epsilon, delta)-DPを満たすとします:
Pr[M(D) in S] <= exp(epsilon) x Pr[M(D') in S] + delta
- epsilon: プライバシー予算——低いほど強いプライバシー保証
- delta: 失敗確率——通常1/|データセット|以下に設定
5.2 ガウスメカニズムとクリッピング
FLにDPを適用するには、各クライアントの更新にノイズを追加する必要があります。
勾配クリッピング:まず勾配のL2ノルムを最大値Cにクリップします。
g_clipped = g x min(1, C / ||g||_2)
ノイズ追加:クリップされた勾配にガウスノイズを追加します。
g_dp = g_clipped + N(0, sigma^2 x C^2 x I)
ここでsigmaはノイズ乗数です。
5.3 DP-FL実装
import torch
import torch.nn as nn
from opacus import PrivacyEngine
from opacus.validators import ModuleValidator
def make_private_model(model: nn.Module) -> nn.Module:
"""Opacus互換モデルに変換(BatchNorm -> GroupNorm)"""
model = ModuleValidator.fix(model)
return model
class DPFLClient:
"""差分プライバシー付きFLクライアント"""
def __init__(
self,
client_id: int,
dataset,
target_epsilon: float = 1.0,
target_delta: float = 1e-5,
max_grad_norm: float = 1.0,
device: str = 'cpu'
):
self.client_id = client_id
self.dataset = dataset
self.target_epsilon = target_epsilon
self.target_delta = target_delta
self.max_grad_norm = max_grad_norm
self.device = device
def dp_local_train(
self,
model: nn.Module,
local_epochs: int,
batch_size: int,
lr: float
) -> Tuple[Dict, float, float]:
"""
DPローカル学習
戻り値: (重み, 損失, 使用されたepsilon)
"""
model = make_private_model(deepcopy(model)).to(self.device)
model.train()
loader = DataLoader(
self.dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True # Opacusが必要とする
)
optimizer = optim.SGD(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
privacy_engine = PrivacyEngine()
model, optimizer, loader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
data_loader=loader,
epochs=local_epochs,
target_epsilon=self.target_epsilon,
target_delta=self.target_delta,
max_grad_norm=self.max_grad_norm
)
total_loss = 0.0
n_batches = 0
for epoch in range(local_epochs):
for X, y in loader:
X, y = X.to(self.device), y.to(self.device)
optimizer.zero_grad()
output = model(X)
loss = criterion(output, y)
loss.backward()
optimizer.step()
total_loss += loss.item()
n_batches += 1
epsilon_used = privacy_engine.get_epsilon(self.target_delta)
avg_loss = total_loss / max(n_batches, 1)
# Opacusラッパーを削除してクリーンな重みを返す
clean_weights = {
k.replace('_module.', ''): v
for k, v in model.state_dict().items()
}
return clean_weights, avg_loss, epsilon_used
6. Secure Aggregation
Secure Aggregationは、サーバーが個々のクライアント更新を見ることなく、集約されたモデル更新のみを学習できるようにします。
6.1 Secure Aggregationの核心原理
加法的秘密分散に基づいて:
- 各クライアントは自分の更新を秘密鍵で暗号化
- サーバーは暗号化された更新を集約
- 集約された結果のみが復号化される
- 個々の更新は秘密のまま
import secrets
import hashlib
from typing import List
class SecureAggregation:
"""
シンプルな加法的マスクによるSecure Aggregation
実際の環境ではPySyftまたはFlowerの安全な集約を使用
"""
def __init__(self, num_clients: int, model_dim: int, prime: int = 2**31 - 1):
self.num_clients = num_clients
self.model_dim = model_dim
self.prime = prime
def generate_pairwise_masks(self, client_id: int, n_clients: int) -> List[int]:
"""他のクライアントとの対称マスクを生成"""
masks = []
for other_id in range(n_clients):
if other_id == client_id:
masks.append(0)
else:
# 決定論的シードから共有マスクを生成
seed = min(client_id, other_id) * 1000 + max(client_id, other_id)
rng = np.random.default_rng(seed)
mask = rng.integers(0, self.prime, size=self.model_dim)
if client_id < other_id:
masks.append(mask)
else:
masks.append(-mask)
return masks
def client_mask_update(
self,
client_id: int,
model_update: np.ndarray,
n_clients: int
) -> np.ndarray:
"""モデル更新にマスクを適用"""
masks = self.generate_pairwise_masks(client_id, n_clients)
masked_update = model_update.copy()
for mask in masks:
if isinstance(mask, np.ndarray):
masked_update = (masked_update + mask) % self.prime
return masked_update
def server_aggregate(self, masked_updates: List[np.ndarray]) -> np.ndarray:
"""マスクされた更新を集約(マスクはキャンセルされる)"""
aggregated = np.zeros(self.model_dim)
for update in masked_updates:
aggregated = (aggregated + update) % self.prime
# プライムフィールドからの変換
aggregated = np.where(
aggregated > self.prime // 2,
aggregated - self.prime,
aggregated
)
return aggregated / len(masked_updates)
7. Non-IIDデータの課題
7.1 Non-IIDとは?
連合学習でクライアントのデータが独立同一分布(IID)でない場合、学習が困難になります。
Non-IIDの種類:
- 特徴分布の不均一性:各クライアントが異なる特徴分布を持つ
- ラベル分布の不均一性:各クライアントが異なるクラス分布を持つ(最も一般的)
- 数量の不均一性:クライアントごとにデータ量が大きく異なる
- 概念ドリフト:同じ特徴でも異なるラベルを持つ
import matplotlib.pyplot as plt
import numpy as np
def visualize_data_distribution(client_indices, dataset, num_classes=10, num_clients=10):
"""クライアントのデータ分布を可視化"""
distribution = np.zeros((num_clients, num_classes))
for client_id, indices in enumerate(client_indices):
for idx in indices:
label = dataset[idx][1]
distribution[client_id][label] += 1
fig, ax = plt.subplots(figsize=(14, 8))
im = ax.imshow(distribution, aspect='auto', cmap='Blues')
ax.set_xlabel('クラス')
ax.set_ylabel('クライアントID')
ax.set_title('クライアントごとのクラス分布')
plt.colorbar(im)
plt.tight_layout()
plt.savefig('data_distribution.png')
plt.show()
# 各クライアントのエントロピーを計算(均一性の尺度)
entropies = []
for i in range(num_clients):
probs = distribution[i] / distribution[i].sum()
probs = probs[probs > 0]
entropy = -np.sum(probs * np.log(probs))
entropies.append(entropy)
print(f"クライアント {i}: サンプル数={int(distribution[i].sum())}, "
f"エントロピー={entropy:.3f}")
return distribution, entropies
8. Flower(flwr)フレームワーク
Flowerは連合学習の研究と本番環境向けの最も人気のあるフレームワークです。
8.1 Flowerクライアント実装
import flwr as fl
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from typing import Dict, List, Tuple
class FlowerClient(fl.client.NumPyClient):
"""Flower連合学習クライアント"""
def __init__(
self,
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
device: str = 'cpu'
):
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.criterion = nn.CrossEntropyLoss()
def get_parameters(self, config: Dict) -> List[np.ndarray]:
"""モデルパラメータを返す"""
return [val.cpu().numpy() for _, val in self.model.state_dict().items()]
def set_parameters(self, parameters: List[np.ndarray]):
"""モデルパラメータを設定"""
params_dict = zip(self.model.state_dict().keys(), parameters)
state_dict = {k: torch.Tensor(v) for k, v in params_dict}
self.model.load_state_dict(state_dict, strict=True)
def fit(
self,
parameters: List[np.ndarray],
config: Dict
) -> Tuple[List[np.ndarray], int, Dict]:
"""サーバーから受け取ったパラメータで学習"""
self.set_parameters(parameters)
local_epochs = int(config.get('local_epochs', 5))
lr = float(config.get('lr', 0.01))
self.model.train()
optimizer = torch.optim.SGD(
self.model.parameters(), lr=lr, momentum=0.9
)
train_loss = 0.0
for epoch in range(local_epochs):
for X, y in self.train_loader:
X, y = X.to(self.device), y.to(self.device)
optimizer.zero_grad()
loss = self.criterion(self.model(X), y)
loss.backward()
optimizer.step()
train_loss += loss.item()
return (
self.get_parameters(config={}),
len(self.train_loader.dataset),
{'train_loss': train_loss}
)
def evaluate(
self,
parameters: List[np.ndarray],
config: Dict
) -> Tuple[float, int, Dict]:
"""サーバーから受け取ったパラメータで評価"""
self.set_parameters(parameters)
self.model.eval()
val_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for X, y in self.val_loader:
X, y = X.to(self.device), y.to(self.device)
output = self.model(X)
val_loss += self.criterion(output, y).item()
_, predicted = output.max(1)
total += y.size(0)
correct += predicted.eq(y).sum().item()
accuracy = correct / total
return (
val_loss / len(self.val_loader),
len(self.val_loader.dataset),
{'accuracy': accuracy}
)
8.2 Flowerサーバー戦略
from flwr.server.strategy import FedAvg
from flwr.common import Parameters, FitIns, EvaluateRes
from flwr.server.client_proxy import ClientProxy
class CustomFedAvgStrategy(FedAvg):
"""カスタムFedAvg戦略"""
def __init__(
self,
fraction_fit: float = 0.1,
fraction_evaluate: float = 0.1,
min_fit_clients: int = 2,
min_evaluate_clients: int = 2,
min_available_clients: int = 2,
initial_parameters=None,
):
super().__init__(
fraction_fit=fraction_fit,
fraction_evaluate=fraction_evaluate,
min_fit_clients=min_fit_clients,
min_evaluate_clients=min_evaluate_clients,
min_available_clients=min_available_clients,
initial_parameters=initial_parameters,
)
self.round_metrics = []
def configure_fit(
self,
server_round: int,
parameters: Parameters,
client_manager
) -> List[Tuple[ClientProxy, FitIns]]:
"""各ラウンドの学習設定"""
# ラウンドが進むにつれて学習率を減衰
lr = 0.01 * (0.99 ** server_round)
config = {
'local_epochs': 5,
'lr': lr,
'round': server_round,
}
fit_ins = FitIns(parameters, config)
clients = client_manager.sample(
num_clients=max(
int(client_manager.num_available() * self.fraction_fit), 1
),
min_num_clients=self.min_fit_clients
)
return [(client, fit_ins) for client in clients]
def aggregate_evaluate(
self,
server_round: int,
results: List[Tuple[ClientProxy, EvaluateRes]],
failures
):
"""評価結果を集約してログ"""
aggregated_loss, metrics = super().aggregate_evaluate(
server_round, results, failures
)
if metrics:
print(
f"ラウンド {server_round}: "
f"集約精度 = {metrics.get('accuracy', 0):.4f}"
)
self.round_metrics.append({
'round': server_round,
'loss': aggregated_loss,
'metrics': metrics
})
return aggregated_loss, metrics
9. 実践FLプロジェクト:病院連合学習
9.1 多施設胸部X線診断
複数の病院が患者の胸部X線データを共有せずに、肺疾患診断モデルを共同学習するシナリオです。
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import pandas as pd
class ChestXRayDataset(Dataset):
"""胸部X線データセット(病院ごと)"""
def __init__(
self,
data_dir: str,
labels_file: str,
transform=None,
hospital_id: int = 0
):
self.data_dir = data_dir
self.labels = pd.read_csv(labels_file)
self.transform = transform
self.hospital_id = hospital_id
# この病院のデータのみにフィルタリング
self.labels = self.labels[
self.labels['hospital_id'] == hospital_id
].reset_index(drop=True)
self.classes = [
'Normal', 'Pneumonia', 'COVID-19', 'Tuberculosis'
]
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
img_path = os.path.join(
self.data_dir, self.labels.iloc[idx]['filename']
)
image = Image.open(img_path).convert('RGB')
label = self.labels.iloc[idx]['label']
if self.transform:
image = self.transform(image)
return image, label
def build_chest_model(num_classes: int = 4, pretrained: bool = True):
"""ResNet50ベースの胸部X線分類モデル"""
model = models.resnet50(pretrained=pretrained)
in_features = model.fc.in_features
model.fc = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(in_features, 256),
nn.ReLU(),
nn.Linear(256, num_classes)
)
return model
def train_hospital_federation():
"""病院連合学習を実行"""
print("=== 多施設連合学習 ===")
print("各病院の患者データはその場所を離れません。")
print("モデルパラメータの更新のみが集約されます。\n")
global_model = build_chest_model(num_classes=4)
total_params = sum(p.numel() for p in global_model.parameters())
print(f"グローバルモデルパラメータ数: {total_params:,}")
print(
f"通信コスト (FP32): "
f"{total_params * 4 / 1024 / 1024:.1f} MB/ラウンド"
)
10. 連合学習の未来
連合学習はAIとプライバシーの対立を解消するための主要技術として確立されています。注目すべきトレンド:
クロスデバイス vs クロスサイロ
- クロスデバイス:数百万のモバイルデバイスが参加(GoogleのGboard)
- クロスサイロ:少数の機関(病院、銀行)が参加——信頼性が高いが規模は小さい
FL + LLM
大規模言語モデル(LLM)の台頭により、FLはさらに重要になっています。ユーザーの会話でモデルをファインチューニングする場合、FLにより会話がユーザーのデバイスから出ないことが保証されます。パラメータ効率的ファインチューニング(PEFT、LoRA)との組み合わせで通信コストをさらに削減できます。
規制フレンドリーなAI
GDPRやHIPAAなどの規制が強化される中、FLは実用的なコンプライアンスソリューションになりつつあります。ヘルスケア、金融、法律分野でのFL採用が急速に拡大することが期待されます。
参考文献
- McMahan, H. B., et al. (2017). Communication-Efficient Learning of Deep Networks from Decentralized Data. AISTATS 2017.
- Li, T., et al. (2020). Federated Optimization in Heterogeneous Networks (FedProx). MLSys 2020.
- Karimireddy, S. P., et al. (2020). SCAFFOLD: Stochastic Controlled Averaging for Federated Learning. ICML 2020.
- Bonawitz, K., et al. (2017). Practical Secure Aggregation for Privacy-Preserving Machine Learning. CCS 2017.
- Flower Framework: https://flower.ai/docs/
- Opacus (PyTorch DP): https://opacus.ai/