19  Privacy & Federated Learning

Clinical Context: Five major hospital systems want to collaborate on an AI model for early sepsis detection. Together, they have 500,000 patient encounters—enough data to build a robust model. But legal teams at each institution won’t allow patient data to leave their systems. How can they collaborate without sharing data? This chapter covers the technical approaches that make privacy-preserving medical AI possible.

Medical AI faces a fundamental tension: better models require more data, but patient data is sensitive, regulated, and siloed across institutions. A single hospital may have thousands of cases of a rare disease; combined data across hundreds of hospitals could provide millions. Yet data sharing agreements are slow, legally complex, and sometimes impossible.

Privacy-preserving machine learning offers a path forward. Differential privacy provides mathematical guarantees about what an adversary can learn from model outputs. Federated learning enables model training without centralizing data. Secure computation allows analysis of encrypted data. These techniques don’t eliminate privacy risks, but they provide tools for managing them.

This chapter focuses on the technical approaches to privacy-preserving medical AI. We cover the mathematics of differential privacy with enough depth to reason about privacy-utility tradeoffs. We implement federated learning with working code. And we discuss how these techniques combine to enable collaborative medical AI.

19.1 The Data Access Problem

19.1.1 Why Medical AI Needs More Data

Machine learning is data-hungry. Deep learning especially benefits from scale—models trained on larger datasets generally perform better. For medical AI:

  • Rare diseases: A single hospital might see 50 cases per year; useful AI needs thousands
  • Subpopulation performance: Ensuring models work across demographics requires diverse data
  • Generalization: Models trained at one institution often fail at others; multi-site data improves robustness
  • Validation: External validation requires access to data from other sources

The institutions that need AI most—smaller hospitals, rural clinics—often have the least data to develop it.

19.1.2 Why Data Sharing Is Hard

Despite the benefits, sharing medical data faces significant barriers:

Regulatory constraints: HIPAA, GDPR, and other regulations restrict data transfer. Even “de-identified” data requires careful handling.

Institutional concerns: Hospitals view their data as a competitive asset. Sharing with academic medical centers or competitors raises strategic concerns.

Technical challenges: Healthcare data is messy, inconsistent across institutions, and difficult to harmonize.

Patient trust: Patients may consent to their data being used at their hospital but object to broader sharing.

Liability: If shared data is breached or misused, who is responsible?

19.1.3 The Privacy-Preserving Paradigm

Privacy-preserving techniques change the question from “how do we share data?” to “how do we share insights without sharing data?”

Approach Key Idea When to Use
De-identification Remove identifying information Simple analyses, low-risk data
Differential privacy Add calibrated noise Query results, model training
Federated learning Train at data sources, share only models Multi-site collaboration
Secure computation Compute on encrypted data High-security requirements

These approaches aren’t mutually exclusive—the most robust systems combine multiple techniques.

19.2 Privacy Regulations and De-identification

19.2.1 HIPAA Basics

The Health Insurance Portability and Accountability Act (HIPAA) governs protected health information (PHI) in the United States. For AI development, the key provisions are:

Protected Health Information (PHI) includes any individually identifiable health information held by covered entities. This encompasses demographics, medical records, billing information—essentially anything that connects health data to an individual.

De-identification removes the legal protections of HIPAA. HIPAA provides two paths:

Safe Harbor: Remove 18 specific identifiers (names, dates, locations, SSN, etc.) and have no actual knowledge the remaining data could identify individuals.

Expert Determination: A qualified expert certifies that re-identification risk is “very small.”

De-identified data under HIPAA is no longer PHI and can be shared more freely.

19.2.2 The Limits of De-identification

De-identification is necessary but often insufficient:

Re-identification attacks: Researchers have re-identified individuals from “anonymous” datasets by linking with public records. The Netflix Prize dataset, AOL search logs, and hospital discharge data have all been re-identified.

Medical imaging: Facial structure can be reconstructed from head CT/MRI scans. Removing obvious identifiers doesn’t prevent recognition.

Genetic data: Genomic sequences are inherently identifying—you can’t de-identify a genome.

Rare conditions: If only one patient at a hospital has a specific rare disease, any data about that disease identifies them.

Longitudinal data: Combining multiple de-identified records can enable re-identification.

19.2.3 Beyond Compliance

HIPAA compliance is a floor, not a ceiling. Even legally de-identified data can pose privacy risks. Technical privacy measures provide additional protection:

  • What if an adversary has auxiliary information about a patient?
  • What can be inferred from model outputs even without direct data access?
  • How do we protect privacy against future re-identification techniques?

Differential privacy and federated learning address these questions.

19.3 Differential Privacy

Clinical Context: A researcher queries a hospital database: “What percentage of diabetic patients developed kidney disease within 5 years?” The answer—42%—seems innocuous. But if the researcher knows their neighbor is diabetic and is in this database, they’ve just learned something about their neighbor’s health. Differential privacy prevents such inference.

19.3.1 The Problem with “Anonymized” Data

Consider a database of patient health records. Even without names, queries can reveal individual information:

  • “How many patients have HIV?” → “How many patients have HIV, excluding patient 12345?”
  • The difference reveals whether patient 12345 has HIV

This differencing attack works for any query. As long as queries reveal information about the dataset, they reveal information about individuals in it.

19.3.2 The Differential Privacy Definition

Differential privacy provides a mathematical framework that limits what any query can reveal about any individual (Dwork and Roth 2014).

A mechanism \(M\) satisfies ε-differential privacy if for all datasets \(D\) and \(D'\) differing in one record, and all possible outputs \(S\):

\[\frac{P(M(D) \in S)}{P(M(D') \in S)} \leq e^\varepsilon\]

In words: the probability of any output changes by at most a factor of \(e^\varepsilon\) whether or not any individual is in the dataset.

What ε means: - ε = 0: Perfect privacy (no information revealed), but useless outputs - ε = 0.1: Strong privacy, but utility may suffer - ε = 1: Moderate privacy, reasonable utility - ε = 10: Weak privacy, good utility

The parameter ε is called the privacy budget. Smaller ε means stronger privacy but typically worse utility.

19.3.3 Intuition: Plausible Deniability

Differential privacy provides plausible deniability. If you’re in a database and the mechanism outputs some result, you can plausibly claim that the same result would have occurred even if you weren’t in the database.

An adversary who sees the output and knows everything except whether you’re in the database can’t confidently determine your presence—the output is almost as likely either way.

19.3.4 The Laplace Mechanism

How do we achieve differential privacy? The simplest approach: add noise.

The Laplace mechanism adds noise drawn from a Laplace distribution:

\[M(D) = f(D) + \text{Laplace}\left(\frac{\Delta f}{\varepsilon}\right)\]

where \(\Delta f\) is the sensitivity of the function—how much the output can change when one record changes.

import numpy as np

def laplace_mechanism(true_value, sensitivity, epsilon):
    """
    Add Laplace noise to achieve ε-differential privacy.

    Args:
        true_value: The actual query result
        sensitivity: Maximum change from adding/removing one record
        epsilon: Privacy parameter (smaller = more private)

    Returns:
        Noisy value satisfying ε-differential privacy
    """
    scale = sensitivity / epsilon
    noise = np.random.laplace(0, scale)
    return true_value + noise

# Example: Count query
# Sensitivity = 1 (adding/removing one person changes count by at most 1)
true_count = 1000
epsilon = 0.1

noisy_count = laplace_mechanism(true_count, sensitivity=1, epsilon=epsilon)
print(f"True count: {true_count}")
print(f"Noisy count (ε={epsilon}): {noisy_count:.0f}")

# Multiple runs show the noise distribution
noisy_counts = [laplace_mechanism(true_count, 1, epsilon) for _ in range(1000)]
print(f"Mean of noisy counts: {np.mean(noisy_counts):.1f}")
print(f"Std of noisy counts: {np.std(noisy_counts):.1f}")

19.3.5 The Privacy-Utility Tradeoff

Smaller ε means more noise, which means less accurate results:

import matplotlib.pyplot as plt

true_value = 100
sensitivity = 1

epsilons = [0.01, 0.1, 0.5, 1.0, 5.0]
n_samples = 1000

fig, axes = plt.subplots(1, len(epsilons), figsize=(15, 3))

for ax, eps in zip(axes, epsilons):
    noisy_values = [laplace_mechanism(true_value, sensitivity, eps)
                    for _ in range(n_samples)]
    ax.hist(noisy_values, bins=50, density=True)
    ax.axvline(true_value, color='red', linestyle='--', label='True value')
    ax.set_title(f'ε = {eps}')
    ax.set_xlim([0, 200])

plt.tight_layout()
plt.savefig('dp_tradeoff.png')
plt.show()

This is the fundamental tradeoff: stronger privacy (smaller ε) requires accepting less accurate results.

19.3.6 Composition: Privacy Budgets

Each query “spends” privacy budget. Multiple queries compose:

Sequential composition: If you run k mechanisms with privacy parameters ε₁, ε₂, …, εₖ on the same dataset, the total privacy loss is at most ε₁ + ε₂ + … + εₖ.

This means you must budget your ε across all queries. If you want total privacy budget ε = 1 and plan to make 10 queries, each query can use at most ε = 0.1.

Advanced composition provides tighter bounds but the principle remains: privacy degrades with more queries.

19.3.7 Differentially Private Machine Learning

For machine learning, we apply differential privacy to the training process. The most common approach is DP-SGD (Differentially Private Stochastic Gradient Descent):

  1. Clip gradients: Limit each sample’s contribution to the gradient
  2. Add noise: Add Gaussian noise to the aggregated gradient
  3. Track privacy: Account for privacy spent each step
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

# Simple DP-SGD implementation (conceptual)
def dp_sgd_step(model, batch_data, batch_labels, optimizer,
                max_grad_norm, noise_multiplier):
    """
    One step of differentially private SGD.

    Args:
        max_grad_norm: Clip individual gradients to this norm
        noise_multiplier: Noise scale relative to sensitivity
    """
    model.zero_grad()

    # Compute per-sample gradients
    batch_size = batch_data.shape[0]
    per_sample_grads = []

    for i in range(batch_size):
        model.zero_grad()
        output = model(batch_data[i:i+1])
        loss = nn.functional.cross_entropy(output, batch_labels[i:i+1])
        loss.backward()

        # Collect gradients for this sample
        sample_grad = []
        for param in model.parameters():
            if param.grad is not None:
                sample_grad.append(param.grad.clone())
        per_sample_grads.append(sample_grad)

    # Clip each sample's gradient
    clipped_grads = []
    for sample_grad in per_sample_grads:
        total_norm = torch.sqrt(sum(g.norm()**2 for g in sample_grad))
        clip_factor = min(1.0, max_grad_norm / (total_norm + 1e-6))
        clipped = [g * clip_factor for g in sample_grad]
        clipped_grads.append(clipped)

    # Aggregate and add noise
    model.zero_grad()
    for param_idx, param in enumerate(model.parameters()):
        if param.grad is None:
            continue

        # Sum clipped gradients
        aggregated = sum(grads[param_idx] for grads in clipped_grads) / batch_size

        # Add Gaussian noise
        noise_std = noise_multiplier * max_grad_norm / batch_size
        noise = torch.randn_like(aggregated) * noise_std

        param.grad = aggregated + noise

    optimizer.step()

19.3.8 Using Opacus for DP Training

In practice, use the Opacus library for differentially private PyTorch training:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from opacus import PrivacyEngine
from opacus.utils.batch_memory_manager import BatchMemoryManager

# Create a simple model
class SimpleClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        return self.fc2(x)

# Generate synthetic medical data
torch.manual_seed(42)
n_samples = 10000
n_features = 50

X = torch.randn(n_samples, n_features)
y = (X[:, 0] + X[:, 1] + torch.randn(n_samples) * 0.5 > 0).long()

dataset = TensorDataset(X, y)
train_loader = DataLoader(dataset, batch_size=256, shuffle=True)

# Create model and optimizer
model = SimpleClassifier(n_features, 64, 2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Attach privacy engine
privacy_engine = PrivacyEngine()

model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
    module=model,
    optimizer=optimizer,
    data_loader=train_loader,
    epochs=10,
    target_epsilon=1.0,  # Privacy budget
    target_delta=1e-5,   # Probability of privacy failure
    max_grad_norm=1.0,   # Gradient clipping bound
)

# Training loop
criterion = nn.CrossEntropyLoss()

for epoch in range(10):
    model.train()
    total_loss = 0

    for batch_X, batch_y in train_loader:
        optimizer.zero_grad()
        outputs = model(batch_X)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    # Get current privacy spent
    epsilon = privacy_engine.get_epsilon(delta=1e-5)
    print(f"Epoch {epoch+1}: Loss = {total_loss/len(train_loader):.4f}, ε = {epsilon:.2f}")

print(f"\nFinal privacy guarantee: (ε={epsilon:.2f}, δ=1e-5)-differential privacy")

19.3.9 Choosing ε: Practical Guidance

What ε should you use? There’s no universal answer, but guidelines exist:

ε Value Privacy Level Typical Use Case
0.1 - 0.5 Strong Highly sensitive data, regulatory requirements
1 - 3 Moderate Research publications, general medical AI
5 - 10 Weak Internal analytics, low-sensitivity data

Consider: - What’s the threat model? Stronger adversaries require smaller ε - How sensitive is the data? Genetic data vs. step counts - What utility do you need? Smaller ε means more noise - How many queries? Budget must cover all analyses

For medical AI, ε between 1 and 8 is common, with δ = 1/n² where n is dataset size.

19.4 Federated Learning

Clinical Context: Ten hospitals want to train a shared radiology AI model. Each has 50,000 chest X-rays. Combining them would create a powerful dataset, but data can’t leave each hospital. Federated learning enables training on all 500,000 images without any image leaving its source institution.

19.4.1 The Federated Paradigm

Federated learning inverts the traditional ML paradigm (Rieke et al. 2020):

  • Traditional: Bring data to computation (centralize data, then train)
  • Federated: Bring computation to data (train locally, share only models)

Instead of collecting data in one place, federated learning: 1. Sends the model to each data source 2. Trains locally at each source 3. Aggregates model updates (not data) centrally 4. Repeats until convergence

19.4.2 FedAvg: The Basic Algorithm

Federated Averaging (FedAvg) is the foundational federated learning algorithm:

For each round t:
    1. Server sends current model θt to selected clients
    2. Each client k:
       - Trains on local data for E epochs
       - Computes update: Δk = θlocal - θt
    3. Server aggregates: θt+1 = θt + Σk (nk/n) Δk
       where nk is client k's data size, n is total

The key insight: averaging model parameters (or gradients) doesn’t require sharing data.

import torch
import torch.nn as nn
import copy
from torch.utils.data import DataLoader, TensorDataset, Subset
import numpy as np

class FederatedServer:
    """Central server for federated learning."""

    def __init__(self, model):
        self.global_model = model

    def aggregate(self, client_models, client_sizes):
        """Weighted average of client models."""
        total_size = sum(client_sizes)

        # Initialize aggregated state dict
        aggregated_state = {}
        for key in self.global_model.state_dict():
            aggregated_state[key] = torch.zeros_like(
                self.global_model.state_dict()[key], dtype=torch.float32
            )

        # Weighted sum
        for client_model, size in zip(client_models, client_sizes):
            weight = size / total_size
            for key in aggregated_state:
                aggregated_state[key] += weight * client_model.state_dict()[key]

        # Update global model
        self.global_model.load_state_dict(aggregated_state)
        return copy.deepcopy(self.global_model)


class FederatedClient:
    """Client in federated learning."""

    def __init__(self, client_id, data_loader, device='cpu'):
        self.client_id = client_id
        self.data_loader = data_loader
        self.device = device

    def train(self, model, epochs, lr=0.01):
        """Local training on client data."""
        model = copy.deepcopy(model).to(self.device)
        model.train()

        optimizer = torch.optim.SGD(model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()

        for epoch in range(epochs):
            for batch_X, batch_y in self.data_loader:
                batch_X, batch_y = batch_X.to(self.device), batch_y.to(self.device)

                optimizer.zero_grad()
                outputs = model(batch_X)
                loss = criterion(outputs, batch_y)
                loss.backward()
                optimizer.step()

        return model

    @property
    def data_size(self):
        return len(self.data_loader.dataset)


def federated_training(server, clients, rounds, local_epochs):
    """Run federated learning for specified rounds."""

    history = {'round': [], 'accuracy': []}

    for round_num in range(rounds):
        # Each client trains locally
        client_models = []
        client_sizes = []

        for client in clients:
            # Get current global model
            local_model = client.train(
                server.global_model,
                epochs=local_epochs
            )
            client_models.append(local_model)
            client_sizes.append(client.data_size)

        # Server aggregates
        server.aggregate(client_models, client_sizes)

        # Evaluate (on combined test data)
        accuracy = evaluate_global_model(server.global_model, test_loader)
        history['round'].append(round_num)
        history['accuracy'].append(accuracy)

        print(f"Round {round_num + 1}: Global accuracy = {accuracy:.4f}")

    return history


def evaluate_global_model(model, test_loader, device='cpu'):
    """Evaluate model on test data."""
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for X, y in test_loader:
            X, y = X.to(device), y.to(device)
            outputs = model(X)
            _, predicted = outputs.max(1)
            total += y.size(0)
            correct += predicted.eq(y).sum().item()

    return correct / total

19.4.3 Complete Federated Learning Example

# Simulate multi-hospital scenario
torch.manual_seed(42)
np.random.seed(42)

# Create synthetic data for 5 "hospitals"
n_hospitals = 5
samples_per_hospital = 2000
n_features = 30

# Each hospital has slightly different data distribution (non-IID)
all_data = []
for h in range(n_hospitals):
    # Hospital-specific mean shift
    mean_shift = np.random.randn(n_features) * 0.5
    X = torch.randn(samples_per_hospital, n_features) + torch.tensor(mean_shift, dtype=torch.float32)
    y = (X[:, 0] + X[:, 1] + torch.randn(samples_per_hospital) * 0.5 > 0).long()
    all_data.append((X, y))

# Create federated clients (one per hospital)
clients = []
for h in range(n_hospitals):
    X, y = all_data[h]
    # Split into train (80%) and contribute to test (20%)
    n_train = int(0.8 * samples_per_hospital)
    train_dataset = TensorDataset(X[:n_train], y[:n_train])
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    clients.append(FederatedClient(f"Hospital_{h}", train_loader))

# Create global test set from all hospitals
test_X = torch.cat([all_data[h][0][int(0.8*samples_per_hospital):] for h in range(n_hospitals)])
test_y = torch.cat([all_data[h][1][int(0.8*samples_per_hospital):] for h in range(n_hospitals)])
test_loader = DataLoader(TensorDataset(test_X, test_y), batch_size=64)

# Initialize global model and server
global_model = SimpleClassifier(n_features, 32, 2)
server = FederatedServer(global_model)

# Run federated learning
print("Starting Federated Learning")
print(f"Number of hospitals: {n_hospitals}")
print(f"Samples per hospital: {samples_per_hospital}")
print("-" * 40)

history = federated_training(
    server=server,
    clients=clients,
    rounds=20,
    local_epochs=3
)

print("-" * 40)
print(f"Final global model accuracy: {history['accuracy'][-1]:.4f}")

19.4.4 Challenges in Federated Learning

Non-IID Data: Real hospitals have different patient populations. Data isn’t “independent and identically distributed” across clients.

# Simulating non-IID: each hospital specializes in different conditions
# Hospital 0: mostly young patients (class 0)
# Hospital 1: mostly elderly patients (class 1)
# This makes federated learning harder

def create_non_iid_data(n_hospitals, samples_per_hospital, skew=0.8):
    """Create non-IID data where each hospital has skewed class distribution."""
    all_data = []

    for h in range(n_hospitals):
        X = torch.randn(samples_per_hospital, n_features)

        # Skew: hospital h prefers class (h % 2)
        preferred_class = h % 2
        probs = torch.rand(samples_per_hospital)
        y = torch.where(
            probs < skew,
            torch.full((samples_per_hospital,), preferred_class),
            torch.randint(0, 2, (samples_per_hospital,))
        )
        all_data.append((X, y.long()))

    return all_data

# Non-IID data is harder to learn from
# Solutions: FedProx, local adaptation, personalization

Communication Efficiency: Sending full models every round is expensive. Techniques include: - Gradient compression - Less frequent communication - Partial model updates

Stragglers: Slow clients delay training. Solutions: - Asynchronous aggregation - Client selection strategies - Timeout mechanisms

19.4.5 Medical Applications

Federated learning enables previously impossible collaborations:

Multi-site clinical trials: Train AI on data from trial sites worldwide without centralizing patient data.

Rare disease research: Combine data across hospitals to reach meaningful sample sizes for rare conditions.

International collaboration: Work across jurisdictions with different privacy laws.

Real-world example: The Federated Tumor Segmentation (FeTS) challenge demonstrated federated learning for brain tumor segmentation across 30+ institutions globally.

19.5 Privacy Attacks and Defenses

Clinical Context: Your hospital participates in a federated learning consortium. You share model updates, not patient data. Is this actually private? Unfortunately, model updates can leak information about training data. Understanding attacks is essential for building robust defenses.

19.5.1 Membership Inference Attacks

Question: Was a specific patient’s data used to train this model?

Attack: Train a classifier that distinguishes model behavior on training data vs. non-training data. Models often have higher confidence on examples they were trained on.

def membership_inference_attack(target_model, member_data, non_member_data):
    """
    Simple membership inference based on prediction confidence.

    Higher confidence often indicates membership in training set.
    """
    target_model.eval()

    def get_confidence(model, data):
        with torch.no_grad():
            outputs = model(data)
            probs = torch.softmax(outputs, dim=1)
            confidence = probs.max(dim=1).values
        return confidence

    member_conf = get_confidence(target_model, member_data)
    non_member_conf = get_confidence(target_model, non_member_data)

    # Members typically have higher confidence
    print(f"Member confidence: {member_conf.mean():.4f} +/- {member_conf.std():.4f}")
    print(f"Non-member confidence: {non_member_conf.mean():.4f} +/- {non_member_conf.std():.4f}")

    # Simple threshold attack
    threshold = (member_conf.mean() + non_member_conf.mean()) / 2

    member_correct = (member_conf > threshold).float().mean()
    non_member_correct = (non_member_conf <= threshold).float().mean()

    attack_accuracy = (member_correct + non_member_correct) / 2
    print(f"Attack accuracy: {attack_accuracy:.4f}")

    return attack_accuracy

Defense: Differential privacy limits membership inference by ensuring model behavior is similar whether or not any individual was in training.

19.5.2 Model Inversion Attacks

Question: Can we reconstruct training data from the model?

Attack: Optimize an input to maximize model confidence for a specific class, potentially revealing what training examples looked like.

For medical imaging, this could mean reconstructing patient faces from a facial recognition model or patient X-rays from a diagnostic model.

Defense: Differential privacy, limiting model access, adding noise to outputs.

19.5.3 Gradient Leakage in Federated Learning

Question: Can shared gradients reveal training data?

Attack: Given gradient updates, reconstruct the training batch that produced them. This is especially effective for small batches and early training rounds.

def gradient_leakage_attack(gradient, model, input_shape):
    """
    Attempt to reconstruct training data from gradients.

    This is a simplified version - real attacks are more sophisticated.
    """
    # Initialize random input
    dummy_input = torch.randn(input_shape, requires_grad=True)
    dummy_label = torch.randint(0, 2, (input_shape[0],))

    optimizer = torch.optim.LBFGS([dummy_input])

    def closure():
        optimizer.zero_grad()

        # Compute gradients for dummy input
        model.zero_grad()
        output = model(dummy_input)
        loss = nn.functional.cross_entropy(output, dummy_label)
        dummy_gradient = torch.autograd.grad(loss, model.parameters(), create_graph=True)

        # Minimize distance between dummy gradient and actual gradient
        grad_diff = sum(
            ((dg - g) ** 2).sum()
            for dg, g in zip(dummy_gradient, gradient)
        )
        grad_diff.backward()

        return grad_diff

    # Optimize to find input that produces similar gradient
    for _ in range(100):
        optimizer.step(closure)

    return dummy_input.detach()

Defense: Secure aggregation ensures the server only sees aggregated gradients, not individual client contributions.

19.5.4 Combining Defenses

Robust privacy requires multiple layers:

  1. Differential privacy: Fundamental privacy guarantee
  2. Secure aggregation: Protect individual client updates
  3. Gradient clipping: Limit information in each update
  4. Minimum batch sizes: Prevent single-sample gradient reconstruction

19.6 Secure Multi-Party Computation

Clinical Context: Two hospitals want to compute the combined prevalence of a rare disease without either learning the other’s patient count. Secure computation enables this—computing on data while keeping inputs private.

19.6.1 What Secure Computation Enables

Secure Multi-Party Computation (MPC) allows parties to jointly compute a function over their inputs without revealing inputs to each other.

Example: Hospitals A and B have patient counts \(n_A\) and \(n_B\). They want to compute \(n_A + n_B\) such that: - A doesn’t learn \(n_B\) - B doesn’t learn \(n_A\) - Both learn \(n_A + n_B\)

This sounds impossible, but cryptographic techniques make it achievable.

19.6.2 Secret Sharing Intuition

Secret sharing splits a secret into shares such that: - Any single share reveals nothing about the secret - Combined shares reconstruct the secret

Simple example (additive sharing): - Secret: 42 - Share 1: random number r = 17 - Share 2: 42 - 17 = 25 - Neither 17 nor 25 reveals 42, but 17 + 25 = 42

For secure computation: 1. Each party secret-shares their input 2. Computation happens on shares (requires special protocols) 3. Result shares are combined to get final answer

19.6.3 Homomorphic Encryption

Homomorphic encryption allows computation on encrypted data:

\[E(a) \oplus E(b) = E(a + b)\]

You can add (or multiply, depending on scheme) encrypted values without decrypting. The result, when decrypted, equals what you’d get from computing on plaintext.

Fully homomorphic encryption (FHE) supports arbitrary computation on encrypted data, but it’s computationally expensive—often 1000x+ slower than plaintext.

19.6.4 Practical Considerations

Secure computation is powerful but has limitations:

Performance: MPC and FHE are orders of magnitude slower than plaintext computation. Training neural networks with MPC is possible but expensive.

Complexity: Implementing secure protocols correctly is hard. Use established libraries (MP-SPDZ, CrypTen, TenSEAL).

Trust model: MPC requires assumptions about adversary capabilities (honest-but-curious vs. malicious).

When to use: - High-value, high-sensitivity computations - Regulatory requirements mandate it - Performance cost is acceptable - Simpler methods (DP, FL) are insufficient

For most medical AI applications, differential privacy and federated learning provide sufficient protection at much lower cost.

19.7 Putting It Together: Private Medical AI Pipeline

Clinical Context: You’re designing a multi-hospital collaboration for sepsis prediction. Here’s how to combine privacy techniques for a robust solution.

19.7.1 Architecture Overview

┌─────────────┐     ┌─────────────┐     ┌─────────────┐
│  Hospital A │     │  Hospital B │     │  Hospital C │
│             │     │             │     │             │
│  Local Data │     │  Local Data │     │  Local Data │
│      ↓      │     │      ↓      │     │      ↓      │
│  DP-SGD     │     │  DP-SGD     │     │  DP-SGD     │
│  Training   │     │  Training   │     │  Training   │
│      ↓      │     │      ↓      │     │      ↓      │
│  Noisy      │     │  Noisy      │     │  Noisy      │
│  Gradients  │     │  Gradients  │     │  Gradients  │
└──────┬──────┘     └──────┬──────┘     └──────┬──────┘
       │                   │                   │
       └───────────────────┼───────────────────┘
                           │
                    ┌──────▼──────┐
                    │   Secure    │
                    │ Aggregator  │
                    │             │
                    │  FedAvg on  │
                    │   Noisy     │
                    │  Updates    │
                    └──────┬──────┘
                           │
                    ┌──────▼──────┐
                    │   Global    │
                    │   Model     │
                    └─────────────┘

19.7.2 Implementation

class PrivateFederatedLearning:
    """
    Federated learning with differential privacy.

    Combines:
    - Local DP-SGD training at each client
    - Secure aggregation of updates
    - Privacy budget tracking
    """

    def __init__(self, model, n_clients, target_epsilon, target_delta,
                 rounds, local_epochs):
        self.global_model = model
        self.n_clients = n_clients
        self.target_epsilon = target_epsilon
        self.target_delta = target_delta
        self.rounds = rounds
        self.local_epochs = local_epochs

        # Per-round privacy budget
        self.epsilon_per_round = target_epsilon / rounds

    def train_round(self, clients):
        """One round of private federated training."""
        client_models = []
        client_sizes = []

        for client in clients:
            # Each client trains with DP-SGD
            local_model = client.private_train(
                model=copy.deepcopy(self.global_model),
                epochs=self.local_epochs,
                epsilon=self.epsilon_per_round,
                delta=self.target_delta / self.rounds
            )
            client_models.append(local_model)
            client_sizes.append(client.data_size)

        # Aggregate (with secure aggregation in production)
        self.aggregate_models(client_models, client_sizes)

    def aggregate_models(self, client_models, client_sizes):
        """Weighted averaging of client models."""
        total_size = sum(client_sizes)

        with torch.no_grad():
            for key in self.global_model.state_dict():
                weighted_sum = sum(
                    (size / total_size) * model.state_dict()[key]
                    for model, size in zip(client_models, client_sizes)
                )
                self.global_model.state_dict()[key].copy_(weighted_sum)


class PrivateClient:
    """Client with differential privacy."""

    def __init__(self, client_id, data_loader):
        self.client_id = client_id
        self.data_loader = data_loader

    def private_train(self, model, epochs, epsilon, delta):
        """Train with differential privacy using Opacus."""
        from opacus import PrivacyEngine

        model = model.train()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

        privacy_engine = PrivacyEngine()
        model, optimizer, data_loader = privacy_engine.make_private_with_epsilon(
            module=model,
            optimizer=optimizer,
            data_loader=self.data_loader,
            epochs=epochs,
            target_epsilon=epsilon,
            target_delta=delta,
            max_grad_norm=1.0
        )

        criterion = nn.CrossEntropyLoss()

        for epoch in range(epochs):
            for batch_X, batch_y in data_loader:
                optimizer.zero_grad()
                loss = criterion(model(batch_X), batch_y)
                loss.backward()
                optimizer.step()

        return model

    @property
    def data_size(self):
        return len(self.data_loader.dataset)

19.7.3 Privacy Accounting

Track total privacy spent across all operations:

class PrivacyAccountant:
    """Track cumulative privacy loss."""

    def __init__(self, target_epsilon, target_delta):
        self.target_epsilon = target_epsilon
        self.target_delta = target_delta
        self.spent_epsilon = 0
        self.operations = []

    def spend(self, epsilon, operation_name):
        """Record privacy expenditure."""
        self.spent_epsilon += epsilon
        self.operations.append({
            'operation': operation_name,
            'epsilon': epsilon,
            'cumulative': self.spent_epsilon
        })

        if self.spent_epsilon > self.target_epsilon:
            raise ValueError(
                f"Privacy budget exceeded! Spent {self.spent_epsilon:.2f}, "
                f"budget was {self.target_epsilon:.2f}"
            )

    def remaining(self):
        return self.target_epsilon - self.spent_epsilon

    def report(self):
        print(f"Privacy Budget Report")
        print(f"Target: ε = {self.target_epsilon}, δ = {self.target_delta}")
        print(f"Spent: ε = {self.spent_epsilon:.4f}")
        print(f"Remaining: ε = {self.remaining():.4f}")
        print(f"\nOperations:")
        for op in self.operations:
            print(f"  {op['operation']}: ε = {op['epsilon']:.4f} "
                  f"(cumulative: {op['cumulative']:.4f})")

19.8 Making Privacy Decisions

19.8.1 Decision Framework

When designing a privacy-preserving system:

1. What's the threat model?
   - Who might attack? (curious collaborators, external adversaries)
   - What's their capability? (access to model, auxiliary data)
   - What's the harm if privacy fails?

2. What are the regulatory requirements?
   - HIPAA, GDPR, institutional policies
   - What level of de-identification is required?

3. What utility do you need?
   - Acceptable accuracy loss?
   - Query/training budget?

4. Choose techniques:
   - Data stays local → Federated learning
   - Need formal guarantees → Differential privacy
   - High-security requirements → Add secure computation
   - Combine as needed

19.8.2 Technique Selection Guide

Scenario Recommended Approach
Multi-site collaboration, moderate trust Federated learning
Releasing aggregate statistics Differential privacy
Sharing trained model publicly DP training
Very high sensitivity (genomics) FL + DP + secure aggregation
Single institution, internal use May not need privacy ML

19.8.3 Communicating Privacy

Stakeholders need to understand privacy guarantees:

For technical audiences: Report ε, δ, and methodology. “This model was trained with (ε=2, δ=10⁻⁵)-differential privacy using DP-SGD with gradient clipping at norm 1.0.”

For non-technical audiences: Explain implications. “Even an attacker who knows everything about the training data except for one patient cannot determine with confidence whether that patient was included.”

For patients: Focus on protections. “Your data helps improve medical AI, but the analysis is designed so that your individual information cannot be extracted.”

19.8.4 When Privacy-Preserving ML Isn’t Enough

Privacy techniques have limits:

  • Utility loss: Strong privacy may degrade model quality unacceptably
  • Implementation complexity: Bugs in privacy implementations can void guarantees
  • Scope: DP protects training data; other privacy risks (inference attacks, model outputs) need separate handling
  • Trust assumptions: FL requires trusting the aggregator; MPC requires trusting protocol implementation

Sometimes the answer is “don’t build this” or “this requires a fundamentally different approach.”

19.9 Chapter Summary

Privacy-preserving techniques enable medical AI development while protecting patient data.

The data access problem: - Medical AI benefits from scale, but data is siloed and sensitive - Regulations (HIPAA, GDPR) and institutional concerns limit sharing - Technical solutions enable collaboration without centralized data

Differential privacy: - Mathematical framework limiting what queries reveal about individuals - ε controls privacy-utility tradeoff - DP-SGD enables private model training - Privacy budgets must be tracked across all operations

Federated learning: - Train models without centralizing data - FedAvg: local training, aggregate updates - Challenges: non-IID data, communication, stragglers - Enables multi-site medical AI collaboration

Privacy attacks: - Membership inference: was this patient in training data? - Gradient leakage: reconstruct data from updates - Defenses: DP, secure aggregation, gradient clipping

Secure computation: - MPC and homomorphic encryption enable computation on private data - Powerful but expensive; use when simpler methods insufficient

Practical deployment: - Combine techniques: federated learning + differential privacy - Track privacy budget throughout pipeline - Match approach to threat model and requirements

These techniques don’t eliminate privacy risk, but they provide principled tools for managing it—enabling medical AI that would otherwise be impossible.

19.10 Exercises

  1. Implement the Laplace mechanism for a count query. Experiment with different ε values and visualize the accuracy vs. privacy tradeoff.

  2. Using Opacus, train a classifier on a medical dataset (e.g., UCI Heart Disease) with ε = 1.0. Compare accuracy to non-private training. How much utility is lost?

  3. Implement a simple federated learning simulation with 5 clients. Create non-IID data distributions and observe how this affects convergence compared to IID data.

  4. Implement a basic membership inference attack against a model you’ve trained. How effective is it? How does differential privacy training affect attack success?

  5. Design a privacy-preserving pipeline for a specific scenario: three competing hospitals want to train a shared stroke prediction model without revealing their patient populations or model details to each other. What techniques would you use and why?