import pydicom
import matplotlib.pyplot as plt
# Load a DICOM file
dcm = pydicom.dcmread("chest_xray.dcm")
# Access metadata
print(f"Patient ID: {dcm.PatientID}")
print(f"Modality: {dcm.Modality}")
print(f"Pixel Spacing: {dcm.PixelSpacing} mm")
# Access pixel data as numpy array
pixels = dcm.pixel_array
print(f"Image shape: {pixels.shape}")
print(f"Data type: {pixels.dtype}") # Often uint16, not uint8
# Display
plt.imshow(pixels, cmap='gray')
plt.title(f"{dcm.Modality}: {dcm.BodyPartExamined}")
plt.show()8 Convolutional Architectures for Medical Imaging
Convolutional neural networks (CNNs) are the foundation of modern medical image analysis. This chapter covers the key architectural components—convolution, pooling, and skip connections—and their application to classification, segmentation, and volumetric medical data.
8.1 Medical Imaging Primer
Clinical Context: Before building AI models for medical images, you need to understand what makes these images different from photographs. A chest X-ray isn’t just a grayscale picture—it encodes physical measurements of tissue density. A CT scan isn’t a stack of photos—it’s a calibrated 3D map of radiodensity measured in Hounsfield units. Understanding these fundamentals prevents subtle but serious errors in preprocessing and model development.
8.1.1 Imaging Modalities
Medical imaging encompasses several distinct technologies, each producing images with different characteristics:
X-ray (Radiography): Projects 3D anatomy onto a 2D image. Dense structures (bone) appear bright; air-filled structures (lungs) appear dark. Fast and inexpensive, but superimposes all structures along the beam path.
Computed Tomography (CT): Reconstructs 3D volumes from multiple X-ray projections. Each voxel contains a Hounsfield unit (HU) value representing radiodensity:
- Air: −1000 HU
- Lung tissue: −500 HU
- Water/soft tissue: 0 HU
- Muscle: +40 HU
- Bone: +400 to +1000 HU
This calibrated scale means CT values are comparable across scanners and patients—a property most natural images lack.
Magnetic Resonance Imaging (MRI): Uses magnetic fields and radio waves to image soft tissue. Excellent contrast between tissue types, but intensities are not calibrated—the same tissue can have different values across scanners, sequences, or even the same scan on different days. This makes MRI preprocessing more challenging than CT.
Ultrasound: Real-time imaging using sound waves. Operator-dependent, with artifacts from air and bone interfaces. Increasingly used with AI for point-of-care applications.
8.1.2 The DICOM Standard
Medical images are stored in DICOM (Digital Imaging and Communications in Medicine) format—a standard that bundles pixel data with extensive metadata. Unlike JPEG or PNG, a DICOM file contains:
- Pixel data: The actual image intensities (often 16-bit, not 8-bit)
- Patient information: ID, name, birth date, sex
- Acquisition parameters: Scanner model, imaging protocol, date/time
- Geometric information: Pixel spacing, slice thickness, patient orientation
- Clinical context: Body part, referring physician, study description
This metadata is essential for correct interpretation. Pixel spacing tells you whether a 10-pixel nodule is 5mm or 15mm. Patient orientation tells you which side is left vs. right. Ignoring metadata leads to models that fail silently on real clinical data.
A minimal example using pydicom:
For comprehensive DICOM handling—including coordinate systems, multi-frame images, and format conversion—see Appendix D.
8.1.3 Window and Level: Seeing What Matters
CT images contain far more intensity values than a monitor can display. A 12-bit CT has 4,096 possible HU values, but your screen shows only 256 gray levels. Windowing maps a subset of HU values to displayable intensities:
- Window center (level): The HU value at the middle of the display range
- Window width: The range of HU values mapped to the display
Different windows reveal different anatomy in the same CT scan:
| Window Name | Center | Width | Reveals |
|---|---|---|---|
| Lung | −600 | 1500 | Airways, nodules, parenchyma |
| Mediastinum | 40 | 400 | Soft tissue, vessels, lymph nodes |
| Bone | 400 | 1800 | Fractures, bone lesions |
| Brain | 40 | 80 | Gray/white matter, acute bleeding |
import numpy as np
def apply_window(image, center, width):
"""Apply window/level to CT image (in Hounsfield units)."""
lower = center - width // 2
upper = center + width // 2
windowed = np.clip(image, lower, upper)
# Normalize to 0-1 range
windowed = (windowed - lower) / (upper - lower)
return windowed
# Same CT, different views
lung_view = apply_window(ct_hu, center=-600, width=1500)
soft_tissue_view = apply_window(ct_hu, center=40, width=400)
bone_view = apply_window(ct_hu, center=400, width=1800)Why this matters for AI: If you train on lung-windowed images, your model learns lung-relevant features. Training on bone-windowed images teaches different features from the same underlying data. Some approaches stack multiple windows as input channels, giving the model access to different “views” simultaneously.
8.1.4 Intensity Normalization
Unlike natural images (0–255 RGB), medical images have modality-specific intensity ranges:
- CT: Hounsfield units, typically −1024 to +3071
- MRI: Arbitrary units, varying by sequence and scanner
- X-ray: Detector-dependent, often 12-bit or 14-bit
Before feeding images to neural networks, normalize appropriately:
def normalize_ct(image):
"""Normalize CT from HU to 0-1 range."""
# Clip to clinically relevant range
image = np.clip(image, -1024, 3071)
# Shift and scale
image = (image + 1024) / (3071 + 1024)
return image
def normalize_mri(image):
"""Z-score normalization for MRI (per-image)."""
# MRI lacks absolute scale, so normalize per image
mean = np.mean(image)
std = np.std(image)
return (image - mean) / (std + 1e-8)For CT, the fixed HU scale allows consistent normalization. For MRI, per-image or per-dataset z-score normalization is standard because absolute intensities are meaningless.
8.1.5 From DICOM to Training Data
Real-world medical imaging AI pipelines typically follow this flow:
- Load DICOM: Read files, extract pixel data and relevant metadata
- Apply corrections: Rescale slope/intercept (for CT), handle photometric interpretation
- Resample: Standardize pixel spacing across the dataset
- Window/normalize: Apply appropriate intensity transformation
- Convert format: Save as NumPy, NIfTI, or PNG for efficient training
Libraries like MONAI and TorchIO abstract much of this complexity, but understanding the underlying steps helps debug issues when preprocessing fails silently.
from monai.transforms import (
LoadImage, EnsureChannelFirst, Spacing,
ScaleIntensityRange, ToTensor, Compose
)
# MONAI preprocessing pipeline
preprocess = Compose([
LoadImage(image_only=True), # Handles DICOM, NIfTI, etc.
EnsureChannelFirst(), # Add channel dimension
Spacing(pixdim=(1.0, 1.0, 1.0)), # Resample to 1mm isotropic
ScaleIntensityRange(
a_min=-1024, a_max=3071, # CT HU range
b_min=0, b_max=1, clip=True
),
ToTensor()
])
# Apply to a CT volume
tensor = preprocess("path/to/ct_series/")With these fundamentals in place, we can now explore how convolutional neural networks exploit the spatial structure of medical images.
8.2 Why Convolutions?
Clinical Context: A radiologist doesn’t examine a chest X-ray pixel by pixel in random order. They scan systematically, looking for patterns—a hazy opacity here, an enlarged cardiac silhouette there. The spatial arrangement matters: pixels near each other form coherent structures like lung fields, ribs, and airways. Any effective image analysis system must exploit this spatial structure.
8.2.1 The Problem with Flat Vectors
Chapter 7’s feedforward networks treated images as flat vectors—a 224×224 image becomes 50,176 numbers fed into fully connected layers. This approach has two major problems:
Parameter explosion. A fully connected layer mapping 50,176 inputs to 1,000 hidden units requires 50 million parameters—just for the first layer. Deep networks become impossibly large and prone to overfitting.
Lost spatial structure. Flattening destroys neighborhood relationships. Pixel (100, 100) and pixel (100, 101) are adjacent in the image but become arbitrary positions in the flat vector. The network must learn from scratch that nearby pixels are related—a massive waste when this structure is known a priori.
8.2.2 The Key Insight: Local Patterns
Consider what a radiologist actually looks for: edges between tissue types, circular opacities suggesting nodules, linear densities indicating vessels. These patterns are local—they depend on small neighborhoods of pixels, not the entire image at once.
Moreover, these patterns can appear anywhere in the image. A pulmonary nodule in the upper right lung looks the same as one in the lower left. An effective detector should recognize the pattern regardless of position.
Convolutions embody both insights:
- Local connectivity: Each output depends only on a small input region (the receptive field)
- Weight sharing: The same pattern detector slides across the entire image
This reduces parameters dramatically and bakes in the prior knowledge that images have spatial structure.
8.2.3 Intuition: The Sliding Magnifying Glass
Imagine scanning a chest X-ray with a small magnifying glass—say, a 3×3 pixel window. At each position, you ask: “Does this local region contain an edge? A bright spot? A texture pattern?” You slide the magnifying glass across the entire image, building a map of where each pattern appears.
That’s exactly what a convolutional layer does. The “magnifying glass” is a learnable filter (kernel) that detects a specific pattern. Sliding it across the image produces a feature map showing where that pattern is strong. Stack many filters, and you detect many different patterns simultaneously.
The beauty is that the same filter works everywhere—we don’t need separate edge detectors for the left lung and right lung. This translation invariance is why CNNs excel at visual recognition.
8.3 The Convolution Operation
Clinical Context: A chest X-ray classifier must detect subtle patterns like pulmonary nodules, rib fractures, or cardiomegaly regardless of their position in the image. Convolutional layers achieve this by learning filters that slide across the entire image, detecting local features in a position-invariant manner.
8.3.1 Convolution Intuition
A 2D convolution slides a small filter (kernel) across an input image, computing element-wise products and summing them at each position:
\[ (X * K)_{i,j} = \sum_{m}\sum_{n} X_{i+m, j+n} \cdot K_{m,n} \]
where \(X\) is the input, \(K\) is a learnable kernel (typically 3×3 or 5×5), and the output is a feature map.
Key concepts:
- Stride: How many pixels the filter moves between positions. Stride 2 halves spatial dimensions.
- Padding: Adding zeros around the input to control output size. “Same” padding preserves dimensions.
- Multiple channels: A kernel for RGB images has shape \((3 \times H \times W)\). Medical images may have more channels (e.g., CT with different window/level views).
Output dimension formula for input size \(N\), kernel size \(K\), padding \(P\), and stride \(S\):
\[ N_{\text{out}} = \left\lfloor \frac{N - K + 2P}{S} \right\rfloor + 1 \]
8.3.2 What Convolutions Learn
Early layers learn simple edge and texture detectors—horizontal edges, vertical edges, color gradients. Deeper layers combine these into complex patterns: lung fields, bone contours, organ boundaries. This hierarchical feature learning is why CNNs excel at visual recognition.
In medical imaging:
- Layer 1–2: Edges, intensity gradients, noise patterns
- Layer 3–5: Textures, local shapes, tissue boundaries
- Layer 6+: Anatomical structures, pathological patterns
8.4 Pooling and Feature Hierarchies
Pooling reduces spatial dimensions while retaining important features. The two main types:
Max pooling: Takes maximum value in each window. Preserves strongest activations.
\[ \text{MaxPool}(X)_{i,j} = \max_{m,n \in \text{window}} X_{i+m, j+n} \]
Average pooling: Takes mean value. Smoother but may dilute strong signals.
\[ \text{AvgPool}(X)_{i,j} = \frac{1}{|\text{window}|} \sum_{m,n \in \text{window}} X_{i+m, j+n} \]
A 2×2 pooling with stride 2 halves each spatial dimension, reducing computation by 4× per layer. This creates a feature pyramid: early layers have high spatial resolution but few channels; deep layers have low resolution but many channels capturing abstract concepts.
Global average pooling (GAP) reduces each feature map to a single value—commonly used before the final classification layer to aggregate spatial information.
8.5 Classic Architectures
Clinical Context: When building a pneumonia classifier for PneumoniaMNIST, you don’t start from scratch. Proven architectures developed on natural images transfer remarkably well to medical imaging. Understanding their design principles helps you choose and adapt them effectively.
8.5.1 LeNet: The Pioneer
Yann LeCun’s LeNet-5 (1998) established the CNN template: alternating convolution and pooling layers, followed by fully connected layers. For small images like 28×28 grayscale, a LeNet-style architecture suffices:
import torch.nn as nn
class SimpleCNN(nn.Module):
def __init__(self, num_classes=2):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(64 * 7 * 7, 128),
nn.ReLU(),
nn.Linear(128, num_classes)
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x8.5.2 VGG: Simplicity Through Depth
VGG networks (2014) showed that stacking many 3×3 convolutions outperforms fewer large kernels. Two 3×3 layers have the same receptive field as one 5×5 layer but fewer parameters and more nonlinearity.
8.5.3 ResNet: Skip Connections
Residual networks (ResNet, 2015) introduced skip connections that add the input directly to the output of a block (He et al. 2016):
\[ y = F(x) + x \]
where \(F(x)\) represents the convolutional layers. This simple change enables training of very deep networks (50, 101, even 152 layers) by providing gradient highways that bypass many layers.
Skip connections solve the degradation problem: without them, adding more layers can actually hurt performance because gradients vanish or the network struggles to learn identity mappings.
For medical imaging, ResNet-18 and ResNet-50 are workhorses. Use pretrained ImageNet weights as initialization:
from torchvision import models
import torch.nn as nn
# Load pretrained ResNet-18
model = models.resnet18(weights='IMAGENET1K_V1')
# Modify for binary classification (e.g., pneumonia vs normal)
model.fc = nn.Linear(model.fc.in_features, 2)
# For single-channel images (grayscale), modify first conv layer
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2,
padding=3, bias=False)8.6 3D Convolutions for Volumetric Data
Clinical Context: CT and MRI scans are 3D volumes—stacks of 2D slices. A lung nodule might span multiple slices; analyzing slices independently loses this 3D context. 3D convolutions process the full volume, capturing inter-slice relationships.
8.6.1 3D Convolution Formulation
A 3D kernel has shape \((C \times D \times H \times W)\) where \(D\) is depth (number of slices). The operation extends naturally:
\[ \text{Conv3d}(X, K)_{i,j,k} = \sum_{c}\sum_{d}\sum_{m}\sum_{n} X_{c,i+d,j+m,k+n} \cdot K_{c,d,m,n} \]
Common kernel sizes: 3×3×3 or 3×3×1 (treating depth differently than spatial dimensions due to anisotropic voxel spacing in medical scans).
8.6.2 Memory Considerations
3D convolutions are memory-intensive. A 512×512×128 CT volume with 64 feature channels requires ~8 GB just for one layer’s activations. Strategies:
- Patch-based training: Process 64×64×64 patches, aggregate at inference
- Mixed precision: Use float16 for forward/backward, float32 for gradients
- Gradient checkpointing: Recompute activations during backward pass instead of storing
The MONAI library provides optimized 3D architectures designed for medical imaging:
from monai.networks.nets import DenseNet121
# 3D DenseNet for volumetric classification
model = DenseNet121(
spatial_dims=3,
in_channels=1,
out_channels=2
)8.7 U-Net for Segmentation
Clinical Context: Radiation therapy planning requires precise tumor delineation. Unlike classification (one label per image), segmentation assigns a label to every pixel. U-Net’s encoder-decoder architecture with skip connections has become the standard for medical image segmentation (Ronneberger, Fischer, and Brox 2015).
8.7.1 Encoder-Decoder Architecture
U-Net consists of:
- Encoder (contracting path): Repeated convolution + pooling, extracting features while reducing resolution
- Bottleneck: Deepest layer with lowest resolution, highest-level features
- Decoder (expanding path): Upsampling + convolution to restore resolution
- Skip connections: Concatenate encoder features with decoder features at each resolution level
The skip connections are critical: they provide high-resolution spatial information from early layers to the decoder, enabling precise boundary localization.
8.7.2 U-Net Implementation
MONAI provides flexible U-Net implementations:
from monai.networks.nets import UNet
model = UNet(
spatial_dims=2,
in_channels=1,
out_channels=3, # background, tumor, organ
channels=(32, 64, 128, 256),
strides=(2, 2, 2),
)
# For 3D segmentation (CT/MRI volumes)
model_3d = UNet(
spatial_dims=3,
in_channels=1,
out_channels=2,
channels=(16, 32, 64, 128),
strides=(2, 2, 2),
)8.7.3 Segmentation Loss Functions
For imbalanced segmentation (small tumors in large images):
Dice loss: Measures overlap between prediction and ground truth \[ \mathcal{L}_{\text{Dice}} = 1 - \frac{2|P \cap G|}{|P| + |G|} \]
Focal loss: Down-weights easy examples, focusing on hard cases
Combined loss: Dice + cross-entropy often works best
from monai.losses import DiceCELoss
loss_fn = DiceCELoss(to_onehot_y=True, softmax=True)8.8 Transfer Learning
Clinical Context: Medical imaging datasets are often small (hundreds to thousands of images) due to annotation costs and privacy constraints. Transfer learning leverages features learned from large natural image datasets—a technique that became practical after deep CNNs achieved breakthrough performance on ImageNet (Krizhevsky, Sutskever, and Hinton 2012)—to bootstrap medical image analysis.
8.8.1 Why Transfer Learning Works
Early CNN layers learn universal visual features—edges, textures, shapes—that apply across domains. A model trained on cats and cars has already learned to detect boundaries and gradients; fine-tuning teaches it which patterns indicate pneumonia vs. normal lung.
Studies show that ImageNet pretraining improves medical imaging performance even though the domains differ substantially. The pretrained weights provide a better starting point than random initialization.
8.8.2 Fine-Tuning Strategies
Three common approaches:
Feature extraction: Freeze all pretrained layers, train only the new classifier head. Fast but limited adaptation.
Fine-tuning: Unfreeze some or all layers, train with small learning rate. More flexible but risks catastrophic forgetting.
Gradual unfreezing: Start with frozen layers, progressively unfreeze deeper layers as training continues.
Example for PneumoniaMNIST:
import torch
import torch.nn as nn
from torchvision import models
# Load pretrained ResNet-18
model = models.resnet18(weights='IMAGENET1K_V1')
# Strategy 1: Feature extraction (freeze backbone)
for param in model.parameters():
param.requires_grad = False
# Replace classifier for binary task
model.fc = nn.Linear(512, 2)
# Strategy 2: Fine-tuning (unfreeze all, lower learning rate)
for param in model.parameters():
param.requires_grad = True
optimizer = torch.optim.Adam([
{'params': model.fc.parameters(), 'lr': 1e-3},
{'params': model.layer4.parameters(), 'lr': 1e-4},
{'params': model.layer3.parameters(), 'lr': 1e-5},
], lr=1e-5) # Base rate for other layers8.8.3 When Transfer Learning Fails
Transfer learning may underperform when:
- Domain gap is too large (e.g., dermoscopy vs. chest X-ray)
- Target dataset is very large (>100k images)—training from scratch may work better
- Image characteristics differ fundamentally (e.g., 16-bit CT Hounsfield units vs. 8-bit natural images)
For highly specialized modalities, consider self-supervised pretraining on unlabeled medical images before fine-tuning on labeled data.
8.9 Vision Transformers
Clinical Context: Transformers revolutionized NLP; can they do the same for medical imaging? Vision Transformers (ViT) treat images as sequences of patches, applying self-attention instead of convolution. For very large datasets, they can outperform CNNs.
8.9.1 Patch Embedding
ViT divides an image into fixed-size patches (e.g., 16×16 pixels), flattens each patch into a vector, and adds positional embeddings:
\[ z_0 = [\texttt{CLS}; x_1 E; x_2 E; \ldots; x_N E] + E_{\text{pos}} \]
where \(E\) is a learnable projection matrix and \(E_{\text{pos}}\) encodes spatial position.
The sequence of patch embeddings is processed by standard Transformer encoder layers. The CLS token’s final representation is used for classification.
8.9.2 When to Use ViT vs. CNN
Current recommendations for medical imaging:
- Small datasets (<10k images): Use CNN with transfer learning; ViT needs more data
- Medium datasets (10k–100k): Either works; CNN may have slight edge
- Large datasets (>100k): ViT can outperform; consider hybrid architectures
Hybrid approaches (CNN backbone + Transformer layers) often provide the best of both worlds, combining CNN’s inductive biases with Transformer’s global attention.
from torchvision.models import vit_b_16, ViT_B_16_Weights
import torch.nn as nn
# Pretrained Vision Transformer
model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
# Modify for medical classification
model.heads = nn.Linear(768, 2)For medical imaging applications, the jury is still out on ViT dominance. Most successful deployments still use CNNs, but research is rapidly evolving.
8.10 Putting It Together: Chest X-ray Classification
Clinical Context: You’ve been given a dataset of 5,000 chest X-rays labeled as “normal” or “pneumonia” and asked to build a classifier. This section walks through the complete pipeline—from loading data to evaluating clinical performance.
8.10.1 Data Loading and Augmentation
Medical imaging datasets require careful preprocessing. Unlike natural images, medical images have specific characteristics: consistent orientation, calibrated intensity ranges, and clinical conventions. Augmentation must respect these constraints.
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# Training transforms with augmentation
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(), # Lungs are roughly symmetric
transforms.RandomRotation(10), # Small rotations only
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Validation/test transforms (no augmentation)
val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Load datasets (assumes ImageFolder structure: data/train/normal/, data/train/pneumonia/)
train_dataset = datasets.ImageFolder('data/train', transform=train_transform)
val_dataset = datasets.ImageFolder('data/val', transform=val_transform)
test_dataset = datasets.ImageFolder('data/test', transform=val_transform)
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)Augmentation guidelines for medical imaging:
- Horizontal flip: Often safe for bilateral structures (lungs, brain hemispheres). Avoid for lateralized findings.
- Rotation: Keep small (±10–15°). Large rotations are clinically unrealistic.
- Scaling/zoom: Modest ranges. Extreme zoom loses diagnostic context.
- Intensity augmentation: Simulate different exposure/contrast settings.
- Avoid: Vertical flips (unrealistic), extreme distortions, cropping that removes anatomy.
8.10.2 Handling Class Imbalance
Medical datasets are often imbalanced—many more normal cases than pathological ones. A naive classifier achieves high accuracy by predicting “normal” for everything, which is clinically useless.
Two common solutions:
Weighted loss function: Penalize errors on the minority class more heavily.
import torch.nn as nn
# Count samples per class
class_counts = [3500, 1500] # [normal, pneumonia]
total = sum(class_counts)
weights = torch.tensor([total / c for c in class_counts])
weights = weights / weights.sum() # Normalize
criterion = nn.CrossEntropyLoss(weight=weights.to(device))Oversampling: Sample minority class more frequently during training.
from torch.utils.data import WeightedRandomSampler
# Compute sample weights
sample_weights = [weights[label] for _, label in train_dataset]
sampler = WeightedRandomSampler(sample_weights, len(sample_weights))
train_loader = DataLoader(train_dataset, batch_size=32, sampler=sampler)8.10.3 Model Setup with Transfer Learning
For a dataset of 5,000 images, transfer learning from ImageNet is essential. We use ResNet-18 as a lightweight but effective backbone:
import torch
import torch.nn as nn
from torchvision import models
# Load pretrained ResNet-18
model = models.resnet18(weights='IMAGENET1K_V1')
# Modify first layer for grayscale input (if needed)
# model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
# Replace classifier for binary classification
model.fc = nn.Linear(model.fc.in_features, 2)
# Move to GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# Optimizer with differential learning rates
optimizer = torch.optim.Adam([
{'params': model.fc.parameters(), 'lr': 1e-3}, # New layers: higher LR
{'params': model.layer4.parameters(), 'lr': 1e-4}, # Fine-tune last block
{'params': model.layer3.parameters(), 'lr': 1e-5}, # Fine-tune earlier blocks
], lr=1e-5)
# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', patience=3, factor=0.5
)8.10.4 The Training Loop
Putting it all together with validation monitoring:
def train_epoch(model, loader, criterion, optimizer, device):
model.train()
running_loss = 0.0
correct = 0
total = 0
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
return running_loss / total, correct / total
def evaluate(model, loader, criterion, device):
model.eval()
running_loss = 0.0
all_labels = []
all_probs = []
with torch.no_grad():
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
running_loss += loss.item() * images.size(0)
probs = torch.softmax(outputs, dim=1)[:, 1] # Probability of pneumonia
all_labels.extend(labels.cpu().numpy())
all_probs.extend(probs.cpu().numpy())
return running_loss / len(loader.dataset), all_labels, all_probs
# Training loop
best_val_loss = float('inf')
num_epochs = 30
for epoch in range(num_epochs):
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
val_loss, val_labels, val_probs = evaluate(model, val_loader, criterion, device)
scheduler.step(val_loss)
print(f'Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.3f}, '
f'Val Loss: {val_loss:.4f}')
# Save best model
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), 'best_chest_xray_model.pt')
# Load best model for final evaluation
model.load_state_dict(torch.load('best_chest_xray_model.pt'))8.10.5 Clinical Evaluation Metrics
Accuracy alone is insufficient for medical AI. A pneumonia classifier needs evaluation on clinically meaningful metrics:
from sklearn.metrics import (
roc_auc_score, confusion_matrix, classification_report,
roc_curve, precision_recall_curve
)
import numpy as np
# Get predictions on test set
test_loss, test_labels, test_probs = evaluate(model, test_loader, criterion, device)
test_preds = [1 if p > 0.5 else 0 for p in test_probs]
# Core metrics
auroc = roc_auc_score(test_labels, test_probs)
cm = confusion_matrix(test_labels, test_preds)
tn, fp, fn, tp = cm.ravel()
sensitivity = tp / (tp + fn) # True positive rate (recall)
specificity = tn / (tn + fp) # True negative rate
ppv = tp / (tp + fp) # Positive predictive value (precision)
npv = tn / (tn + fn) # Negative predictive value
print(f"AUROC: {auroc:.3f}")
print(f"Sensitivity: {sensitivity:.3f}")
print(f"Specificity: {specificity:.3f}")
print(f"PPV: {ppv:.3f}")
print(f"NPV: {npv:.3f}")
print(f"\nConfusion Matrix:\n{cm}")Interpreting these metrics clinically:
- Sensitivity (recall): Of all pneumonia cases, what fraction did we catch? Critical for screening—we don’t want to miss disease.
- Specificity: Of all normal cases, what fraction did we correctly identify? Important to avoid unnecessary follow-up.
- AUROC: Overall discriminative ability across all thresholds. >0.9 is generally considered excellent for medical applications.
- PPV/NPV: Depend on disease prevalence in your population. A model with 95% sensitivity may still have low PPV in a low-prevalence screening setting.
8.10.6 Operating Point Selection
The default 0.5 threshold is rarely optimal. Choose based on clinical priorities:
# Plot ROC curve to visualize tradeoffs
fpr, tpr, thresholds = roc_curve(test_labels, test_probs)
# Find threshold for 95% sensitivity (catch most pneumonia cases)
target_sensitivity = 0.95
idx = np.argmin(np.abs(tpr - target_sensitivity))
optimal_threshold = thresholds[idx]
print(f"Threshold for {target_sensitivity:.0%} sensitivity: {optimal_threshold:.3f}")
print(f"Specificity at this threshold: {1 - fpr[idx]:.3f}")For a screening application, you might accept lower specificity (more false positives sent for follow-up) to achieve very high sensitivity (few missed cases). For a triage application with limited resources, you might prioritize specificity.
8.10.7 What’s Next: Interpretation
A black-box prediction of “85% pneumonia” is clinically limited. Radiologists want to know where in the image the model sees pathology. Chapter 18 covers interpretation methods like GradCAM that highlight the image regions driving the prediction—essential for building clinical trust and catching model failures.
# Preview: GradCAM visualization (covered in Chapter 18)
from torchcam.methods import GradCAM
cam_extractor = GradCAM(model, target_layer='layer4')
# ... generates heatmap showing where model "looks"The complete pipeline—data loading, augmentation, transfer learning, training, and clinical evaluation—provides a template you can adapt for other medical imaging tasks: skin lesion classification, retinal disease detection, tumor segmentation, and beyond.