import torch
import torch.nn as nn
class ConvAutoencoder(nn.Module):
def __init__(self, latent_dim=128):
super().__init__()
# Encoder: 224x224x1 -> 7x7x256 -> latent_dim
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, 4, stride=2, padding=1), # 112x112
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2, padding=1), # 56x56
nn.ReLU(),
nn.Conv2d(64, 128, 4, stride=2, padding=1), # 28x28
nn.ReLU(),
nn.Conv2d(128, 256, 4, stride=2, padding=1), # 14x14
nn.ReLU(),
nn.Conv2d(256, 256, 4, stride=2, padding=1), # 7x7
nn.ReLU(),
nn.Flatten(),
nn.Linear(256 * 7 * 7, latent_dim)
)
# Decoder: latent_dim -> 7x7x256 -> 224x224x1
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 256 * 7 * 7),
nn.Unflatten(1, (256, 7, 7)),
nn.ConvTranspose2d(256, 256, 4, stride=2, padding=1), # 14x14
nn.ReLU(),
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), # 28x28
nn.ReLU(),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), # 56x56
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), # 112x112
nn.ReLU(),
nn.ConvTranspose2d(32, 1, 4, stride=2, padding=1), # 224x224
nn.Sigmoid() # Output in [0, 1]
)
def forward(self, x):
z = self.encoder(x)
return self.decoder(z)
def encode(self, x):
return self.encoder(x)9 Generative Models
Generative models learn to create new data that resembles their training distribution. In medical imaging, this capability addresses fundamental challenges: scarce labeled data, privacy constraints, and the need to detect rare abnormalities. This chapter covers four generative approaches—autoencoders, VAEs, GANs, and diffusion models—with emphasis on practical applications and the critical question of when synthetic medical data can be trusted.
9.1 Why Generate Medical Images?
Clinical Context: Your hospital wants to build an AI system to detect rare pediatric bone tumors from X-rays. You have 50 confirmed cases and 10,000 normal images. No amount of clever training tricks will overcome this data imbalance—you simply don’t have enough positive examples. Generative models offer a potential solution: create synthetic tumor images to augment your training set.
9.1.1 The Data Scarcity Problem
Medical AI faces a fundamental constraint: labeled data is expensive and scarce. Unlike natural images (billions available on the internet), medical images require:
- Expert annotation: A radiologist’s time to label each image
- Patient consent: IRB approval and privacy protections
- Rare conditions: Some diseases simply don’t occur frequently enough
A typical academic medical center might have millions of images in their PACS, but only a few hundred labeled for any specific rare condition. This bottleneck limits what medical AI can achieve.
9.1.2 Three Use Cases for Generative Models
1. Data Augmentation
Generate synthetic training examples to expand limited datasets. If you have 50 real tumor images, perhaps you can generate 500 synthetic variations to help your classifier learn the pattern.
The key question: Do synthetic images actually improve model performance on real test cases? We’ll address validation strategies later in this chapter.
2. Anomaly Detection
Train a generative model exclusively on normal images. When presented with an abnormal image, the model fails to reconstruct it well—it’s never seen anything like it. This reconstruction error becomes an anomaly score.
This approach is powerful when abnormalities are diverse and rare. Instead of teaching the model what disease looks like (impossible with few examples), you teach it what normal looks like and flag deviations.
3. Image-to-Image Translation
Transform images between modalities or quality levels:
- CT ↔︎ MRI synthesis: Reduce the need for multiple expensive scans
- Low-dose → standard CT: Maintain diagnostic quality while reducing radiation exposure
- Stain normalization: Standardize histopathology slides from different labs
These translations can reduce patient burden, lower costs, and enable retrospective analysis when only one modality was acquired.
9.1.3 The Generative Modeling Landscape
This chapter covers four approaches, each with distinct strengths:
| Model | Key Idea | Strengths | Medical Applications |
|---|---|---|---|
| Autoencoder | Compress and reconstruct | Simple, fast, interpretable latent space | Anomaly detection, denoising |
| VAE | Probabilistic compression | Smooth latent space, principled generation | Data augmentation, interpolation |
| GAN | Adversarial training | High visual quality | Image synthesis, translation |
| Diffusion | Iterative denoising | State-of-the-art quality, stable training | Emerging: synthesis, super-resolution |
We’ll build up from simple autoencoders to cutting-edge diffusion models, always asking: how does this apply to clinical problems?
9.2 Autoencoders: Learning to Compress
Clinical Context: A hospital network wants to screen thousands of chest X-rays daily for abnormalities. Training a classifier for every possible disease is impractical—there are too many conditions, many extremely rare. Instead, they want a system that flags “unusual” images for radiologist review. An autoencoder trained on normal X-rays can identify images it struggles to reconstruct—potential anomalies.
9.2.1 The Encoder-Decoder Architecture
An autoencoder learns to compress an input into a low-dimensional representation, then reconstruct the original from this compressed form. It consists of:
- Encoder: Maps input \(x\) to a latent representation \(z = f_\text{enc}(x)\)
- Latent space: A bottleneck layer with far fewer dimensions than the input
- Decoder: Reconstructs the input \(\hat{x} = f_\text{dec}(z)\)
The network is trained to minimize reconstruction loss:
\[ \mathcal{L}_\text{recon} = \|x - \hat{x}\|^2 \]
The bottleneck forces the network to learn efficient representations—it can’t simply memorize inputs; it must discover the underlying structure.
9.2.2 Autoencoder Architecture for Medical Images
9.2.3 Anomaly Detection with Autoencoders
The key insight: an autoencoder trained on normal images learns to reconstruct normal anatomy. When presented with an abnormal image (tumor, fracture, foreign body), the reconstruction will be poor—the model has never learned to represent those patterns.
def compute_anomaly_score(model, image):
"""Higher score = more abnormal"""
model.eval()
with torch.no_grad():
reconstruction = model(image)
# Per-pixel reconstruction error
error = (image - reconstruction) ** 2
# Aggregate to single score
return error.mean().item()
def detect_anomalies(model, dataloader, threshold):
"""Flag images with reconstruction error above threshold"""
anomalies = []
for images, paths in dataloader:
for img, path in zip(images, paths):
score = compute_anomaly_score(model, img.unsqueeze(0))
if score > threshold:
anomalies.append((path, score))
return sorted(anomalies, key=lambda x: -x[1]) # Highest scores firstSetting the threshold: Use a validation set of normal images to establish a baseline. Set the threshold at a percentile (e.g., 95th) of normal reconstruction errors. Images exceeding this threshold are flagged for review.
Limitations: Autoencoders can sometimes reconstruct abnormalities if they share features with normal anatomy. They also struggle with global anomalies (wrong patient orientation) versus local ones (small nodule). The reconstruction error doesn’t localize the abnormality—you know something is wrong, but not where.
9.3 Variational Autoencoders
Clinical Context: Your research team wants to study how lung nodules vary in appearance. Rather than clustering existing nodules, you want to explore the space of possible nodule appearances—interpolating between small and large, solid and ground-glass. A variational autoencoder provides a structured latent space where such exploration is meaningful.
9.3.1 From Compression to Generation
Standard autoencoders learn a deterministic mapping: input \(x\) → latent \(z\) → reconstruction \(\hat{x}\). But the latent space has no structure—similar images might map to distant points, and many regions of latent space decode to nonsense.
Variational autoencoders (VAEs) impose structure by treating the latent space probabilistically (Kingma and Welling 2013). Instead of encoding to a point \(z\), we encode to a distribution \(q(z|x)\), typically Gaussian:
\[ q(z|x) = \mathcal{N}(\mu(x), \sigma^2(x)) \]
The encoder outputs two vectors: mean \(\mu\) and log-variance \(\log \sigma^2\). We sample \(z\) from this distribution during training.
9.3.2 The VAE Objective
VAEs optimize the evidence lower bound (ELBO), balancing two terms:
\[ \mathcal{L}_\text{VAE} = \underbrace{\mathbb{E}_{z \sim q(z|x)}[\log p(x|z)]}_{\text{reconstruction}} - \underbrace{D_\text{KL}(q(z|x) \| p(z))}_{\text{regularization}} \]
- Reconstruction term: The decoded sample should match the input (same as autoencoder)
- KL divergence term: The encoded distribution should be close to a standard normal \(p(z) = \mathcal{N}(0, I)\)
The KL term is crucial—it prevents the encoder from collapsing to point estimates and ensures the latent space is smooth and continuous. Similar images map to nearby distributions, and interpolating in latent space produces meaningful intermediate images.
9.3.3 VAE Implementation
class VAE(nn.Module):
def __init__(self, latent_dim=128):
super().__init__()
# Encoder outputs mean and log-variance
self.encoder_conv = nn.Sequential(
nn.Conv2d(1, 32, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.ReLU(),
nn.Flatten(),
)
self.fc_mu = nn.Linear(128 * 28 * 28, latent_dim)
self.fc_logvar = nn.Linear(128 * 28 * 28, latent_dim)
# Decoder (same as autoencoder)
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 128 * 28 * 28),
nn.Unflatten(1, (128, 28, 28)),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(32, 1, 4, stride=2, padding=1),
nn.Sigmoid()
)
def encode(self, x):
h = self.encoder_conv(x)
return self.fc_mu(h), self.fc_logvar(h)
def reparameterize(self, mu, logvar):
"""Sample z = mu + sigma * epsilon"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decoder(z), mu, logvar
def vae_loss(recon_x, x, mu, logvar):
"""ELBO loss = reconstruction + KL divergence"""
recon_loss = nn.functional.mse_loss(recon_x, x, reduction='sum')
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recon_loss + kl_loss9.3.4 Generation and Interpolation
Once trained, VAEs enable meaningful generation:
# Generate new images by sampling from prior
def generate_samples(model, num_samples, latent_dim, device):
model.eval()
with torch.no_grad():
z = torch.randn(num_samples, latent_dim).to(device)
samples = model.decoder(z)
return samples
# Interpolate between two images
def interpolate(model, img1, img2, steps=10):
model.eval()
with torch.no_grad():
mu1, _ = model.encode(img1)
mu2, _ = model.encode(img2)
interpolations = []
for alpha in torch.linspace(0, 1, steps):
z = (1 - alpha) * mu1 + alpha * mu2
interpolations.append(model.decoder(z))
return torch.cat(interpolations)Interpolation reveals what the model has learned about the data manifold. Interpolating between a small nodule and large nodule might show gradual growth. Interpolating between pneumonia and normal might reveal the transition from consolidated to clear lung fields.
9.3.5 Medical Applications of VAEs
Data augmentation: Sample from regions of latent space near existing rare examples to generate variations.
Latent space analysis: The latent dimensions often capture clinically meaningful factors of variation (size, intensity, location). This enables unsupervised discovery of disease subtypes.
Conditional generation: With labels, train a conditional VAE that generates images of a specific class on demand.
9.4 Generative Adversarial Networks
Clinical Context: Your AI team needs to train a model to segment liver tumors in CT scans, but you only have 200 labeled cases. A collaborating institution has 2,000 labeled cases, but data sharing agreements prohibit direct transfer. Could they instead share a GAN trained on their data, allowing you to generate synthetic training images?
9.4.1 The Adversarial Framework
Generative adversarial networks (GANs) train two networks in competition (Goodfellow et al. 2014):
- Generator \(G\): Takes random noise \(z\) and produces fake images \(G(z)\)
- Discriminator \(D\): Classifies images as real (from training data) or fake (from generator)
The generator tries to fool the discriminator; the discriminator tries to catch fakes. This adversarial dynamic pushes the generator to produce increasingly realistic images.
The training objective is a minimax game:
\[ \min_G \max_D \mathbb{E}_{x \sim p_\text{data}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))] \]
At equilibrium, the generator produces images indistinguishable from real data, and the discriminator outputs 0.5 (random guessing) for all inputs.
9.4.2 GAN Training Dynamics
GAN training is notoriously unstable. Common issues:
Mode collapse: The generator produces only a few types of images, ignoring the diversity of real data. It finds a few outputs that fool the discriminator and sticks with them.
Training instability: The generator and discriminator can oscillate rather than converge. If the discriminator becomes too strong, gradients to the generator vanish.
Evaluation difficulty: Unlike VAEs with a clear loss function, GAN quality is hard to measure. The discriminator loss doesn’t correlate well with image quality.
Practical remedies: - Use established architectures (DCGAN, StyleGAN) - Apply spectral normalization to discriminator - Use progressive growing (start low-resolution, gradually increase) - Monitor generated samples visually throughout training
9.4.3 Conditional GANs for Medical Imaging
Standard GANs generate random images from the data distribution. Conditional GANs (cGANs) add control by conditioning on additional information:
class ConditionalGenerator(nn.Module):
def __init__(self, latent_dim, num_classes, img_channels=1):
super().__init__()
self.label_embedding = nn.Embedding(num_classes, num_classes)
self.model = nn.Sequential(
nn.Linear(latent_dim + num_classes, 256 * 7 * 7),
nn.Unflatten(1, (256, 7, 7)),
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, img_channels, 4, stride=2, padding=1),
nn.Tanh()
)
def forward(self, z, labels):
label_embed = self.label_embedding(labels)
x = torch.cat([z, label_embed], dim=1)
return self.model(x)
# Generate pneumonia images specifically
z = torch.randn(16, latent_dim)
labels = torch.ones(16, dtype=torch.long) # Class 1 = pneumonia
fake_pneumonia = generator(z, labels)9.4.4 Image-to-Image Translation
Pix2pix extends cGANs to paired image translation: given an input image, generate a corresponding output image. Applications include:
- CT → synthetic MRI (or vice versa)
- Low-dose CT → standard-dose CT
- Segmentation mask → realistic image
The generator is typically a U-Net (encoder-decoder with skip connections), and the discriminator evaluates image patches rather than the whole image (PatchGAN).
CycleGAN enables unpaired translation—learning the mapping between two domains without requiring paired examples. This is powerful when paired data doesn’t exist: you have CT scans and MRI scans, but not from the same patients.
# Using a pretrained CycleGAN for CT-MRI translation (conceptual)
from models import CycleGAN
# Load pretrained model
model = CycleGAN.load('ct_to_mri_pretrained.pt')
# Translate CT to synthetic MRI
ct_scan = load_ct('patient_001.nii')
synthetic_mri = model.translate_A_to_B(ct_scan)
# Note: Synthetic images should NEVER replace real diagnostic scans
# They're useful for research, training, and hypothesis generation9.4.5 Quality Assessment
How do you know if generated images are good? Common metrics:
Fréchet Inception Distance (FID): Compares statistics of real and generated images in a pretrained network’s feature space. Lower is better. FID < 50 is generally considered reasonable for medical images.
Inception Score (IS): Measures how confident a classifier is on generated images and how diverse the generations are. Less used in medical imaging.
Clinical review: Ultimately, the gold standard is whether radiologists can distinguish real from fake, and whether synthetic images contain clinically accurate anatomy.
9.5 Diffusion Models
Clinical Context: Recent research shows that diffusion models can generate chest X-rays nearly indistinguishable from real ones, even to expert radiologists. These models are rapidly becoming the state-of-the-art for medical image synthesis. Understanding how they work—and their limitations—is essential as they move toward clinical applications.
9.5.1 The Denoising Perspective
Diffusion models take a different approach than GANs or VAEs. The key idea:
- Forward process: Gradually add noise to a real image until it becomes pure random noise
- Reverse process: Learn to gradually remove noise, recovering the original image
The forward process is fixed—just add Gaussian noise at each step:
\[ q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t} x_{t-1}, \beta_t I) \]
After enough steps (typically 1000), \(x_T\) is indistinguishable from random noise.
The model learns the reverse: given a noisy image \(x_t\), predict the slightly-less-noisy \(x_{t-1}\). At inference time, start from pure noise and iteratively denoise to generate a new image.
9.5.2 Why Diffusion Models Excel
Training stability: Unlike GANs, there’s no adversarial dynamic to balance. The model simply learns to predict noise—a straightforward regression task.
Sample quality: Diffusion models now produce higher-quality images than GANs on many benchmarks, including medical imaging.
Mode coverage: GANs often miss modes of the data distribution (mode collapse). Diffusion models cover the full distribution more reliably.
Controllability: Conditioning and guidance are natural extensions. Classifier-free guidance allows trading off diversity for fidelity.
The downside: slow sampling. Generating one image requires hundreds or thousands of denoising steps. Recent advances (DDIM, consistency models) reduce this dramatically.
9.5.3 Diffusion for Medical Imaging
# Using HuggingFace diffusers for medical image generation (conceptual)
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
import torch
# Define model architecture
model = UNet2DModel(
sample_size=256,
in_channels=1, # Grayscale medical images
out_channels=1,
layers_per_block=2,
block_out_channels=(64, 128, 256, 512),
down_block_types=(
"DownBlock2D", "DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"
),
up_block_types=(
"AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D", "UpBlock2D"
),
)
# Training loop (simplified)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
for images in dataloader:
# Sample random timesteps
timesteps = torch.randint(0, 1000, (images.shape[0],))
# Add noise according to timestep
noise = torch.randn_like(images)
noisy_images = noise_scheduler.add_noise(images, noise, timesteps)
# Predict the noise
noise_pred = model(noisy_images, timesteps).sample
# Simple MSE loss on noise prediction
loss = nn.functional.mse_loss(noise_pred, noise)
loss.backward()
optimizer.step()9.5.4 Conditional Diffusion and Guidance
Conditional diffusion models can generate images with specific attributes:
- Class-conditional: Generate images of a specific pathology
- Text-conditional: “Chest X-ray showing right lower lobe pneumonia”
- Image-conditional: Generate variations of an input image
Classifier-free guidance is a powerful technique: train the model both conditionally and unconditionally, then at inference time, extrapolate beyond the conditional distribution for higher-fidelity outputs.
# Classifier-free guidance (conceptual)
# During training: randomly drop condition with probability p
# During inference: combine conditional and unconditional predictions
def guided_sampling(model, condition, guidance_scale=7.5):
# Get conditional prediction
noise_pred_cond = model(noisy_image, condition)
# Get unconditional prediction
noise_pred_uncond = model(noisy_image, null_condition)
# Guided prediction: extrapolate away from unconditional
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
return noise_pred9.5.5 Current State in Medical Imaging
Diffusion models for medical imaging are an active research area:
- Synthetic data generation: Creating diverse training sets for rare conditions
- Super-resolution: Enhancing low-resolution or low-dose scans
- Inpainting: Filling in missing or corrupted regions
- Anomaly detection: Score images by likelihood under the diffusion model
Challenges remain: - Ensuring anatomical accuracy (no hallucinated structures) - Validating clinical utility of synthetic data - Computational cost for 3D volumetric data - Regulatory pathway for clinical use
The field is moving rapidly. Models that seem cutting-edge today may be standard practice within a year.
9.6 Putting It Together: Augmentation Pipeline
Clinical Context: You’re building a classifier for a rare lung condition with only 100 positive training examples. You decide to use generative augmentation to expand your dataset. This section walks through the complete pipeline: training a generative model, validating synthetic data quality, and measuring the impact on downstream classification.
9.6.1 Step 1: Train a Conditional VAE
We’ll use a conditional VAE to generate class-specific synthetic images:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
class ConditionalVAE(nn.Module):
def __init__(self, latent_dim=128, num_classes=2):
super().__init__()
self.latent_dim = latent_dim
# Class embedding
self.class_embed = nn.Embedding(num_classes, 64)
# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.ReLU(),
nn.Flatten(),
)
encoder_out_dim = 128 * 28 * 28
self.fc_mu = nn.Linear(encoder_out_dim + 64, latent_dim)
self.fc_logvar = nn.Linear(encoder_out_dim + 64, latent_dim)
# Decoder
self.fc_decode = nn.Linear(latent_dim + 64, 128 * 28 * 28)
self.decoder = nn.Sequential(
nn.Unflatten(1, (128, 28, 28)),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(32, 1, 4, stride=2, padding=1),
nn.Sigmoid()
)
def encode(self, x, labels):
h = self.encoder(x)
c = self.class_embed(labels)
h = torch.cat([h, c], dim=1)
return self.fc_mu(h), self.fc_logvar(h)
def decode(self, z, labels):
c = self.class_embed(labels)
h = torch.cat([z, c], dim=1)
h = self.fc_decode(h)
return self.decoder(h)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x, labels):
mu, logvar = self.encode(x, labels)
z = self.reparameterize(mu, logvar)
return self.decode(z, labels), mu, logvar
# Training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ConditionalVAE(latent_dim=128, num_classes=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(100):
model.train()
total_loss = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
recon, mu, logvar = model(images, labels)
# ELBO loss
recon_loss = nn.functional.mse_loss(recon, images, reduction='sum')
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
loss = recon_loss + kl_loss
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader.dataset):.4f}')9.6.2 Step 2: Generate Synthetic Training Data
def generate_synthetic_dataset(model, num_samples_per_class, latent_dim, device):
"""Generate balanced synthetic dataset"""
model.eval()
synthetic_images = []
synthetic_labels = []
with torch.no_grad():
for class_idx in range(2):
z = torch.randn(num_samples_per_class, latent_dim).to(device)
labels = torch.full((num_samples_per_class,), class_idx, dtype=torch.long).to(device)
generated = model.decode(z, labels)
synthetic_images.append(generated.cpu())
synthetic_labels.extend([class_idx] * num_samples_per_class)
return torch.cat(synthetic_images), torch.tensor(synthetic_labels)
# Generate 500 synthetic images per class
synthetic_images, synthetic_labels = generate_synthetic_dataset(
model, num_samples_per_class=500, latent_dim=128, device=device
)
print(f"Generated {len(synthetic_images)} synthetic images")9.6.3 Step 3: Validate Synthetic Data Quality
Before using synthetic data for training, validate its quality:
import numpy as np
from scipy import linalg
from torchvision.models import inception_v3
def calculate_fid(real_images, synthetic_images, device):
"""Calculate Fréchet Inception Distance"""
# Use pretrained InceptionV3 features
inception = inception_v3(pretrained=True, transform_input=False)
inception.fc = nn.Identity() # Remove final layer
inception = inception.to(device).eval()
def get_features(images):
# Resize to 299x299 and convert to 3-channel for Inception
images = nn.functional.interpolate(images, size=(299, 299))
images = images.repeat(1, 3, 1, 1) # Grayscale to RGB
features = []
with torch.no_grad():
for batch in torch.split(images, 32):
feat = inception(batch.to(device))
features.append(feat.cpu().numpy())
return np.concatenate(features)
real_features = get_features(real_images)
fake_features = get_features(synthetic_images)
# Calculate statistics
mu_real, sigma_real = real_features.mean(0), np.cov(real_features, rowvar=False)
mu_fake, sigma_fake = fake_features.mean(0), np.cov(fake_features, rowvar=False)
# FID formula
diff = mu_real - mu_fake
covmean = linalg.sqrtm(sigma_real @ sigma_fake)
if np.iscomplexobj(covmean):
covmean = covmean.real
fid = diff @ diff + np.trace(sigma_real + sigma_fake - 2 * covmean)
return fid
fid_score = calculate_fid(real_images, synthetic_images, device)
print(f"FID Score: {fid_score:.2f}")
# FID < 50 is generally reasonable; < 20 is goodVisual inspection: Always manually review a sample of generated images. Look for: - Anatomical plausibility (correct structures in correct locations) - Diversity (not all images look the same) - Artifact absence (no obvious generation artifacts)
9.6.4 Step 4: Train Classifier with Augmented Data
from torch.utils.data import ConcatDataset, TensorDataset
from torchvision import models
# Create combined dataset
real_dataset = train_dataset # Original 100 positive + negatives
synthetic_dataset = TensorDataset(synthetic_images, synthetic_labels)
# Option A: Replace minority class with synthetic
# Option B: Augment minority class with synthetic
# Option C: Use all synthetic + all real
# We'll use Option B: augment the minority class
combined_loader = DataLoader(
ConcatDataset([real_dataset, synthetic_dataset]),
batch_size=32, shuffle=True
)
# Train classifier
classifier = models.resnet18(weights='IMAGENET1K_V1')
classifier.fc = nn.Linear(classifier.fc.in_features, 2)
classifier = classifier.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-4)
for epoch in range(20):
classifier.train()
for images, labels in combined_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = classifier(images.repeat(1, 3, 1, 1)) # Grayscale to RGB
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()9.6.5 Step 5: Evaluate on Real Test Data
Critical: Always evaluate on held-out real data, never synthetic:
from sklearn.metrics import roc_auc_score, classification_report
def evaluate_classifier(model, test_loader, device):
model.eval()
all_labels = []
all_probs = []
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
outputs = model(images.repeat(1, 3, 1, 1))
probs = torch.softmax(outputs, dim=1)[:, 1]
all_labels.extend(labels.numpy())
all_probs.extend(probs.cpu().numpy())
auroc = roc_auc_score(all_labels, all_probs)
return auroc
# Compare: classifier trained with vs. without synthetic data
auroc_with_synthetic = evaluate_classifier(classifier_with_synthetic, real_test_loader, device)
auroc_without_synthetic = evaluate_classifier(classifier_without_synthetic, real_test_loader, device)
print(f"AUROC without synthetic augmentation: {auroc_without_synthetic:.3f}")
print(f"AUROC with synthetic augmentation: {auroc_with_synthetic:.3f}")
print(f"Improvement: {auroc_with_synthetic - auroc_without_synthetic:+.3f}")9.6.6 When Synthetic Augmentation Helps (and When It Doesn’t)
Likely to help: - Very small datasets (< 500 training images) - Severe class imbalance - When synthetic images are high quality and diverse
May not help: - Large datasets (real data is sufficient) - When synthetic images are low quality or lack diversity - When the classifier already achieves near-ceiling performance
Can hurt: - If synthetic data introduces distribution shift - If generated pathology is unrealistic - If the model memorizes synthetic artifacts rather than real features
Always validate with ablation studies: train with and without synthetic data, compare on real test sets.
9.7 Risks and Ethics of Synthetic Medical Data
Clinical Context: A startup offers to train custom radiology AI models using their proprietary synthetic data—“no need for real patient images.” Their marketing claims equal performance to models trained on real data. How do you evaluate this claim? What risks should you consider before deploying such a model?
9.7.1 Hallucinated Pathology
Generative models can create anatomically plausible images with pathology that never occurs in real patients—or miss subtle pathological features that distinguish disease subtypes.
The risk: A model trained on synthetic data learns to detect hallucinated disease patterns that don’t exist in real patients, or fails to learn rare but real patterns absent from the synthetic distribution.
Mitigation: - Always validate on real, held-out test data - Have domain experts review generated images for clinical accuracy - Compare feature representations: do synthetic and real images occupy the same embedding space?
9.7.2 Distribution Shift
Synthetic data, no matter how realistic, comes from a model’s learned distribution—not reality. Subtle differences can cause problems:
- Scanner characteristics: Real images vary by manufacturer, protocol, patient positioning
- Patient populations: Real data reflects the demographics of source institutions
- Edge cases: Rare presentations may be underrepresented or absent in synthetic data
A model that performs well on synthetic validation data may fail on real clinical data from a different population or scanner.
9.7.3 Detection of Synthetic Images
Can we tell if an image is synthetic? This matters for:
- Research integrity: Preventing fabricated data in publications
- Medical records: Ensuring patient records contain only real images
- Legal/regulatory: Establishing provenance of diagnostic images
Current detection methods include: - Forensic analysis of generation artifacts - Training classifiers to distinguish real from fake - Watermarking during generation
As generative models improve, detection becomes harder. This is an ongoing arms race.
9.7.4 Regulatory Considerations
Regulatory bodies (FDA, EMA) are still developing frameworks for AI trained on synthetic data:
- Validation requirements: What evidence is needed that synthetic-trained models work on real data?
- Documentation: How should synthetic data usage be disclosed?
- Liability: Who is responsible if a synthetic-trained model fails?
Current best practice: treat synthetic data as a supplement to, not replacement for, real data. Document all synthetic data usage. Validate extensively on real, diverse test sets.
9.7.5 Privacy Implications
Generative models trained on patient data may memorize and regenerate identifiable information:
- Membership inference: Can we determine if a specific patient was in the training data?
- Data extraction: Can we recover training images from a generative model?
Differential privacy techniques can provide formal guarantees, but often at the cost of generation quality. This is an active research area.
9.7.6 Recommendations for Clinical AI Teams
Never use synthetic data alone. Combine with real data and validate on real test sets.
Document everything. Record what generative model was used, how synthetic data was generated, and its proportion in training.
Validate with domain experts. Have radiologists review synthetic images for clinical plausibility.
Monitor deployed models. Track performance on real clinical data over time.
Stay current. The capabilities and risks of generative models evolve rapidly.
Synthetic medical data is a powerful tool, but it requires careful, thoughtful application. The goal is augmenting real data to improve clinical AI—not replacing the need for real patient data and rigorous validation.
9.8 Appendix 7A: Mathematical Foundations
This appendix provides formal derivations for readers who want the mathematical foundations behind VAEs and diffusion models.
9.8.1 VAE: The Evidence Lower Bound
We want to learn a generative model \(p_\theta(x)\) of images \(x\). Introduce a latent variable \(z\) and assume:
\[ p_\theta(x) = \int p_\theta(x|z) p(z) dz \]
where \(p(z) = \mathcal{N}(0, I)\) is the prior and \(p_\theta(x|z)\) is the decoder.
The problem: This integral is intractable—we can’t compute or differentiate through it.
The solution: Introduce an approximate posterior \(q_\phi(z|x)\) (the encoder) and derive a lower bound on \(\log p_\theta(x)\).
Starting from the log-likelihood:
\[ \log p_\theta(x) = \log \int p_\theta(x|z) p(z) dz \]
Multiply and divide by \(q_\phi(z|x)\):
\[ = \log \int \frac{p_\theta(x|z) p(z)}{q_\phi(z|x)} q_\phi(z|x) dz \]
Apply Jensen’s inequality (\(\log \mathbb{E}[X] \geq \mathbb{E}[\log X]\)):
\[ \geq \int q_\phi(z|x) \log \frac{p_\theta(x|z) p(z)}{q_\phi(z|x)} dz \]
Expanding:
\[ = \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] - D_\text{KL}(q_\phi(z|x) \| p(z)) \]
This is the Evidence Lower Bound (ELBO):
\[ \mathcal{L}_\text{ELBO} = \underbrace{\mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)]}_{\text{reconstruction}} - \underbrace{D_\text{KL}(q_\phi(z|x) \| p(z))}_{\text{regularization}} \]
Maximizing the ELBO simultaneously: - Trains the decoder to reconstruct inputs from latent codes - Trains the encoder to produce latent codes close to the prior
9.8.2 The Reparameterization Trick
To backpropagate through sampling \(z \sim q_\phi(z|x)\), we reparameterize:
\[ z = \mu_\phi(x) + \sigma_\phi(x) \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I) \]
Now \(z\) is a deterministic function of \(x\) and \(\epsilon\), and gradients flow through \(\mu_\phi\) and \(\sigma_\phi\).
9.8.3 KL Divergence for Gaussians
For \(q_\phi(z|x) = \mathcal{N}(\mu, \sigma^2 I)\) and \(p(z) = \mathcal{N}(0, I)\):
\[ D_\text{KL}(q \| p) = \frac{1}{2} \sum_{j=1}^{d} \left( \mu_j^2 + \sigma_j^2 - 1 - \log \sigma_j^2 \right) \]
This has a closed-form solution, making optimization efficient.
9.8.4 Diffusion Models: Forward and Reverse Processes
Forward process (fixed, adds noise):
Define a sequence of increasingly noisy images \(x_0, x_1, \ldots, x_T\) where \(x_0\) is the real image and \(x_T \approx \mathcal{N}(0, I)\).
\[ q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t} x_{t-1}, \beta_t I) \]
where \(\beta_t\) is a noise schedule (typically increasing from \(\beta_1 \approx 10^{-4}\) to \(\beta_T \approx 0.02\)).
A useful property: we can sample \(x_t\) directly from \(x_0\):
\[ q(x_t | x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha}_t} x_0, (1-\bar{\alpha}_t) I) \]
where \(\alpha_t = 1 - \beta_t\) and \(\bar{\alpha}_t = \prod_{s=1}^{t} \alpha_s\).
Reverse process (learned, removes noise):
\[ p_\theta(x_{t-1} | x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \sigma_t^2 I) \]
The model learns \(\mu_\theta\), predicting the mean of the less-noisy image.
9.8.5 Noise Prediction Formulation
Instead of predicting \(\mu_\theta\) directly, we can equivalently predict the noise \(\epsilon\) that was added:
\[ x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \epsilon, \quad \epsilon \sim \mathcal{N}(0, I) \]
Given \(x_t\), the model predicts \(\epsilon_\theta(x_t, t)\), and we recover:
\[ \mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(x_t, t) \right) \]
9.8.6 Training Objective
The training loss is simply the MSE between predicted and actual noise:
\[ \mathcal{L}_\text{simple} = \mathbb{E}_{t, x_0, \epsilon} \left[ \| \epsilon - \epsilon_\theta(x_t, t) \|^2 \right] \]
This is derived from a variational bound but empirically works better than the full objective.
9.8.7 Score Matching Perspective
An alternative view: the model learns the score function \(\nabla_{x_t} \log p(x_t)\)—the gradient of the log-density.
The noise prediction \(\epsilon_\theta\) is related to the score by:
\[ \nabla_{x_t} \log p(x_t) \approx -\frac{\epsilon_\theta(x_t, t)}{\sqrt{1-\bar{\alpha}_t}} \]
Sampling then follows the score (Langevin dynamics), moving toward regions of high probability.
9.8.8 Further Reading
- Kingma & Welling (2014). “Auto-Encoding Variational Bayes.” The original VAE paper.
- Goodfellow et al. (2014). “Generative Adversarial Nets.” The original GAN paper.
- Ho, Jain & Abbeel (2020). “Denoising Diffusion Probabilistic Models.” The DDPM paper that sparked diffusion model interest.
- Song et al. (2021). “Score-Based Generative Modeling through Stochastic Differential Equations.” Unifying framework for diffusion models.
- Kazerouni et al. (2023). “Diffusion Models in Medical Imaging: A Comprehensive Survey.” Recent overview of medical applications.