import numpy as np
def laplace_mechanism(true_value, sensitivity, epsilon):
"""
Add Laplace noise to achieve ε-differential privacy.
Args:
true_value: The actual query result
sensitivity: Maximum change from adding/removing one record
epsilon: Privacy parameter (smaller = more private)
Returns:
Noisy value satisfying ε-differential privacy
"""
scale = sensitivity / epsilon
noise = np.random.laplace(0, scale)
return true_value + noise
# Example: Count query
# Sensitivity = 1 (adding/removing one person changes count by at most 1)
true_count = 1000
epsilon = 0.1
noisy_count = laplace_mechanism(true_count, sensitivity=1, epsilon=epsilon)
print(f"True count: {true_count}")
print(f"Noisy count (ε={epsilon}): {noisy_count:.0f}")
# Multiple runs show the noise distribution
noisy_counts = [laplace_mechanism(true_count, 1, epsilon) for _ in range(1000)]
print(f"Mean of noisy counts: {np.mean(noisy_counts):.1f}")
print(f"Std of noisy counts: {np.std(noisy_counts):.1f}")19 Privacy & Federated Learning
Clinical Context: Five major hospital systems want to collaborate on an AI model for early sepsis detection. Together, they have 500,000 patient encounters—enough data to build a robust model. But legal teams at each institution won’t allow patient data to leave their systems. How can they collaborate without sharing data? This chapter covers the technical approaches that make privacy-preserving medical AI possible.
Medical AI faces a fundamental tension: better models require more data, but patient data is sensitive, regulated, and siloed across institutions. A single hospital may have thousands of cases of a rare disease; combined data across hundreds of hospitals could provide millions. Yet data sharing agreements are slow, legally complex, and sometimes impossible.
Privacy-preserving machine learning offers a path forward. Differential privacy provides mathematical guarantees about what an adversary can learn from model outputs. Federated learning enables model training without centralizing data. Secure computation allows analysis of encrypted data. These techniques don’t eliminate privacy risks, but they provide tools for managing them.
This chapter focuses on the technical approaches to privacy-preserving medical AI. We cover the mathematics of differential privacy with enough depth to reason about privacy-utility tradeoffs. We implement federated learning with working code. And we discuss how these techniques combine to enable collaborative medical AI.
19.1 The Data Access Problem
19.1.1 Why Medical AI Needs More Data
Machine learning is data-hungry. Deep learning especially benefits from scale—models trained on larger datasets generally perform better. For medical AI:
- Rare diseases: A single hospital might see 50 cases per year; useful AI needs thousands
- Subpopulation performance: Ensuring models work across demographics requires diverse data
- Generalization: Models trained at one institution often fail at others; multi-site data improves robustness
- Validation: External validation requires access to data from other sources
The institutions that need AI most—smaller hospitals, rural clinics—often have the least data to develop it.
19.1.2 Why Data Sharing Is Hard
Despite the benefits, sharing medical data faces significant barriers:
Regulatory constraints: HIPAA, GDPR, and other regulations restrict data transfer. Even “de-identified” data requires careful handling.
Institutional concerns: Hospitals view their data as a competitive asset. Sharing with academic medical centers or competitors raises strategic concerns.
Technical challenges: Healthcare data is messy, inconsistent across institutions, and difficult to harmonize.
Patient trust: Patients may consent to their data being used at their hospital but object to broader sharing.
Liability: If shared data is breached or misused, who is responsible?
19.1.3 The Privacy-Preserving Paradigm
Privacy-preserving techniques change the question from “how do we share data?” to “how do we share insights without sharing data?”
| Approach | Key Idea | When to Use |
|---|---|---|
| De-identification | Remove identifying information | Simple analyses, low-risk data |
| Differential privacy | Add calibrated noise | Query results, model training |
| Federated learning | Train at data sources, share only models | Multi-site collaboration |
| Secure computation | Compute on encrypted data | High-security requirements |
These approaches aren’t mutually exclusive—the most robust systems combine multiple techniques.
19.2 Privacy Regulations and De-identification
19.2.1 HIPAA Basics
The Health Insurance Portability and Accountability Act (HIPAA) governs protected health information (PHI) in the United States. For AI development, the key provisions are:
Protected Health Information (PHI) includes any individually identifiable health information held by covered entities. This encompasses demographics, medical records, billing information—essentially anything that connects health data to an individual.
De-identification removes the legal protections of HIPAA. HIPAA provides two paths:
Safe Harbor: Remove 18 specific identifiers (names, dates, locations, SSN, etc.) and have no actual knowledge the remaining data could identify individuals.
Expert Determination: A qualified expert certifies that re-identification risk is “very small.”
De-identified data under HIPAA is no longer PHI and can be shared more freely.
19.2.2 The Limits of De-identification
De-identification is necessary but often insufficient:
Re-identification attacks: Researchers have re-identified individuals from “anonymous” datasets by linking with public records. The Netflix Prize dataset, AOL search logs, and hospital discharge data have all been re-identified.
Medical imaging: Facial structure can be reconstructed from head CT/MRI scans. Removing obvious identifiers doesn’t prevent recognition.
Genetic data: Genomic sequences are inherently identifying—you can’t de-identify a genome.
Rare conditions: If only one patient at a hospital has a specific rare disease, any data about that disease identifies them.
Longitudinal data: Combining multiple de-identified records can enable re-identification.
19.2.3 Beyond Compliance
HIPAA compliance is a floor, not a ceiling. Even legally de-identified data can pose privacy risks. Technical privacy measures provide additional protection:
- What if an adversary has auxiliary information about a patient?
- What can be inferred from model outputs even without direct data access?
- How do we protect privacy against future re-identification techniques?
Differential privacy and federated learning address these questions.
19.3 Differential Privacy
Clinical Context: A researcher queries a hospital database: “What percentage of diabetic patients developed kidney disease within 5 years?” The answer—42%—seems innocuous. But if the researcher knows their neighbor is diabetic and is in this database, they’ve just learned something about their neighbor’s health. Differential privacy prevents such inference.
19.3.1 The Problem with “Anonymized” Data
Consider a database of patient health records. Even without names, queries can reveal individual information:
- “How many patients have HIV?” → “How many patients have HIV, excluding patient 12345?”
- The difference reveals whether patient 12345 has HIV
This differencing attack works for any query. As long as queries reveal information about the dataset, they reveal information about individuals in it.
19.3.2 The Differential Privacy Definition
Differential privacy provides a mathematical framework that limits what any query can reveal about any individual (Dwork and Roth 2014).
A mechanism \(M\) satisfies ε-differential privacy if for all datasets \(D\) and \(D'\) differing in one record, and all possible outputs \(S\):
\[\frac{P(M(D) \in S)}{P(M(D') \in S)} \leq e^\varepsilon\]
In words: the probability of any output changes by at most a factor of \(e^\varepsilon\) whether or not any individual is in the dataset.
What ε means: - ε = 0: Perfect privacy (no information revealed), but useless outputs - ε = 0.1: Strong privacy, but utility may suffer - ε = 1: Moderate privacy, reasonable utility - ε = 10: Weak privacy, good utility
The parameter ε is called the privacy budget. Smaller ε means stronger privacy but typically worse utility.
19.3.3 Intuition: Plausible Deniability
Differential privacy provides plausible deniability. If you’re in a database and the mechanism outputs some result, you can plausibly claim that the same result would have occurred even if you weren’t in the database.
An adversary who sees the output and knows everything except whether you’re in the database can’t confidently determine your presence—the output is almost as likely either way.
19.3.4 The Laplace Mechanism
How do we achieve differential privacy? The simplest approach: add noise.
The Laplace mechanism adds noise drawn from a Laplace distribution:
\[M(D) = f(D) + \text{Laplace}\left(\frac{\Delta f}{\varepsilon}\right)\]
where \(\Delta f\) is the sensitivity of the function—how much the output can change when one record changes.
19.3.5 The Privacy-Utility Tradeoff
Smaller ε means more noise, which means less accurate results:
import matplotlib.pyplot as plt
true_value = 100
sensitivity = 1
epsilons = [0.01, 0.1, 0.5, 1.0, 5.0]
n_samples = 1000
fig, axes = plt.subplots(1, len(epsilons), figsize=(15, 3))
for ax, eps in zip(axes, epsilons):
noisy_values = [laplace_mechanism(true_value, sensitivity, eps)
for _ in range(n_samples)]
ax.hist(noisy_values, bins=50, density=True)
ax.axvline(true_value, color='red', linestyle='--', label='True value')
ax.set_title(f'ε = {eps}')
ax.set_xlim([0, 200])
plt.tight_layout()
plt.savefig('dp_tradeoff.png')
plt.show()This is the fundamental tradeoff: stronger privacy (smaller ε) requires accepting less accurate results.
19.3.6 Composition: Privacy Budgets
Each query “spends” privacy budget. Multiple queries compose:
Sequential composition: If you run k mechanisms with privacy parameters ε₁, ε₂, …, εₖ on the same dataset, the total privacy loss is at most ε₁ + ε₂ + … + εₖ.
This means you must budget your ε across all queries. If you want total privacy budget ε = 1 and plan to make 10 queries, each query can use at most ε = 0.1.
Advanced composition provides tighter bounds but the principle remains: privacy degrades with more queries.
19.3.7 Differentially Private Machine Learning
For machine learning, we apply differential privacy to the training process. The most common approach is DP-SGD (Differentially Private Stochastic Gradient Descent):
- Clip gradients: Limit each sample’s contribution to the gradient
- Add noise: Add Gaussian noise to the aggregated gradient
- Track privacy: Account for privacy spent each step
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
# Simple DP-SGD implementation (conceptual)
def dp_sgd_step(model, batch_data, batch_labels, optimizer,
max_grad_norm, noise_multiplier):
"""
One step of differentially private SGD.
Args:
max_grad_norm: Clip individual gradients to this norm
noise_multiplier: Noise scale relative to sensitivity
"""
model.zero_grad()
# Compute per-sample gradients
batch_size = batch_data.shape[0]
per_sample_grads = []
for i in range(batch_size):
model.zero_grad()
output = model(batch_data[i:i+1])
loss = nn.functional.cross_entropy(output, batch_labels[i:i+1])
loss.backward()
# Collect gradients for this sample
sample_grad = []
for param in model.parameters():
if param.grad is not None:
sample_grad.append(param.grad.clone())
per_sample_grads.append(sample_grad)
# Clip each sample's gradient
clipped_grads = []
for sample_grad in per_sample_grads:
total_norm = torch.sqrt(sum(g.norm()**2 for g in sample_grad))
clip_factor = min(1.0, max_grad_norm / (total_norm + 1e-6))
clipped = [g * clip_factor for g in sample_grad]
clipped_grads.append(clipped)
# Aggregate and add noise
model.zero_grad()
for param_idx, param in enumerate(model.parameters()):
if param.grad is None:
continue
# Sum clipped gradients
aggregated = sum(grads[param_idx] for grads in clipped_grads) / batch_size
# Add Gaussian noise
noise_std = noise_multiplier * max_grad_norm / batch_size
noise = torch.randn_like(aggregated) * noise_std
param.grad = aggregated + noise
optimizer.step()19.3.8 Using Opacus for DP Training
In practice, use the Opacus library for differentially private PyTorch training:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from opacus import PrivacyEngine
from opacus.utils.batch_memory_manager import BatchMemoryManager
# Create a simple model
class SimpleClassifier(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.fc1(x))
return self.fc2(x)
# Generate synthetic medical data
torch.manual_seed(42)
n_samples = 10000
n_features = 50
X = torch.randn(n_samples, n_features)
y = (X[:, 0] + X[:, 1] + torch.randn(n_samples) * 0.5 > 0).long()
dataset = TensorDataset(X, y)
train_loader = DataLoader(dataset, batch_size=256, shuffle=True)
# Create model and optimizer
model = SimpleClassifier(n_features, 64, 2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Attach privacy engine
privacy_engine = PrivacyEngine()
model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
data_loader=train_loader,
epochs=10,
target_epsilon=1.0, # Privacy budget
target_delta=1e-5, # Probability of privacy failure
max_grad_norm=1.0, # Gradient clipping bound
)
# Training loop
criterion = nn.CrossEntropyLoss()
for epoch in range(10):
model.train()
total_loss = 0
for batch_X, batch_y in train_loader:
optimizer.zero_grad()
outputs = model(batch_X)
loss = criterion(outputs, batch_y)
loss.backward()
optimizer.step()
total_loss += loss.item()
# Get current privacy spent
epsilon = privacy_engine.get_epsilon(delta=1e-5)
print(f"Epoch {epoch+1}: Loss = {total_loss/len(train_loader):.4f}, ε = {epsilon:.2f}")
print(f"\nFinal privacy guarantee: (ε={epsilon:.2f}, δ=1e-5)-differential privacy")19.3.9 Choosing ε: Practical Guidance
What ε should you use? There’s no universal answer, but guidelines exist:
| ε Value | Privacy Level | Typical Use Case |
|---|---|---|
| 0.1 - 0.5 | Strong | Highly sensitive data, regulatory requirements |
| 1 - 3 | Moderate | Research publications, general medical AI |
| 5 - 10 | Weak | Internal analytics, low-sensitivity data |
Consider: - What’s the threat model? Stronger adversaries require smaller ε - How sensitive is the data? Genetic data vs. step counts - What utility do you need? Smaller ε means more noise - How many queries? Budget must cover all analyses
For medical AI, ε between 1 and 8 is common, with δ = 1/n² where n is dataset size.
19.4 Federated Learning
Clinical Context: Ten hospitals want to train a shared radiology AI model. Each has 50,000 chest X-rays. Combining them would create a powerful dataset, but data can’t leave each hospital. Federated learning enables training on all 500,000 images without any image leaving its source institution.
19.4.1 The Federated Paradigm
Federated learning inverts the traditional ML paradigm (Rieke et al. 2020):
- Traditional: Bring data to computation (centralize data, then train)
- Federated: Bring computation to data (train locally, share only models)
Instead of collecting data in one place, federated learning: 1. Sends the model to each data source 2. Trains locally at each source 3. Aggregates model updates (not data) centrally 4. Repeats until convergence
19.4.2 FedAvg: The Basic Algorithm
Federated Averaging (FedAvg) is the foundational federated learning algorithm:
For each round t:
1. Server sends current model θt to selected clients
2. Each client k:
- Trains on local data for E epochs
- Computes update: Δk = θlocal - θt
3. Server aggregates: θt+1 = θt + Σk (nk/n) Δk
where nk is client k's data size, n is total
The key insight: averaging model parameters (or gradients) doesn’t require sharing data.
import torch
import torch.nn as nn
import copy
from torch.utils.data import DataLoader, TensorDataset, Subset
import numpy as np
class FederatedServer:
"""Central server for federated learning."""
def __init__(self, model):
self.global_model = model
def aggregate(self, client_models, client_sizes):
"""Weighted average of client models."""
total_size = sum(client_sizes)
# Initialize aggregated state dict
aggregated_state = {}
for key in self.global_model.state_dict():
aggregated_state[key] = torch.zeros_like(
self.global_model.state_dict()[key], dtype=torch.float32
)
# Weighted sum
for client_model, size in zip(client_models, client_sizes):
weight = size / total_size
for key in aggregated_state:
aggregated_state[key] += weight * client_model.state_dict()[key]
# Update global model
self.global_model.load_state_dict(aggregated_state)
return copy.deepcopy(self.global_model)
class FederatedClient:
"""Client in federated learning."""
def __init__(self, client_id, data_loader, device='cpu'):
self.client_id = client_id
self.data_loader = data_loader
self.device = device
def train(self, model, epochs, lr=0.01):
"""Local training on client data."""
model = copy.deepcopy(model).to(self.device)
model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
for epoch in range(epochs):
for batch_X, batch_y in self.data_loader:
batch_X, batch_y = batch_X.to(self.device), batch_y.to(self.device)
optimizer.zero_grad()
outputs = model(batch_X)
loss = criterion(outputs, batch_y)
loss.backward()
optimizer.step()
return model
@property
def data_size(self):
return len(self.data_loader.dataset)
def federated_training(server, clients, rounds, local_epochs):
"""Run federated learning for specified rounds."""
history = {'round': [], 'accuracy': []}
for round_num in range(rounds):
# Each client trains locally
client_models = []
client_sizes = []
for client in clients:
# Get current global model
local_model = client.train(
server.global_model,
epochs=local_epochs
)
client_models.append(local_model)
client_sizes.append(client.data_size)
# Server aggregates
server.aggregate(client_models, client_sizes)
# Evaluate (on combined test data)
accuracy = evaluate_global_model(server.global_model, test_loader)
history['round'].append(round_num)
history['accuracy'].append(accuracy)
print(f"Round {round_num + 1}: Global accuracy = {accuracy:.4f}")
return history
def evaluate_global_model(model, test_loader, device='cpu'):
"""Evaluate model on test data."""
model.eval()
correct = 0
total = 0
with torch.no_grad():
for X, y in test_loader:
X, y = X.to(device), y.to(device)
outputs = model(X)
_, predicted = outputs.max(1)
total += y.size(0)
correct += predicted.eq(y).sum().item()
return correct / total19.4.3 Complete Federated Learning Example
# Simulate multi-hospital scenario
torch.manual_seed(42)
np.random.seed(42)
# Create synthetic data for 5 "hospitals"
n_hospitals = 5
samples_per_hospital = 2000
n_features = 30
# Each hospital has slightly different data distribution (non-IID)
all_data = []
for h in range(n_hospitals):
# Hospital-specific mean shift
mean_shift = np.random.randn(n_features) * 0.5
X = torch.randn(samples_per_hospital, n_features) + torch.tensor(mean_shift, dtype=torch.float32)
y = (X[:, 0] + X[:, 1] + torch.randn(samples_per_hospital) * 0.5 > 0).long()
all_data.append((X, y))
# Create federated clients (one per hospital)
clients = []
for h in range(n_hospitals):
X, y = all_data[h]
# Split into train (80%) and contribute to test (20%)
n_train = int(0.8 * samples_per_hospital)
train_dataset = TensorDataset(X[:n_train], y[:n_train])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
clients.append(FederatedClient(f"Hospital_{h}", train_loader))
# Create global test set from all hospitals
test_X = torch.cat([all_data[h][0][int(0.8*samples_per_hospital):] for h in range(n_hospitals)])
test_y = torch.cat([all_data[h][1][int(0.8*samples_per_hospital):] for h in range(n_hospitals)])
test_loader = DataLoader(TensorDataset(test_X, test_y), batch_size=64)
# Initialize global model and server
global_model = SimpleClassifier(n_features, 32, 2)
server = FederatedServer(global_model)
# Run federated learning
print("Starting Federated Learning")
print(f"Number of hospitals: {n_hospitals}")
print(f"Samples per hospital: {samples_per_hospital}")
print("-" * 40)
history = federated_training(
server=server,
clients=clients,
rounds=20,
local_epochs=3
)
print("-" * 40)
print(f"Final global model accuracy: {history['accuracy'][-1]:.4f}")19.4.4 Challenges in Federated Learning
Non-IID Data: Real hospitals have different patient populations. Data isn’t “independent and identically distributed” across clients.
# Simulating non-IID: each hospital specializes in different conditions
# Hospital 0: mostly young patients (class 0)
# Hospital 1: mostly elderly patients (class 1)
# This makes federated learning harder
def create_non_iid_data(n_hospitals, samples_per_hospital, skew=0.8):
"""Create non-IID data where each hospital has skewed class distribution."""
all_data = []
for h in range(n_hospitals):
X = torch.randn(samples_per_hospital, n_features)
# Skew: hospital h prefers class (h % 2)
preferred_class = h % 2
probs = torch.rand(samples_per_hospital)
y = torch.where(
probs < skew,
torch.full((samples_per_hospital,), preferred_class),
torch.randint(0, 2, (samples_per_hospital,))
)
all_data.append((X, y.long()))
return all_data
# Non-IID data is harder to learn from
# Solutions: FedProx, local adaptation, personalizationCommunication Efficiency: Sending full models every round is expensive. Techniques include: - Gradient compression - Less frequent communication - Partial model updates
Stragglers: Slow clients delay training. Solutions: - Asynchronous aggregation - Client selection strategies - Timeout mechanisms
19.4.5 Medical Applications
Federated learning enables previously impossible collaborations:
Multi-site clinical trials: Train AI on data from trial sites worldwide without centralizing patient data.
Rare disease research: Combine data across hospitals to reach meaningful sample sizes for rare conditions.
International collaboration: Work across jurisdictions with different privacy laws.
Real-world example: The Federated Tumor Segmentation (FeTS) challenge demonstrated federated learning for brain tumor segmentation across 30+ institutions globally.
19.5 Privacy Attacks and Defenses
Clinical Context: Your hospital participates in a federated learning consortium. You share model updates, not patient data. Is this actually private? Unfortunately, model updates can leak information about training data. Understanding attacks is essential for building robust defenses.
19.5.1 Membership Inference Attacks
Question: Was a specific patient’s data used to train this model?
Attack: Train a classifier that distinguishes model behavior on training data vs. non-training data. Models often have higher confidence on examples they were trained on.
def membership_inference_attack(target_model, member_data, non_member_data):
"""
Simple membership inference based on prediction confidence.
Higher confidence often indicates membership in training set.
"""
target_model.eval()
def get_confidence(model, data):
with torch.no_grad():
outputs = model(data)
probs = torch.softmax(outputs, dim=1)
confidence = probs.max(dim=1).values
return confidence
member_conf = get_confidence(target_model, member_data)
non_member_conf = get_confidence(target_model, non_member_data)
# Members typically have higher confidence
print(f"Member confidence: {member_conf.mean():.4f} +/- {member_conf.std():.4f}")
print(f"Non-member confidence: {non_member_conf.mean():.4f} +/- {non_member_conf.std():.4f}")
# Simple threshold attack
threshold = (member_conf.mean() + non_member_conf.mean()) / 2
member_correct = (member_conf > threshold).float().mean()
non_member_correct = (non_member_conf <= threshold).float().mean()
attack_accuracy = (member_correct + non_member_correct) / 2
print(f"Attack accuracy: {attack_accuracy:.4f}")
return attack_accuracyDefense: Differential privacy limits membership inference by ensuring model behavior is similar whether or not any individual was in training.
19.5.2 Model Inversion Attacks
Question: Can we reconstruct training data from the model?
Attack: Optimize an input to maximize model confidence for a specific class, potentially revealing what training examples looked like.
For medical imaging, this could mean reconstructing patient faces from a facial recognition model or patient X-rays from a diagnostic model.
Defense: Differential privacy, limiting model access, adding noise to outputs.
19.5.3 Gradient Leakage in Federated Learning
Question: Can shared gradients reveal training data?
Attack: Given gradient updates, reconstruct the training batch that produced them. This is especially effective for small batches and early training rounds.
def gradient_leakage_attack(gradient, model, input_shape):
"""
Attempt to reconstruct training data from gradients.
This is a simplified version - real attacks are more sophisticated.
"""
# Initialize random input
dummy_input = torch.randn(input_shape, requires_grad=True)
dummy_label = torch.randint(0, 2, (input_shape[0],))
optimizer = torch.optim.LBFGS([dummy_input])
def closure():
optimizer.zero_grad()
# Compute gradients for dummy input
model.zero_grad()
output = model(dummy_input)
loss = nn.functional.cross_entropy(output, dummy_label)
dummy_gradient = torch.autograd.grad(loss, model.parameters(), create_graph=True)
# Minimize distance between dummy gradient and actual gradient
grad_diff = sum(
((dg - g) ** 2).sum()
for dg, g in zip(dummy_gradient, gradient)
)
grad_diff.backward()
return grad_diff
# Optimize to find input that produces similar gradient
for _ in range(100):
optimizer.step(closure)
return dummy_input.detach()Defense: Secure aggregation ensures the server only sees aggregated gradients, not individual client contributions.
19.5.4 Combining Defenses
Robust privacy requires multiple layers:
- Differential privacy: Fundamental privacy guarantee
- Secure aggregation: Protect individual client updates
- Gradient clipping: Limit information in each update
- Minimum batch sizes: Prevent single-sample gradient reconstruction
19.6 Secure Multi-Party Computation
Clinical Context: Two hospitals want to compute the combined prevalence of a rare disease without either learning the other’s patient count. Secure computation enables this—computing on data while keeping inputs private.
19.6.1 What Secure Computation Enables
Secure Multi-Party Computation (MPC) allows parties to jointly compute a function over their inputs without revealing inputs to each other.
Example: Hospitals A and B have patient counts \(n_A\) and \(n_B\). They want to compute \(n_A + n_B\) such that: - A doesn’t learn \(n_B\) - B doesn’t learn \(n_A\) - Both learn \(n_A + n_B\)
This sounds impossible, but cryptographic techniques make it achievable.
19.6.2 Secret Sharing Intuition
Secret sharing splits a secret into shares such that: - Any single share reveals nothing about the secret - Combined shares reconstruct the secret
Simple example (additive sharing): - Secret: 42 - Share 1: random number r = 17 - Share 2: 42 - 17 = 25 - Neither 17 nor 25 reveals 42, but 17 + 25 = 42
For secure computation: 1. Each party secret-shares their input 2. Computation happens on shares (requires special protocols) 3. Result shares are combined to get final answer
19.6.3 Homomorphic Encryption
Homomorphic encryption allows computation on encrypted data:
\[E(a) \oplus E(b) = E(a + b)\]
You can add (or multiply, depending on scheme) encrypted values without decrypting. The result, when decrypted, equals what you’d get from computing on plaintext.
Fully homomorphic encryption (FHE) supports arbitrary computation on encrypted data, but it’s computationally expensive—often 1000x+ slower than plaintext.
19.6.4 Practical Considerations
Secure computation is powerful but has limitations:
Performance: MPC and FHE are orders of magnitude slower than plaintext computation. Training neural networks with MPC is possible but expensive.
Complexity: Implementing secure protocols correctly is hard. Use established libraries (MP-SPDZ, CrypTen, TenSEAL).
Trust model: MPC requires assumptions about adversary capabilities (honest-but-curious vs. malicious).
When to use: - High-value, high-sensitivity computations - Regulatory requirements mandate it - Performance cost is acceptable - Simpler methods (DP, FL) are insufficient
For most medical AI applications, differential privacy and federated learning provide sufficient protection at much lower cost.
19.7 Putting It Together: Private Medical AI Pipeline
Clinical Context: You’re designing a multi-hospital collaboration for sepsis prediction. Here’s how to combine privacy techniques for a robust solution.
19.7.1 Architecture Overview
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│ Hospital A │ │ Hospital B │ │ Hospital C │
│ │ │ │ │ │
│ Local Data │ │ Local Data │ │ Local Data │
│ ↓ │ │ ↓ │ │ ↓ │
│ DP-SGD │ │ DP-SGD │ │ DP-SGD │
│ Training │ │ Training │ │ Training │
│ ↓ │ │ ↓ │ │ ↓ │
│ Noisy │ │ Noisy │ │ Noisy │
│ Gradients │ │ Gradients │ │ Gradients │
└──────┬──────┘ └──────┬──────┘ └──────┬──────┘
│ │ │
└───────────────────┼───────────────────┘
│
┌──────▼──────┐
│ Secure │
│ Aggregator │
│ │
│ FedAvg on │
│ Noisy │
│ Updates │
└──────┬──────┘
│
┌──────▼──────┐
│ Global │
│ Model │
└─────────────┘
19.7.2 Implementation
class PrivateFederatedLearning:
"""
Federated learning with differential privacy.
Combines:
- Local DP-SGD training at each client
- Secure aggregation of updates
- Privacy budget tracking
"""
def __init__(self, model, n_clients, target_epsilon, target_delta,
rounds, local_epochs):
self.global_model = model
self.n_clients = n_clients
self.target_epsilon = target_epsilon
self.target_delta = target_delta
self.rounds = rounds
self.local_epochs = local_epochs
# Per-round privacy budget
self.epsilon_per_round = target_epsilon / rounds
def train_round(self, clients):
"""One round of private federated training."""
client_models = []
client_sizes = []
for client in clients:
# Each client trains with DP-SGD
local_model = client.private_train(
model=copy.deepcopy(self.global_model),
epochs=self.local_epochs,
epsilon=self.epsilon_per_round,
delta=self.target_delta / self.rounds
)
client_models.append(local_model)
client_sizes.append(client.data_size)
# Aggregate (with secure aggregation in production)
self.aggregate_models(client_models, client_sizes)
def aggregate_models(self, client_models, client_sizes):
"""Weighted averaging of client models."""
total_size = sum(client_sizes)
with torch.no_grad():
for key in self.global_model.state_dict():
weighted_sum = sum(
(size / total_size) * model.state_dict()[key]
for model, size in zip(client_models, client_sizes)
)
self.global_model.state_dict()[key].copy_(weighted_sum)
class PrivateClient:
"""Client with differential privacy."""
def __init__(self, client_id, data_loader):
self.client_id = client_id
self.data_loader = data_loader
def private_train(self, model, epochs, epsilon, delta):
"""Train with differential privacy using Opacus."""
from opacus import PrivacyEngine
model = model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
privacy_engine = PrivacyEngine()
model, optimizer, data_loader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
data_loader=self.data_loader,
epochs=epochs,
target_epsilon=epsilon,
target_delta=delta,
max_grad_norm=1.0
)
criterion = nn.CrossEntropyLoss()
for epoch in range(epochs):
for batch_X, batch_y in data_loader:
optimizer.zero_grad()
loss = criterion(model(batch_X), batch_y)
loss.backward()
optimizer.step()
return model
@property
def data_size(self):
return len(self.data_loader.dataset)19.7.3 Privacy Accounting
Track total privacy spent across all operations:
class PrivacyAccountant:
"""Track cumulative privacy loss."""
def __init__(self, target_epsilon, target_delta):
self.target_epsilon = target_epsilon
self.target_delta = target_delta
self.spent_epsilon = 0
self.operations = []
def spend(self, epsilon, operation_name):
"""Record privacy expenditure."""
self.spent_epsilon += epsilon
self.operations.append({
'operation': operation_name,
'epsilon': epsilon,
'cumulative': self.spent_epsilon
})
if self.spent_epsilon > self.target_epsilon:
raise ValueError(
f"Privacy budget exceeded! Spent {self.spent_epsilon:.2f}, "
f"budget was {self.target_epsilon:.2f}"
)
def remaining(self):
return self.target_epsilon - self.spent_epsilon
def report(self):
print(f"Privacy Budget Report")
print(f"Target: ε = {self.target_epsilon}, δ = {self.target_delta}")
print(f"Spent: ε = {self.spent_epsilon:.4f}")
print(f"Remaining: ε = {self.remaining():.4f}")
print(f"\nOperations:")
for op in self.operations:
print(f" {op['operation']}: ε = {op['epsilon']:.4f} "
f"(cumulative: {op['cumulative']:.4f})")19.8 Making Privacy Decisions
19.8.1 Decision Framework
When designing a privacy-preserving system:
1. What's the threat model?
- Who might attack? (curious collaborators, external adversaries)
- What's their capability? (access to model, auxiliary data)
- What's the harm if privacy fails?
2. What are the regulatory requirements?
- HIPAA, GDPR, institutional policies
- What level of de-identification is required?
3. What utility do you need?
- Acceptable accuracy loss?
- Query/training budget?
4. Choose techniques:
- Data stays local → Federated learning
- Need formal guarantees → Differential privacy
- High-security requirements → Add secure computation
- Combine as needed
19.8.2 Technique Selection Guide
| Scenario | Recommended Approach |
|---|---|
| Multi-site collaboration, moderate trust | Federated learning |
| Releasing aggregate statistics | Differential privacy |
| Sharing trained model publicly | DP training |
| Very high sensitivity (genomics) | FL + DP + secure aggregation |
| Single institution, internal use | May not need privacy ML |
19.8.3 Communicating Privacy
Stakeholders need to understand privacy guarantees:
For technical audiences: Report ε, δ, and methodology. “This model was trained with (ε=2, δ=10⁻⁵)-differential privacy using DP-SGD with gradient clipping at norm 1.0.”
For non-technical audiences: Explain implications. “Even an attacker who knows everything about the training data except for one patient cannot determine with confidence whether that patient was included.”
For patients: Focus on protections. “Your data helps improve medical AI, but the analysis is designed so that your individual information cannot be extracted.”
19.8.4 When Privacy-Preserving ML Isn’t Enough
Privacy techniques have limits:
- Utility loss: Strong privacy may degrade model quality unacceptably
- Implementation complexity: Bugs in privacy implementations can void guarantees
- Scope: DP protects training data; other privacy risks (inference attacks, model outputs) need separate handling
- Trust assumptions: FL requires trusting the aggregator; MPC requires trusting protocol implementation
Sometimes the answer is “don’t build this” or “this requires a fundamentally different approach.”
19.9 Chapter Summary
Privacy-preserving techniques enable medical AI development while protecting patient data.
The data access problem: - Medical AI benefits from scale, but data is siloed and sensitive - Regulations (HIPAA, GDPR) and institutional concerns limit sharing - Technical solutions enable collaboration without centralized data
Differential privacy: - Mathematical framework limiting what queries reveal about individuals - ε controls privacy-utility tradeoff - DP-SGD enables private model training - Privacy budgets must be tracked across all operations
Federated learning: - Train models without centralizing data - FedAvg: local training, aggregate updates - Challenges: non-IID data, communication, stragglers - Enables multi-site medical AI collaboration
Privacy attacks: - Membership inference: was this patient in training data? - Gradient leakage: reconstruct data from updates - Defenses: DP, secure aggregation, gradient clipping
Secure computation: - MPC and homomorphic encryption enable computation on private data - Powerful but expensive; use when simpler methods insufficient
Practical deployment: - Combine techniques: federated learning + differential privacy - Track privacy budget throughout pipeline - Match approach to threat model and requirements
These techniques don’t eliminate privacy risk, but they provide principled tools for managing it—enabling medical AI that would otherwise be impossible.
19.10 Exercises
Implement the Laplace mechanism for a count query. Experiment with different ε values and visualize the accuracy vs. privacy tradeoff.
Using Opacus, train a classifier on a medical dataset (e.g., UCI Heart Disease) with ε = 1.0. Compare accuracy to non-private training. How much utility is lost?
Implement a simple federated learning simulation with 5 clients. Create non-IID data distributions and observe how this affects convergence compared to IID data.
Implement a basic membership inference attack against a model you’ve trained. How effective is it? How does differential privacy training affect attack success?
Design a privacy-preserving pipeline for a specific scenario: three competing hospitals want to train a shared stroke prediction model without revealing their patient populations or model details to each other. What techniques would you use and why?