- Authors

- Name
- Youngju Kim
- @fjvbn20031
연합 학습(Federated Learning) 완전 가이드: 프라이버시 보존 분산 AI
현대 AI의 가장 큰 아이러니 중 하나는 모델의 성능이 올라갈수록 더 많은 데이터가 필요하고, 더 많은 데이터를 모을수록 프라이버시 침해 위험이 커진다는 것이다. 병원은 환자 데이터를 공유할 수 없고, 스마트폰 제조사는 사용자의 타이핑 패턴을 서버로 보낼 수 없으며, 금융기관은 거래 내역을 경쟁사와 공유할 수 없다.
연합 학습(Federated Learning, FL)은 이 딜레마를 해결하는 패러다임이다. 데이터를 중앙 서버로 보내는 대신, 모델을 데이터가 있는 곳으로 보내서 학습한 뒤 모델의 업데이트(가중치 변화)만 집계한다. "데이터는 움직이지 않고, 지능만 움직인다"는 개념이다.
이 가이드에서는 FL의 이론적 기반부터 실전 구현까지 완전히 다룬다.
1. 연합 학습 기초
1.1 전통적 중앙집중식 학습의 문제
전통적인 머신러닝 파이프라인을 생각해보자. 수천 개의 병원에서 환자 데이터를 수집해 중앙 서버에 저장하고, 그 데이터로 진단 모델을 훈련한다. 이 방식의 문제는 무엇인가?
데이터 프라이버시 문제
환자 데이터, 금융 거래 내역, 개인 통신 내용 등은 매우 민감한 정보다. 이런 데이터를 중앙 서버로 전송하면 다음과 같은 위험이 생긴다.
- 전송 중 도청(eavesdropping) 위험
- 서버 해킹으로 인한 대규모 데이터 유출
- 데이터를 소유한 기관의 신뢰 상실
- 규제 위반으로 인한 법적 제재
법적 규제
전 세계적으로 데이터 프라이버시 규제가 강화되고 있다.
- GDPR (General Data Protection Regulation): 유럽연합의 개인정보보호법. 데이터 처리 목적의 명시, 동의 요건, 데이터 최소화 원칙 등을 규정한다.
- HIPAA (Health Insurance Portability and Accountability Act): 미국의 의료정보보호법. 환자의 건강 정보(PHI)를 보호한다.
- 개인정보보호법(Korea): 한국의 개인정보보호법은 개인정보의 수집, 이용, 제공에 엄격한 제한을 둔다.
통신 비용
수백만 개의 엣지 디바이스(스마트폰, IoT 센서)에서 데이터를 중앙으로 전송하려면 막대한 네트워크 대역폭이 필요하다. 이미지, 오디오 같은 대용량 데이터는 더욱 큰 문제가 된다.
1.2 연합 학습의 핵심 아이디어
연합 학습은 2016년 Google의 McMahan 등이 제안한 개념으로, 다음 원칙에 기반한다.
"데이터는 현장에 남고, 지식(모델 업데이트)만 이동한다"
FL의 기본 프로세스는 다음과 같다.
- 초기화: 중앙 서버가 글로벌 모델을 초기화한다.
- 배포: 서버가 선택된 클라이언트들에게 현재 글로벌 모델을 배포한다.
- 로컬 학습: 각 클라이언트는 자신의 로컬 데이터로 모델을 훈련한다.
- 업로드: 클라이언트들이 모델 업데이트(그래디언트 또는 가중치 차이)를 서버로 전송한다.
- 집계: 서버가 업데이트들을 집계(평균 등)하여 글로벌 모델을 갱신한다.
- 반복: 수렴할 때까지 2-5단계를 반복한다.
이 방식에서 원본 데이터는 클라이언트 장치를 떠나지 않는다. 전송되는 것은 모델 파라미터의 업데이트뿐이다.
1.3 연합 학습의 응용 분야
모바일 / 엣지 디바이스
Google은 Gboard(모바일 키보드)에 FL을 적용했다. 사용자의 타이핑 패턴을 서버로 보내지 않고, 기기에서 직접 다음 단어 예측 모델을 개선한다. 수백만 명의 사용자 데이터를 활용하면서도 개인 정보는 기기 밖으로 나가지 않는다.
의료 분야
여러 병원의 환자 데이터를 공유하지 않고도 더 나은 진단 AI를 만들 수 있다. 예를 들어 희귀 질환의 경우, 단일 병원의 데이터만으로는 모델 학습이 어렵지만, FL로 여러 병원의 지식을 결합할 수 있다.
금융
사기 탐지, 신용 평가 등에서 여러 금융기관이 고객 데이터를 공유하지 않고도 협력할 수 있다. 특히 국경을 넘는 금융 거래에서 각국의 데이터 주권을 지키면서 공동 모델을 학습할 수 있다.
자율주행
여러 자동차 제조사가 각자의 주행 데이터를 공유하지 않으면서 도로 위험 상황 감지 모델을 공동으로 개선할 수 있다.
2. 연합 학습 아키텍처
2.1 클라이언트-서버 구조
가장 일반적인 FL 아키텍처는 중앙 집계 서버(aggregation server)와 여러 클라이언트(client)로 구성된다.
┌─────────────────┐
│ 중앙 서버 │
│ (Aggregator) │
└────────┬────────┘
│ 모델 배포 / 업데이트 집계
┌────────┼────────┐
↓ ↓ ↓
┌─────────┐ ┌─────────┐ ┌─────────┐
│클라이언트1│ │클라이언트2│ │클라이언트3│
│(로컬 데이│ │(로컬 데이│ │(로컬 데이│
│터로 학습)│ │터로 학습)│ │터로 학습)│
└─────────┘ └─────────┘ └─────────┘
서버의 역할
- 글로벌 모델 유지 및 관리
- 각 라운드에서 참여할 클라이언트 선택
- 클라이언트 업데이트 집계
- 집계된 모델을 클라이언트에 배포
클라이언트의 역할
- 로컬 데이터 보유
- 서버로부터 받은 모델을 로컬 데이터로 미세 조정
- 업데이트된 모델(또는 그래디언트) 전송
2.2 수평 연합 학습 (Horizontal Federated Learning)
수평 FL은 모든 클라이언트가 동일한 특성(feature) 공간을 가지지만 다른 데이터 샘플을 보유한 경우에 사용된다. 예를 들어, 여러 병원이 동일한 진단 항목(혈압, 혈당, 나이 등)을 측정하지만 서로 다른 환자를 가진 경우이다.
클라이언트 1: [특성1, 특성2, 특성3] × [샘플 1~1000]
클라이언트 2: [특성1, 특성2, 특성3] × [샘플 1001~2000]
클라이언트 3: [특성1, 특성2, 특성3] × [샘플 2001~3000]
같은 특성 공간, 다른 샘플 공간. 가장 일반적인 FL 형태이다.
2.3 수직 연합 학습 (Vertical Federated Learning)
수직 FL은 클라이언트들이 동일한 사용자(데이터 샘플)를 보유하지만 서로 다른 특성을 가진 경우이다. 예를 들어, 은행은 사용자의 금융 정보를, 병원은 동일 사용자의 의료 정보를 가진 경우이다.
클라이언트 A (은행): [금융 특성] × [사용자 1~10000]
클라이언트 B (병원): [의료 특성] × [사용자 1~10000]
같은 샘플 공간, 다른 특성 공간.
수직 FL은 더 복잡한 프로토콜이 필요하다. 라벨을 가진 클라이언트와 특성만 가진 클라이언트 간의 협력을 위해 암호화 기술이 필요하다.
2.4 연합 이전 학습 (Federated Transfer Learning)
클라이언트들이 부분적으로 겹치는 샘플과 특성 공간을 가질 때, 이전 학습(Transfer Learning) 기법과 결합하는 방식이다. 데이터 오버랩이 거의 없는 상황에서도 FL을 적용할 수 있다.
3. FedAvg 알고리즘
3.1 McMahan et al. (2017) 원본 알고리즘
FedAvg(Federated Averaging)는 FL의 기초가 되는 알고리즘으로, 2017년 Google의 McMahan 등이 발표했다. 핵심 아이디어는 각 클라이언트가 여러 번의 로컬 SGD 업데이트를 수행한 후, 서버에서 가중치를 평균내는 것이다.
알고리즘 개요
서버 실행:
w_0 초기화
for 라운드 t = 1, 2, ..., T:
m = max(C × K, 1) // C: 참여 비율, K: 전체 클라이언트 수
S_t = 무작위로 m개 클라이언트 선택
for 각 클라이언트 k in S_t (병렬):
w_{t+1}^k = ClientUpdate(k, w_t)
w_{t+1} = Σ (n_k / n) × w_{t+1}^k // 가중 평균
클라이언트 k 실행:
B = 로컬 데이터를 배치로 분할
for 로컬 에폭 e = 1, ..., E:
for 배치 b in B:
w = w - η × ∇ℓ(w; b)
return w
핵심 파라미터
- C: 각 라운드에서 참여하는 클라이언트의 비율 (0 < C ≤ 1)
- E: 각 클라이언트의 로컬 에폭 수
- B: 로컬 미니배치 크기
- η: 학습률
E=1이고 B=전체 데이터이면 FedSGD와 동일하다. E를 늘릴수록 통신 횟수가 줄지만 클라이언트 드리프트(drift) 위험이 커진다.
3.2 완전한 FedAvg 구현
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import numpy as np
from copy import deepcopy
from typing import List, Dict, Tuple
import random
# ========== 모델 정의 ==========
class SimpleNet(nn.Module):
"""간단한 분류 네트워크"""
def __init__(self, input_dim: int, hidden_dim: int, num_classes: int):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, num_classes)
)
def forward(self, x):
return self.net(x)
# ========== 클라이언트 ==========
class FLClient:
"""연합 학습 클라이언트"""
def __init__(
self,
client_id: int,
dataset,
device: str = 'cpu'
):
self.client_id = client_id
self.dataset = dataset
self.device = device
def local_train(
self,
model: nn.Module,
local_epochs: int,
batch_size: int,
lr: float
) -> Tuple[Dict, float]:
"""
로컬 데이터로 모델 훈련
Returns: (업데이트된 가중치, 로컬 손실)
"""
model = deepcopy(model).to(self.device)
model.train()
loader = DataLoader(
self.dataset,
batch_size=batch_size,
shuffle=True
)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
criterion = nn.CrossEntropyLoss()
total_loss = 0.0
n_batches = 0
for epoch in range(local_epochs):
for X, y in loader:
X, y = X.to(self.device), y.to(self.device)
optimizer.zero_grad()
output = model(X)
loss = criterion(output, y)
loss.backward()
optimizer.step()
total_loss += loss.item()
n_batches += 1
avg_loss = total_loss / max(n_batches, 1)
return model.state_dict(), avg_loss
def evaluate(self, model: nn.Module) -> Tuple[float, float]:
"""로컬 데이터로 모델 평가"""
model = deepcopy(model).to(self.device)
model.eval()
loader = DataLoader(self.dataset, batch_size=64)
criterion = nn.CrossEntropyLoss()
total_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for X, y in loader:
X, y = X.to(self.device), y.to(self.device)
output = model(X)
loss = criterion(output, y)
total_loss += loss.item()
_, predicted = output.max(1)
total += y.size(0)
correct += predicted.eq(y).sum().item()
return total_loss / len(loader), correct / total
# ========== 서버 ==========
class FedAvgServer:
"""FedAvg 서버"""
def __init__(
self,
global_model: nn.Module,
clients: List[FLClient],
fraction: float = 0.1, # C
device: str = 'cpu'
):
self.global_model = global_model.to(device)
self.clients = clients
self.fraction = fraction
self.device = device
self.round_history = []
def select_clients(self) -> List[FLClient]:
"""각 라운드에서 클라이언트 선택"""
m = max(int(self.fraction * len(self.clients)), 1)
return random.sample(self.clients, m)
def aggregate(
self,
client_weights: List[Dict],
client_sizes: List[int]
) -> Dict:
"""
가중 평균으로 모델 집계 (FedAvg)
n_k / n 비율로 가중 평균
"""
total_size = sum(client_sizes)
aggregated = {}
for key in client_weights[0].keys():
aggregated[key] = torch.zeros_like(
client_weights[0][key], dtype=torch.float32
)
for w, size in zip(client_weights, client_sizes):
weight = size / total_size
aggregated[key] += weight * w[key].float()
return aggregated
def train_round(
self,
local_epochs: int = 5,
batch_size: int = 32,
lr: float = 0.01
) -> Dict:
"""하나의 FL 라운드 실행"""
selected = self.select_clients()
client_weights = []
client_sizes = []
client_losses = []
for client in selected:
weights, loss = client.local_train(
self.global_model, local_epochs, batch_size, lr
)
client_weights.append(weights)
client_sizes.append(len(client.dataset))
client_losses.append(loss)
# 집계
new_weights = self.aggregate(client_weights, client_sizes)
self.global_model.load_state_dict(new_weights)
round_info = {
'num_clients': len(selected),
'avg_local_loss': np.mean(client_losses),
'client_losses': client_losses
}
self.round_history.append(round_info)
return round_info
def evaluate_global(self, test_loader: DataLoader) -> Tuple[float, float]:
"""글로벌 모델 평가"""
self.global_model.eval()
criterion = nn.CrossEntropyLoss()
total_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for X, y in test_loader:
X, y = X.to(self.device), y.to(self.device)
output = self.global_model(X)
loss = criterion(output, y)
total_loss += loss.item()
_, predicted = output.max(1)
total += y.size(0)
correct += predicted.eq(y).sum().item()
return total_loss / len(test_loader), correct / total
def federated_train(
self,
num_rounds: int,
local_epochs: int = 5,
batch_size: int = 32,
lr: float = 0.01,
test_loader: DataLoader = None
):
"""전체 FL 훈련 루프"""
print(f"연합 학습 시작: {num_rounds} 라운드, {len(self.clients)} 클라이언트")
for round_num in range(1, num_rounds + 1):
round_info = self.train_round(local_epochs, batch_size, lr)
if test_loader and round_num % 10 == 0:
test_loss, test_acc = self.evaluate_global(test_loader)
print(
f"라운드 {round_num:3d}/{num_rounds} | "
f"클라이언트: {round_info['num_clients']} | "
f"로컬 손실: {round_info['avg_local_loss']:.4f} | "
f"테스트 정확도: {test_acc:.4f}"
)
print("연합 학습 완료!")
# ========== 데이터 분산 (Non-IID 시뮬레이션) ==========
def create_non_iid_partition(
dataset,
num_clients: int,
num_classes: int,
alpha: float = 0.5
) -> List[List[int]]:
"""
Dirichlet 분포를 사용한 Non-IID 데이터 분할
alpha가 낮을수록 더 불균형한 분포
"""
labels = np.array([dataset[i][1] for i in range(len(dataset))])
client_indices = [[] for _ in range(num_clients)]
for cls in range(num_classes):
cls_indices = np.where(labels == cls)[0]
np.random.shuffle(cls_indices)
# Dirichlet 분포로 클라이언트에 할당
proportions = np.random.dirichlet(
[alpha] * num_clients
)
proportions = (proportions * len(cls_indices)).astype(int)
proportions[-1] = len(cls_indices) - proportions[:-1].sum()
start = 0
for k, prop in enumerate(proportions):
client_indices[k].extend(
cls_indices[start:start + prop].tolist()
)
start += prop
return client_indices
# ========== 메인 실행 ==========
def run_fedavg_demo():
"""FedAvg 데모 실행"""
import torchvision
import torchvision.transforms as transforms
# MNIST 데이터셋 로드
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = torchvision.datasets.MNIST(
'./data', train=True, download=True, transform=transform
)
test_dataset = torchvision.datasets.MNIST(
'./data', train=False, transform=transform
)
# Non-IID 데이터 분할
num_clients = 20
client_indices = create_non_iid_partition(
train_dataset, num_clients, num_classes=10, alpha=0.5
)
# 클라이언트 생성
clients = []
for k in range(num_clients):
subset = Subset(train_dataset, client_indices[k])
clients.append(FLClient(k, subset))
print(f"클라이언트 수: {num_clients}")
print(f"클라이언트당 평균 샘플: {np.mean([len(c.dataset) for c in clients]):.1f}")
# 글로벌 모델 생성
global_model = SimpleNet(784, 256, 10)
# 서버 생성 (10% 클라이언트 선택)
server = FedAvgServer(global_model, clients, fraction=0.1)
# 테스트 로더
test_loader = DataLoader(test_dataset, batch_size=256)
# FL 훈련
server.federated_train(
num_rounds=100,
local_epochs=5,
batch_size=32,
lr=0.01,
test_loader=test_loader
)
# 최종 평가
_, final_acc = server.evaluate_global(test_loader)
print(f"\n최종 테스트 정확도: {final_acc:.4f}")
if __name__ == '__main__':
run_fedavg_demo()
4. FL의 도전 과제
4.1 데이터 이질성 (Non-IID 문제)
FL에서 가장 큰 기술적 도전은 Non-IID(Non-Independent and Identically Distributed) 데이터이다. 실제 환경에서 각 클라이언트의 데이터 분포는 다르다.
Non-IID의 유형
- 특성 분포 차이: 클라이언트마다 입력 분포가 다름 (예: 지역별 날씨 패턴)
- 레이블 분포 차이: 클라이언트마다 클래스 비율이 다름 (스마트폰 사용자마다 자주 쓰는 단어가 다름)
- 개념 드리프트: 같은 입력에 대해 다른 레이블 (지역마다 다른 문화적 맥락)
- 수량 불균형: 클라이언트마다 데이터 양이 극단적으로 다름
Non-IID 데이터에서 FedAvg를 사용하면 클라이언트 드리프트(client drift) 문제가 발생한다. 각 클라이언트의 로컬 최적점이 글로벌 최적점과 달라지면서 집계가 효과적이지 않게 된다.
4.2 시스템 이질성
실제 FL 시스템에서는 디바이스의 성능, 배터리 상태, 네트워크 연결 상태가 매우 다양하다.
- 계산 이질성: GPU가 있는 서버와 저사양 모바일 기기가 동시에 참여
- 네트워크 이질성: 고속 유선과 불안정한 모바일 네트워크 혼재
- 메모리 이질성: 일부 클라이언트는 전체 모델을 올릴 수 없을 수 있음
4.3 스트래글러(Straggler) 문제
일부 클라이언트가 느리거나 응답하지 않으면 전체 훈련이 지연된다. Synchronous FL에서는 모든 선택된 클라이언트가 업데이트를 반환할 때까지 기다려야 한다. 이에 대한 해결책으로:
- 비동기 FL (Asynchronous FL): 응답한 클라이언트의 업데이트만 즉시 집계
- 타임아웃 설정: 일정 시간 내에 응답한 클라이언트만 사용
- FedProx: 로컬 업데이트에 근접 항을 추가하여 스트래글러 허용
5. 고급 FL 알고리즘
5.1 FedProx: Non-IID 문제 해결
FedProx는 Li 등이 2020년에 제안한 알고리즘으로, 로컬 최적화에 **근접 항(proximal term)**을 추가한다. 이 항은 로컬 모델이 글로벌 모델에서 너무 멀리 벗어나지 않도록 제한한다.
FedProx 목적 함수
로컬 목적 함수에 근접 항을 추가한다.
h_k(w; w^t) = F_k(w) + (mu/2) × ||w - w^t||^2
여기서 mu는 근접 항의 강도를 조절하는 하이퍼파라미터이다. mu=0이면 FedAvg와 동일하다.
class FedProxClient(FLClient):
"""FedProx 클라이언트: 근접 항 추가"""
def local_train_prox(
self,
model: nn.Module,
global_weights: Dict,
local_epochs: int,
batch_size: int,
lr: float,
mu: float = 0.01 # 근접 항 가중치
) -> Tuple[Dict, float]:
"""
근접 항을 포함한 로컬 훈련
h_k(w) = F_k(w) + (mu/2) * ||w - w^t||^2
"""
model = deepcopy(model).to(self.device)
model.train()
# 글로벌 가중치를 기준점으로 저장
global_model = deepcopy(model)
global_model.load_state_dict(global_weights)
for param in global_model.parameters():
param.requires_grad = False
loader = DataLoader(
self.dataset,
batch_size=batch_size,
shuffle=True
)
optimizer = optim.SGD(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
total_loss = 0.0
n_batches = 0
for epoch in range(local_epochs):
for X, y in loader:
X, y = X.to(self.device), y.to(self.device)
optimizer.zero_grad()
output = model(X)
task_loss = criterion(output, y)
# 근접 항: (mu/2) * ||w - w_global||^2
prox_loss = 0.0
for w, w_global in zip(
model.parameters(),
global_model.parameters()
):
prox_loss += (mu / 2) * torch.norm(w - w_global) ** 2
total_batch_loss = task_loss + prox_loss
total_batch_loss.backward()
optimizer.step()
total_loss += task_loss.item()
n_batches += 1
avg_loss = total_loss / max(n_batches, 1)
return model.state_dict(), avg_loss
5.2 SCAFFOLD: 클라이언트 드리프트 수정
SCAFFOLD(Stochastic Controlled Averaging for Federated Learning)는 제어 변량(control variates)을 사용하여 클라이언트 드리프트를 직접 수정한다. 각 클라이언트와 서버가 제어 변량 c_k와 c를 유지하여 그래디언트 편향을 보정한다.
class ScaffoldClient:
"""SCAFFOLD 클라이언트"""
def __init__(self, client_id, dataset, device='cpu'):
self.client_id = client_id
self.dataset = dataset
self.device = device
# 클라이언트 제어 변량 초기화
self.c_k = None
def init_control_variate(self, model: nn.Module):
"""제어 변량 초기화"""
self.c_k = {
name: torch.zeros_like(param)
for name, param in model.named_parameters()
}
def local_train_scaffold(
self,
model: nn.Module,
server_control: Dict,
local_epochs: int,
batch_size: int,
lr: float
) -> Tuple[Dict, Dict, float]:
"""
SCAFFOLD 로컬 훈련
Returns: (업데이트 가중치, 제어 변량 업데이트, 손실)
"""
if self.c_k is None:
self.init_control_variate(model)
model = deepcopy(model).to(self.device)
model.train()
initial_weights = deepcopy(model.state_dict())
loader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True)
criterion = nn.CrossEntropyLoss()
total_loss = 0.0
n_batches = 0
total_steps = local_epochs * len(loader)
for epoch in range(local_epochs):
for X, y in loader:
X, y = X.to(self.device), y.to(self.device)
# 수동 파라미터 업데이트 (SCAFFOLD 보정 포함)
output = model(X)
loss = criterion(output, y)
loss.backward()
with torch.no_grad():
for name, param in model.named_parameters():
if param.grad is not None:
# SCAFFOLD 보정: g - c_k + c
correction = (
self.c_k[name].to(self.device)
- server_control[name].to(self.device)
)
param -= lr * (param.grad + correction)
param.grad.zero_()
total_loss += loss.item()
n_batches += 1
final_weights = model.state_dict()
# 제어 변량 업데이트
# c_k+ = c_k - c + (1 / (K * lr)) * (w_0 - w_K)
new_c_k = {}
c_k_diff = {}
for name in self.c_k:
w_diff = (
initial_weights[name].float() - final_weights[name].float()
)
new_c_k[name] = (
self.c_k[name]
- server_control[name]
+ w_diff / (total_steps * lr)
)
c_k_diff[name] = new_c_k[name] - self.c_k[name]
self.c_k = new_c_k
avg_loss = total_loss / max(n_batches, 1)
return final_weights, c_k_diff, avg_loss
6. 차등 프라이버시 (Differential Privacy)
6.1 DP의 수학적 정의
차등 프라이버시(DP)는 개인정보 보호를 위한 수학적 프레임워크이다. 직관적으로, "하나의 데이터 포인트가 추가되거나 제거되어도 출력의 분포가 크게 변하지 않는다"는 보장이다.
엡실론-델타 DP 정의
랜덤화 메커니즘 M이 다음 조건을 만족하면 (엡실론, 델타)-DP를 만족한다.
모든 인접 데이터셋 D, D' 및 출력 집합 S에 대해:
Pr[M(D) ∈ S] ≤ exp(ε) × Pr[M(D') ∈ S] + δ
- ε (엡실론): 프라이버시 예산. 값이 낮을수록 강한 프라이버시 보장
- δ (델타): 실패 확률. 보통 1/|데이터셋| 이하로 설정
6.2 가우시안 메커니즘과 클리핑
FL에서 DP를 적용하려면 각 클라이언트의 업데이트에 노이즈를 추가해야 한다.
그래디언트 클리핑: 먼저 그래디언트의 L2 노름을 최대값 C로 클리핑한다.
g_clipped = g × min(1, C / ||g||_2)
노이즈 추가: 클리핑된 그래디언트에 가우시안 노이즈를 추가한다.
g_dp = g_clipped + N(0, σ^2 × C^2 × I)
여기서 σ는 노이즈 배율(noise multiplier)이다.
6.3 DP-FL 구현
import torch
import torch.nn as nn
from opacus import PrivacyEngine
from opacus.validators import ModuleValidator
def make_private_model(model: nn.Module) -> nn.Module:
"""Opacus 호환 모델로 변환 (BatchNorm -> GroupNorm)"""
model = ModuleValidator.fix(model)
return model
class DPFLClient:
"""차등 프라이버시를 적용한 FL 클라이언트"""
def __init__(
self,
client_id: int,
dataset,
target_epsilon: float = 1.0, # 프라이버시 예산
target_delta: float = 1e-5, # 실패 확률
max_grad_norm: float = 1.0, # 그래디언트 클리핑 임계값
device: str = 'cpu'
):
self.client_id = client_id
self.dataset = dataset
self.target_epsilon = target_epsilon
self.target_delta = target_delta
self.max_grad_norm = max_grad_norm
self.device = device
def dp_local_train(
self,
model: nn.Module,
local_epochs: int,
batch_size: int,
lr: float
) -> Tuple[Dict, float, float]:
"""
차등 프라이버시 로컬 훈련
Returns: (가중치, 손실, 소비된 엡실론)
"""
model = make_private_model(deepcopy(model)).to(self.device)
model.train()
loader = DataLoader(
self.dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True # Opacus 요구사항
)
optimizer = optim.SGD(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
# Opacus PrivacyEngine 부착
privacy_engine = PrivacyEngine()
model, optimizer, loader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
data_loader=loader,
epochs=local_epochs,
target_epsilon=self.target_epsilon,
target_delta=self.target_delta,
max_grad_norm=self.max_grad_norm
)
total_loss = 0.0
n_batches = 0
for epoch in range(local_epochs):
for X, y in loader:
X, y = X.to(self.device), y.to(self.device)
optimizer.zero_grad()
output = model(X)
loss = criterion(output, y)
loss.backward()
optimizer.step()
total_loss += loss.item()
n_batches += 1
epsilon_used = privacy_engine.get_epsilon(self.target_delta)
avg_loss = total_loss / max(n_batches, 1)
# Opacus 래퍼 제거하여 순수 가중치 반환
clean_weights = {
k.replace('_module.', ''): v
for k, v in model.state_dict().items()
}
return clean_weights, avg_loss, epsilon_used
class DPFLServer:
"""DP-FL 서버: 서버 측 DP 집계"""
def __init__(
self,
global_model: nn.Module,
clients: List['DPFLClient'],
noise_multiplier: float = 0.5,
max_grad_norm: float = 1.0,
device: str = 'cpu'
):
self.global_model = global_model.to(device)
self.clients = clients
self.noise_multiplier = noise_multiplier
self.max_grad_norm = max_grad_norm
self.device = device
def clip_and_aggregate(
self,
client_weights: List[Dict],
reference_weights: Dict
) -> Dict:
"""
서버 측 클리핑 및 노이즈 추가
각 업데이트를 클리핑한 후 가우시안 노이즈 추가
"""
n_clients = len(client_weights)
# 업데이트 계산 (delta_w = w_k - w_global)
updates = []
for w in client_weights:
delta = {
k: w[k].float() - reference_weights[k].float()
for k in reference_weights
}
updates.append(delta)
# L2 노름 클리핑
clipped_updates = []
for delta in updates:
total_norm = torch.sqrt(
sum(torch.norm(v) ** 2 for v in delta.values())
)
clip_factor = min(1.0, self.max_grad_norm / (total_norm + 1e-8))
clipped = {k: v * clip_factor for k, v in delta.items()}
clipped_updates.append(clipped)
# 합산
summed = {}
for key in reference_weights:
summed[key] = sum(u[key] for u in clipped_updates)
# 가우시안 노이즈 추가
sigma = self.noise_multiplier * self.max_grad_norm
noisy = {}
for key in summed:
noise = torch.randn_like(summed[key]) * sigma
noisy[key] = (summed[key] + noise) / n_clients
# 글로벌 모델에 적용
aggregated = {
k: reference_weights[k].float() + noisy[k]
for k in reference_weights
}
return aggregated
7. 보안 집계 (Secure Aggregation)
7.1 암호화된 집계의 필요성
DP는 통계적 프라이버시를 보장하지만, 서버는 여전히 각 클라이언트의 업데이트를 개별적으로 볼 수 있다. 보안 집계(SecAgg)는 서버가 집계된 결과만 볼 수 있도록 하는 암호화 프로토콜이다.
7.2 비밀 분산 (Secret Sharing) 기반 SecAgg
Bonawitz et al. (2017)의 Secure Aggregation 프로토콜은 다음 원리를 사용한다.
- 클라이언트들이 서로 마스크(난수 쌍)를 교환한다.
- 각 클라이언트는 자신의 업데이트에 마스크를 더해 서버로 전송한다.
- 서버가 합산하면 마스크가 서로 상쇄되어 순수한 합산 업데이트가 나온다.
import secrets
import hashlib
class SecureAggregation:
"""
간략화된 보안 집계 시뮬레이션
실제 구현에서는 Diffie-Hellman 키 교환 사용
"""
def __init__(self, num_clients: int, seed_length: int = 32):
self.num_clients = num_clients
self.seed_length = seed_length
def generate_pairwise_masks(
self,
client_id: int,
model_shape: Dict[str, torch.Size],
shared_seeds: Dict
) -> Dict:
"""
클라이언트 쌍 간의 마스크 생성
client i와 j 사이: i>j면 더하고, i<j면 빼기
"""
mask = {k: torch.zeros(v) for k, v in model_shape.items()}
for other_id, seed in shared_seeds.items():
# 공유된 시드에서 의사 난수 마스크 생성
torch.manual_seed(seed)
direction = 1 if client_id > other_id else -1
for key in mask:
mask[key] += direction * torch.randn(model_shape[key])
return mask
def client_mask_update(
self,
local_update: Dict,
mask: Dict
) -> Dict:
"""업데이트에 마스크 적용"""
return {k: local_update[k] + mask[k] for k in local_update}
def server_aggregate_masked(
self,
masked_updates: List[Dict]
) -> Dict:
"""마스크된 업데이트 합산 (마스크는 서로 상쇄됨)"""
aggregated = {}
for key in masked_updates[0]:
aggregated[key] = sum(u[key] for u in masked_updates)
aggregated[key] /= len(masked_updates)
return aggregated
7.3 동형 암호 (Homomorphic Encryption) 개요
동형 암호(HE)는 암호화된 상태에서 연산을 수행할 수 있는 암호 체계이다. FL에서 HE를 사용하면 서버가 클라이언트 업데이트를 복호화하지 않고도 집계할 수 있다.
부분 동형 암호(PHE): 덧셈 또는 곱셈 중 하나만 지원 (Paillier 암호) 완전 동형 암호(FHE): 임의의 연산 지원 (CKKS, BGV) - 계산 비용이 매우 높음
실용적인 FL에서는 덧셈만 지원하는 PHE로도 집계(가중 평균)를 구현할 수 있다.
8. Flower 프레임워크
8.1 Flower 소개
Flower(flwr)는 연합 학습을 위한 파이썬 프레임워크로, 프레임워크 독립적이다. PyTorch, TensorFlow, JAX 등 다양한 ML 프레임워크와 함께 사용할 수 있다.
pip install flwr
pip install flwr[simulation]
8.2 Flower 클라이언트 구현
import flwr as fl
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from typing import Dict, List, Tuple
import numpy as np
class FlowerClient(fl.client.NumPyClient):
"""Flower FL 클라이언트"""
def __init__(
self,
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
device: str = 'cpu'
):
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.criterion = nn.CrossEntropyLoss()
def get_parameters(self, config: Dict) -> List[np.ndarray]:
"""현재 모델 파라미터를 NumPy 배열로 반환"""
return [
val.cpu().numpy()
for _, val in self.model.state_dict().items()
]
def set_parameters(self, parameters: List[np.ndarray]):
"""NumPy 배열로 모델 파라미터 설정"""
params_dict = zip(self.model.state_dict().keys(), parameters)
state_dict = {
k: torch.tensor(v)
for k, v in params_dict
}
self.model.load_state_dict(state_dict, strict=True)
def fit(
self,
parameters: List[np.ndarray],
config: Dict
) -> Tuple[List[np.ndarray], int, Dict]:
"""서버로부터 받은 파라미터로 로컬 훈련"""
# 글로벌 파라미터 설정
self.set_parameters(parameters)
# 설정 파라미터 파싱
local_epochs = int(config.get('local_epochs', 1))
lr = float(config.get('lr', 0.01))
# 로컬 훈련
self.model.train()
optimizer = torch.optim.SGD(
self.model.parameters(), lr=lr, momentum=0.9
)
train_loss = 0.0
for epoch in range(local_epochs):
for X, y in self.train_loader:
X, y = X.to(self.device), y.to(self.device)
optimizer.zero_grad()
loss = self.criterion(self.model(X), y)
loss.backward()
optimizer.step()
train_loss += loss.item()
return (
self.get_parameters(config={}),
len(self.train_loader.dataset),
{'train_loss': train_loss}
)
def evaluate(
self,
parameters: List[np.ndarray],
config: Dict
) -> Tuple[float, int, Dict]:
"""서버로부터 받은 파라미터로 평가"""
self.set_parameters(parameters)
self.model.eval()
val_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for X, y in self.val_loader:
X, y = X.to(self.device), y.to(self.device)
output = self.model(X)
val_loss += self.criterion(output, y).item()
_, predicted = output.max(1)
total += y.size(0)
correct += predicted.eq(y).sum().item()
accuracy = correct / total
return (
val_loss / len(self.val_loader),
len(self.val_loader.dataset),
{'accuracy': accuracy}
)
8.3 Flower 서버 전략 구현
from flwr.server.strategy import FedAvg
from flwr.common import Parameters, FitIns, FitRes, EvaluateIns, EvaluateRes
from flwr.server.client_proxy import ClientProxy
class CustomFedAvgStrategy(FedAvg):
"""커스텀 FedAvg 전략"""
def __init__(
self,
fraction_fit: float = 0.1,
fraction_evaluate: float = 0.1,
min_fit_clients: int = 2,
min_evaluate_clients: int = 2,
min_available_clients: int = 2,
initial_parameters=None,
):
super().__init__(
fraction_fit=fraction_fit,
fraction_evaluate=fraction_evaluate,
min_fit_clients=min_fit_clients,
min_evaluate_clients=min_evaluate_clients,
min_available_clients=min_available_clients,
initial_parameters=initial_parameters,
)
self.round_metrics = []
def configure_fit(
self,
server_round: int,
parameters: Parameters,
client_manager
) -> List[Tuple[ClientProxy, FitIns]]:
"""각 라운드의 훈련 설정"""
# 라운드가 늘어날수록 학습률 감소
lr = 0.01 * (0.99 ** server_round)
config = {
'local_epochs': 5,
'lr': lr,
'round': server_round,
}
fit_ins = FitIns(parameters, config)
clients = client_manager.sample(
num_clients=max(
int(client_manager.num_available() * self.fraction_fit), 1
),
min_num_clients=self.min_fit_clients
)
return [(client, fit_ins) for client in clients]
def aggregate_evaluate(
self,
server_round: int,
results: List[Tuple[ClientProxy, EvaluateRes]],
failures
):
"""평가 결과 집계 및 로깅"""
aggregated_loss, metrics = super().aggregate_evaluate(
server_round, results, failures
)
if metrics:
print(f"라운드 {server_round}: 집계된 정확도 = {metrics.get('accuracy', 0):.4f}")
self.round_metrics.append({
'round': server_round,
'loss': aggregated_loss,
'metrics': metrics
})
return aggregated_loss, metrics
def run_flower_simulation():
"""Flower 시뮬레이션 실행"""
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Subset
# 데이터 준비
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = torchvision.datasets.MNIST(
'./data', train=True, download=True, transform=transform
)
test_dataset = torchvision.datasets.MNIST(
'./data', train=False, transform=transform
)
num_clients = 10
def client_fn(cid: str):
"""클라이언트 생성 함수"""
cid = int(cid)
# 데이터 분할
n = len(train_dataset) // num_clients
start = cid * n
indices = list(range(start, min(start + n, len(train_dataset))))
train_subset = Subset(train_dataset, indices)
val_indices = list(range(cid * 100, (cid + 1) * 100))
val_subset = Subset(test_dataset, val_indices)
model = SimpleNet(784, 256, 10)
return FlowerClient(
model=model,
train_loader=DataLoader(train_subset, batch_size=32),
val_loader=DataLoader(val_subset, batch_size=64)
).to_client()
# 글로벌 모델 파라미터 초기화
global_model = SimpleNet(784, 256, 10)
initial_params = fl.common.ndarrays_to_parameters(
[val.numpy() for val in global_model.state_dict().values()]
)
# 전략 설정
strategy = CustomFedAvgStrategy(
fraction_fit=0.5,
fraction_evaluate=0.5,
min_fit_clients=2,
min_evaluate_clients=2,
min_available_clients=num_clients,
initial_parameters=initial_params
)
# 시뮬레이션 실행
history = fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=num_clients,
config=fl.server.ServerConfig(num_rounds=50),
strategy=strategy,
)
print("\n=== Flower 시뮬레이션 완료 ===")
print(f"최종 분산 손실: {history.losses_distributed[-1]}")
9. 실전 FL 프로젝트: 병원 연합 학습
9.1 멀티-병원 흉부 X-ray 진단
여러 병원이 흉부 X-ray 데이터를 공유하지 않고 폐 질환 진단 모델을 공동 훈련하는 시나리오이다.
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import pandas as pd
class ChestXRayDataset(Dataset):
"""흉부 X-ray 데이터셋 (병원별)"""
def __init__(
self,
data_dir: str,
labels_file: str,
transform=None,
hospital_id: int = 0
):
self.data_dir = data_dir
self.labels = pd.read_csv(labels_file)
self.transform = transform
self.hospital_id = hospital_id
# 각 병원의 데이터만 필터링
self.labels = self.labels[
self.labels['hospital_id'] == hospital_id
].reset_index(drop=True)
self.classes = [
'Normal', 'Pneumonia', 'COVID-19', 'Tuberculosis'
]
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
img_path = os.path.join(
self.data_dir, self.labels.iloc[idx]['filename']
)
image = Image.open(img_path).convert('RGB')
label = self.labels.iloc[idx]['label']
if self.transform:
image = self.transform(image)
return image, label
def build_chest_model(num_classes: int = 4, pretrained: bool = True):
"""ResNet50 기반 흉부 X-ray 분류 모델"""
model = models.resnet50(pretrained=pretrained)
# 마지막 레이어 교체
in_features = model.fc.in_features
model.fc = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(in_features, 256),
nn.ReLU(),
nn.Linear(256, num_classes)
)
return model
class HospitalFLClient(FlowerClient):
"""병원용 FL 클라이언트 (클래스 불균형 처리)"""
def __init__(
self,
hospital_id: int,
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
class_weights: torch.Tensor = None,
device: str = 'cpu'
):
super().__init__(model, train_loader, val_loader, device)
self.hospital_id = hospital_id
if class_weights is not None:
self.criterion = nn.CrossEntropyLoss(
weight=class_weights.to(device)
)
else:
self.criterion = nn.CrossEntropyLoss()
def compute_class_weights(self, dataset) -> torch.Tensor:
"""클래스 불균형 보정을 위한 가중치 계산"""
labels = [dataset[i][1] for i in range(len(dataset))]
class_counts = torch.bincount(torch.tensor(labels))
weights = 1.0 / class_counts.float()
weights = weights / weights.sum()
return weights
def setup_hospital_federation(num_hospitals: int = 5):
"""
다중 병원 연합 학습 설정
실제로는 각 병원의 데이터가 로컬에 있음
"""
import torchvision.transforms as transforms
# 데이터 전처리
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
hospitals = []
for h_id in range(num_hospitals):
# 실제로는 각 병원 서버에서 실행됨
# 여기서는 시뮬레이션을 위해 같은 서버에서 실행
print(f"병원 {h_id} 데이터 준비 중...")
# 더미 데이터 생성 (실제 구현 시 실제 데이터로 교체)
# 여기서는 예시 구조만 보여줌
return hospitals
# 연합 학습 훈련 루프
def train_hospital_federation():
"""병원 연합 학습 실행"""
print("=== 멀티-병원 연합 학습 시작 ===")
print("각 병원의 환자 데이터는 외부로 전송되지 않습니다.")
print("오직 모델 파라미터 업데이트만 집계됩니다.")
global_model = build_chest_model(num_classes=4)
print(f"\n글로벌 모델 파라미터 수: {sum(p.numel() for p in global_model.parameters()):,}")
print(f"통신 비용 (FP32): {sum(p.numel() for p in global_model.parameters()) * 4 / 1024 / 1024:.1f} MB/라운드")
10. 마무리: 연합 학습의 미래
연합 학습은 AI와 프라이버시의 갈등을 해소하는 핵심 기술로 자리잡고 있다. 특히 다음 트렌드에 주목해야 한다.
크로스-디바이스 vs 크로스-사일로
- 크로스-디바이스: 수백만 개의 모바일 기기가 참여 (Google의 Gboard)
- 크로스-사일로: 소수의 기관(병원, 은행)이 참여 - 더 신뢰할 수 있지만 규모가 작음
FL + LLM
대형 언어 모델(LLM)의 등장으로 FL이 더욱 중요해졌다. 사용자의 대화 내용으로 모델을 파인튜닝할 때, FL을 사용하면 대화 내용이 서버로 전송되지 않는다. 파라미터 효율적 파인튜닝(PEFT, LoRA)과 결합하면 통신 비용도 줄일 수 있다.
규제 친화적 AI
GDPR, HIPAA 등 규제가 강화되면서 FL은 규정 준수를 위한 실용적 해결책이 되고 있다. 의료, 금융, 법률 분야에서 FL 적용이 빠르게 확산될 것으로 예상된다.
참고 자료
- McMahan, H. B., et al. (2017). Communication-Efficient Learning of Deep Networks from Decentralized Data. AISTATS 2017.
- Li, T., et al. (2020). Federated Optimization in Heterogeneous Networks (FedProx). MLSys 2020.
- Karimireddy, S. P., et al. (2020). SCAFFOLD: Stochastic Controlled Averaging for Federated Learning. ICML 2020.
- Bonawitz, K., et al. (2017). Practical Secure Aggregation for Privacy-Preserving Machine Learning. CCS 2017.
- Flower Framework: https://flower.ai/docs/
- Opacus (PyTorch DP): https://opacus.ai/