import medmnist
from medmnist import PneumoniaMNIST
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
# Data transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
# Load datasets
train_dataset = PneumoniaMNIST(split='train', transform=transform,
download=True)
val_dataset = PneumoniaMNIST(split='val', transform=transform)
test_dataset = PneumoniaMNIST(split='test', transform=transform)
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)
test_loader = DataLoader(test_dataset, batch_size=64)
print(f"Training: {len(train_dataset)} images")
print(f"Validation: {len(val_dataset)} images")
print(f"Test: {len(test_dataset)} images")13 Imaging Case Studies
This chapter walks through complete medical imaging AI pipelines, from data loading to deployment considerations. Each case study illustrates techniques from earlier chapters applied to real clinical problems.
13.1 Chest X-Ray Classification
Clinical Context: Chest X-rays are the most common imaging study worldwide—over 2 billion performed annually. AI systems can triage urgent findings, assist in screening programs, and support diagnosis in resource-limited settings where radiologists are scarce. Early deep learning work demonstrated radiologist-level performance on specific findings (Rajpurkar et al. 2017), sparking a wave of clinical AI development (Topol 2019).
13.1.1 The PneumoniaMNIST Dataset
PneumoniaMNIST provides a standardized benchmark: 5,856 pediatric chest X-rays (28×28 grayscale) labeled as normal or pneumonia. While simplified from clinical resolution, it demonstrates the complete classification pipeline.
13.1.2 Handling Class Imbalance
PneumoniaMNIST is imbalanced: approximately 75% pneumonia, 25% normal. Without correction, the model may predict “pneumonia” for everything and still achieve 75% accuracy.
Strategies:
- Weighted loss: Increase penalty for misclassifying the minority class
- Oversampling: Duplicate minority class examples
- Threshold adjustment: Use a threshold other than 0.5 for classification
import torch.nn as nn
# Calculate class weights (inverse frequency)
class_counts = [len(train_dataset) - sum(train_dataset.labels),
sum(train_dataset.labels)]
weights = torch.tensor([1.0 / c for c in class_counts])
weights = weights / weights.sum() # Normalize
# Weighted cross-entropy loss
criterion = nn.CrossEntropyLoss(weight=weights)13.1.3 Training Loop with Validation
A complete training pipeline with early stopping:
import torch.optim as optim
from sklearn.metrics import roc_auc_score
import numpy as np
def train_epoch(model, loader, criterion, optimizer, device):
model.train()
total_loss = 0
for images, labels in loader:
images, labels = images.to(device), labels.squeeze().to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(loader)
def evaluate(model, loader, device):
model.eval()
all_labels, all_probs = [], []
with torch.no_grad():
for images, labels in loader:
images = images.to(device)
outputs = model(images)
probs = torch.softmax(outputs, dim=1)[:, 1]
all_labels.extend(labels.numpy().flatten())
all_probs.extend(probs.cpu().numpy())
auroc = roc_auc_score(all_labels, all_probs)
return auroc
# Training with early stopping
best_auroc = 0
patience_counter = 0
for epoch in range(100):
train_loss = train_epoch(model, train_loader, criterion,
optimizer, device)
val_auroc = evaluate(model, val_loader, device)
if val_auroc > best_auroc:
best_auroc = val_auroc
torch.save(model.state_dict(), 'best_model.pt')
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= 10:
print(f"Early stopping at epoch {epoch}")
break
print(f"Epoch {epoch}: Loss={train_loss:.4f}, Val AUROC={val_auroc:.4f}")13.1.4 Clinical Deployment Considerations
Moving from PneumoniaMNIST to real chest X-rays requires:
- Resolution: Clinical X-rays are 2000×2000+ pixels, not 28×28
- Preprocessing: DICOM handling, window/level adjustment, orientation normalization
- Multi-label: Real findings include pneumonia, effusion, cardiomegaly, nodules, etc.
- Uncertainty: Flag low-confidence predictions for radiologist review
FDA-cleared chest X-ray AI systems (e.g., Qure.ai qXR, Lunit INSIGHT CXR) typically detect 10+ findings and include explainability features like heatmaps.
13.2 CT Segmentation with U-Net
Clinical Context: Radiation therapy requires precise tumor delineation to maximize dose to cancer while sparing healthy tissue. Manual contouring takes 30–90 minutes per patient. AI-assisted segmentation can reduce this to minutes while improving consistency.
13.2.1 3D Medical Image Loading
CT volumes come as DICOM series. MONAI simplifies loading and preprocessing:
from monai.transforms import (
Compose, LoadImaged, EnsureChannelFirstd,
Spacingd, ScaleIntensityRanged, CropForegroundd,
RandCropByPosNegLabeld, RandFlipd, ToTensord
)
from monai.data import Dataset, DataLoader
# Define transforms for training
train_transforms = Compose([
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
Spacingd(keys=["image", "label"],
pixdim=(1.5, 1.5, 2.0), # Resample to uniform spacing
mode=("bilinear", "nearest")),
ScaleIntensityRanged(
keys=["image"],
a_min=-1000, a_max=400, # CT Hounsfield units
b_min=0.0, b_max=1.0,
clip=True
),
CropForegroundd(keys=["image", "label"], source_key="image"),
RandCropByPosNegLabeld(
keys=["image", "label"],
label_key="label",
spatial_size=(96, 96, 96), # 3D patch size
pos=1, neg=1, # Balance positive/negative samples
num_samples=4
),
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
ToTensord(keys=["image", "label"])
])
# Create dataset from file list
data_dicts = [
{"image": img_path, "label": seg_path}
for img_path, seg_path in zip(image_files, label_files)
]
train_ds = Dataset(data=data_dicts, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True)13.2.2 U-Net Training for Organ Segmentation
MONAI provides optimized U-Net implementations:
from monai.networks.nets import UNet
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
# 3D U-Net for liver segmentation
model = UNet(
spatial_dims=3,
in_channels=1,
out_channels=2, # Background + liver
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2
).to(device)
# Combined Dice + Cross-Entropy loss
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# Dice metric for evaluation
dice_metric = DiceMetric(include_background=False, reduction="mean")
# Training loop
for epoch in range(max_epochs):
model.train()
epoch_loss = 0
for batch_data in train_loader:
inputs = batch_data["image"].to(device)
labels = batch_data["label"].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
# Validation
model.eval()
with torch.no_grad():
for val_data in val_loader:
val_inputs = val_data["image"].to(device)
val_labels = val_data["label"].to(device)
val_outputs = model(val_inputs)
# Post-process: argmax to get discrete labels
val_outputs = torch.argmax(val_outputs, dim=1, keepdim=True)
dice_metric(y_pred=val_outputs, y=val_labels)
mean_dice = dice_metric.aggregate().item()
dice_metric.reset()
print(f"Epoch {epoch}: Loss={epoch_loss/len(train_loader):.4f}, "
f"Dice={mean_dice:.4f}")13.2.3 Evaluation Metrics for Segmentation
Beyond Dice score, clinical evaluation includes:
- Hausdorff distance: Maximum surface-to-surface distance (measures worst-case boundary error)
- Surface Dice: Dice computed only on boundary voxels (sensitive to contour accuracy)
- Volume difference: Absolute difference in segmented volume
from monai.metrics import HausdorffDistanceMetric, SurfaceDistanceMetric
hausdorff = HausdorffDistanceMetric(include_background=False,
percentile=95)
surface_dice = SurfaceDistanceMetric(include_background=False)
# Compute metrics
hausdorff(y_pred=predictions, y=ground_truth)
hd95 = hausdorff.aggregate().item()
surface_dice(y_pred=predictions, y=ground_truth)
sd = surface_dice.aggregate().item()
print(f"95% Hausdorff Distance: {hd95:.2f} mm")
print(f"Surface Dice: {sd:.4f}")13.3 Radiation Oncology Applications
Clinical Context: Radiation therapy planning involves multiple imaging and contouring tasks. AI can automate organ-at-risk delineation, predict treatment response, and optimize dose distributions.
13.3.1 Auto-Contouring Workflow
A typical auto-contouring pipeline:
- Import: Load planning CT from treatment planning system
- Preprocess: Resample to standard spacing, normalize intensities
- Segment: Run trained model to generate contours
- Post-process: Remove small islands, smooth boundaries
- Export: Save as DICOM RT Structure Set for clinical review
from rt_utils import RTStructBuilder
import numpy as np
def export_to_rtstruct(ct_series_path, segmentation, structure_name):
"""Export numpy segmentation mask to DICOM RT Structure."""
# Create RT Structure from CT series
rtstruct = RTStructBuilder.create_new(dicom_series_path=ct_series_path)
# Add segmentation as ROI
# segmentation: binary numpy array matching CT dimensions
rtstruct.add_roi(
mask=segmentation,
color=[255, 0, 0], # Red
name=structure_name
)
# Save DICOM file
rtstruct.save(f"{structure_name}_rtstruct.dcm")
return rtstruct
# Example: export liver segmentation
liver_mask = (model_output.argmax(dim=1) == 1).cpu().numpy()
export_to_rtstruct("./ct_series/", liver_mask[0], "Liver_AI")13.3.2 Quality Assurance Metrics
Before clinical use, auto-contours require QA:
- Geometric accuracy: Dice > 0.85 for most organs, HD95 < 5mm
- Dosimetric impact: Does using the auto-contour change the dose distribution significantly?
- Time savings: Measured reduction in contouring time
- Inter-observer comparison: AI vs. expert variability
13.4 Histopathology Analysis
Clinical Context: Pathologists examine tissue slides to diagnose cancer, grade tumors, and identify prognostic features. A single slide can contain billions of pixels—far too large for standard CNNs. Specialized techniques handle these gigapixel images.
13.4.1 Whole Slide Image Processing
Whole slide images (WSIs) are typically 50,000×50,000+ pixels. The standard approach: extract patches and aggregate predictions.
import openslide
from PIL import Image
import numpy as np
def extract_patches(slide_path, patch_size=256, level=0, stride=256):
"""Extract patches from whole slide image."""
slide = openslide.OpenSlide(slide_path)
width, height = slide.level_dimensions[level]
patches = []
coordinates = []
for y in range(0, height - patch_size, stride):
for x in range(0, width - patch_size, stride):
# Read patch
patch = slide.read_region((x, y), level,
(patch_size, patch_size))
patch = np.array(patch.convert('RGB'))
# Skip background (white) patches
if patch.mean() > 220:
continue
patches.append(patch)
coordinates.append((x, y))
return patches, coordinates
# Extract patches from slide
patches, coords = extract_patches("tumor_slide.svs",
patch_size=256, stride=128)
print(f"Extracted {len(patches)} tissue patches")13.4.2 Multiple Instance Learning
For slide-level labels (e.g., “this patient has cancer”), multiple instance learning (MIL) aggregates patch predictions:
import torch
import torch.nn as nn
from torchvision import models
class AttentionMIL(nn.Module):
"""Attention-based MIL for WSI classification."""
def __init__(self, num_classes=2):
super().__init__()
# Feature extractor (pretrained ResNet)
resnet = models.resnet18(weights='IMAGENET1K_V1')
self.features = nn.Sequential(*list(resnet.children())[:-1])
# Attention mechanism
self.attention = nn.Sequential(
nn.Linear(512, 128),
nn.Tanh(),
nn.Linear(128, 1)
)
# Classifier
self.classifier = nn.Linear(512, num_classes)
def forward(self, patches):
# patches: (num_patches, 3, 224, 224)
# Extract features for all patches
features = self.features(patches) # (num_patches, 512, 1, 1)
features = features.squeeze(-1).squeeze(-1) # (num_patches, 512)
# Compute attention weights
attention_weights = self.attention(features) # (num_patches, 1)
attention_weights = torch.softmax(attention_weights, dim=0)
# Weighted aggregation
slide_feature = (attention_weights * features).sum(dim=0)
# Classification
logits = self.classifier(slide_feature)
return logits, attention_weights
# High attention weights indicate which patches drove the prediction
model = AttentionMIL(num_classes=2)
logits, attention = model(patch_tensor)13.4.3 Clinical Applications
Histopathology AI is used for:
- Cancer detection: Identifying tumor regions in biopsies
- Grading: Gleason grading for prostate cancer, Nottingham grading for breast cancer
- Biomarker prediction: Predicting molecular markers (HER2, MSI) from H&E stains
- Prognosis: Predicting survival from tissue morphology
FDA-cleared systems include Paige Prostate (prostate cancer detection) and PathAI’s tools for liver fibrosis staging.
13.5 Multi-Modal Imaging
Clinical Context: Clinical decisions often integrate multiple imaging modalities—CT for anatomy, PET for metabolism, MRI for soft tissue detail. Multi-modal AI can leverage complementary information.
13.5.1 Fusion Strategies
Three main approaches to combine modalities:
Early fusion: Concatenate images as input channels
# Stack CT and PET as 2-channel input
combined = torch.cat([ct_tensor, pet_tensor], dim=1)
model = UNet(in_channels=2, out_channels=num_classes, ...)Late fusion: Separate encoders, combine features
class LateFusionModel(nn.Module):
def __init__(self):
super().__init__()
self.ct_encoder = ResNet18(in_channels=1)
self.pet_encoder = ResNet18(in_channels=1)
self.classifier = nn.Linear(512 * 2, num_classes)
def forward(self, ct, pet):
ct_features = self.ct_encoder(ct)
pet_features = self.pet_encoder(pet)
combined = torch.cat([ct_features, pet_features], dim=1)
return self.classifier(combined)Cross-attention fusion: Modalities attend to each other
class CrossModalAttention(nn.Module):
def __init__(self, dim):
super().__init__()
self.query = nn.Linear(dim, dim)
self.key = nn.Linear(dim, dim)
self.value = nn.Linear(dim, dim)
def forward(self, x1, x2):
# x1 attends to x2
q = self.query(x1)
k = self.key(x2)
v = self.value(x2)
attention = torch.softmax(q @ k.T / np.sqrt(k.shape[-1]), dim=-1)
return attention @ v13.5.2 Registration Challenges
Multi-modal fusion requires spatial alignment. CT and PET may be acquired on the same scanner (hardware fusion), but MRI requires software registration.
MONAI provides registration transforms:
from monai.transforms import Spacingd, Orientationd
# Ensure consistent spacing and orientation
align_transforms = Compose([
Spacingd(keys=["ct", "pet", "mri"],
pixdim=(1.0, 1.0, 1.0)),
Orientationd(keys=["ct", "pet", "mri"],
axcodes="RAS")
])For deformable registration (anatomical changes between scans), consider libraries like ANTsPy or MONAI’s registration modules.