import torch
import torch.nn as nn
import torch.nn.functional as F
class CLIPModel(nn.Module):
"""Simplified CLIP architecture for medical imaging."""
def __init__(self, image_encoder, text_encoder, embed_dim=512):
super().__init__()
self.image_encoder = image_encoder
self.text_encoder = text_encoder
# Projection heads to shared embedding space
self.image_projection = nn.Linear(image_encoder.output_dim, embed_dim)
self.text_projection = nn.Linear(text_encoder.output_dim, embed_dim)
# Learnable temperature parameter
self.temperature = nn.Parameter(torch.ones([]) * 0.07)
def encode_image(self, images):
features = self.image_encoder(images)
embeddings = self.image_projection(features)
return F.normalize(embeddings, dim=-1)
def encode_text(self, text_tokens):
features = self.text_encoder(text_tokens)
embeddings = self.text_projection(features)
return F.normalize(embeddings, dim=-1)
def forward(self, images, text_tokens):
image_embeddings = self.encode_image(images)
text_embeddings = self.encode_text(text_tokens)
# Compute similarity matrix
logits = image_embeddings @ text_embeddings.T / self.temperature
return logits
def contrastive_loss(logits):
"""Symmetric contrastive loss."""
batch_size = logits.shape[0]
labels = torch.arange(batch_size, device=logits.device)
# Image-to-text loss
loss_i2t = F.cross_entropy(logits, labels)
# Text-to-image loss
loss_t2i = F.cross_entropy(logits.T, labels)
return (loss_i2t + loss_t2i) / 215 Multimodal Pipelines
Clinical Context: A radiologist reviews a chest X-ray alongside the patient’s history: 65-year-old smoker with chronic cough, elevated inflammatory markers, recent weight loss. The imaging findings alone might be ambiguous, but combined with the clinical context, the diagnosis becomes clearer. The best medical AI systems will need to reason across modalities just as clinicians do—integrating images, text, and structured data into unified understanding.
Medicine is inherently multimodal. A diagnosis emerges from imaging studies, laboratory values, clinical notes, patient history, and physical examination. Yet most medical AI systems operate on single modalities: a CNN analyzes an X-ray, an NLP model processes clinical notes, a classifier predicts from lab values. The frontier of medical AI lies in multimodal fusion—combining these data sources to capture the full clinical picture.
This chapter covers the architectures and techniques that enable multimodal medical AI. We’ll start with contrastive learning (CLIP), which learns shared representations across images and text. Then we’ll explore vision-language models that can answer questions about medical images and generate clinical reports. Finally, we’ll examine practical fusion strategies for combining imaging with structured EHR data.
15.1 Why Multimodal?
15.1.1 The Clinical Reality
Consider the information available for a hospitalized patient:
- Imaging: Chest X-ray, CT scans, ultrasound
- Clinical notes: Admission note, progress notes, discharge summary
- Structured data: Vital signs, lab values, medications, diagnoses
- Time series: Continuous monitoring, longitudinal records
Each modality captures different aspects of the patient’s condition. Imaging shows anatomy and pathology. Notes capture clinical reasoning and context. Labs provide objective biomarkers. A complete understanding requires all of them.
15.1.2 The Promise of Multimodal AI
Multimodal models can:
- Improve accuracy: Combining modalities often outperforms single-modality models
- Enable new capabilities: Visual question answering, automatic report generation
- Reduce workload: Generate draft reports, highlight relevant prior findings
- Catch errors: Cross-check imaging findings against clinical context
15.1.3 The Technical Challenges
Building multimodal systems is hard because:
- Different data structures: Images are dense tensors; text is sequential; EHR is tabular
- Alignment: Mapping between modalities (which image regions relate to which text?)
- Missing data: Not all patients have all modalities
- Scale mismatch: A CT scan has millions of voxels; a diagnosis code is one token
This chapter provides the tools to address these challenges.
15.2 Contrastive Learning: CLIP and Medical Adaptations
Clinical Context: You have a chest X-ray and want to know if it shows pneumonia—but you don’t have a trained classifier for pneumonia specifically. With CLIP-style models, you can classify images using natural language descriptions, even for categories not seen during training.
15.2.1 The CLIP Paradigm
CLIP (Contrastive Language-Image Pre-training) revolutionized vision-language AI with a simple idea: learn to match images with their text descriptions.
The architecture has two encoders:
- Image encoder: CNN or Vision Transformer (ViT) that produces image embeddings
- Text encoder: Transformer that produces text embeddings
Both encoders map to the same embedding space. During training, matching image-text pairs should have similar embeddings; non-matching pairs should be dissimilar.
15.2.2 Contrastive Training
Given a batch of \(N\) image-text pairs, CLIP creates an \(N \times N\) matrix of similarities. The diagonal contains matching pairs; off-diagonal entries are negatives.
The loss maximizes similarity for matches while minimizing it for non-matches:
\[\mathcal{L} = -\frac{1}{N}\sum_{i=1}^{N}\left[\log\frac{\exp(\text{sim}(I_i, T_i)/\tau)}{\sum_{j=1}^{N}\exp(\text{sim}(I_i, T_j)/\tau)}\right]\]
where \(\text{sim}(I, T)\) is cosine similarity and \(\tau\) is a temperature parameter.
15.2.3 Medical CLIP Variants
Several groups have adapted CLIP for medical imaging:
BiomedCLIP: Trained on PubMed figure-caption pairs. Strong for scientific/research images.
PubMedCLIP: Similar approach, focused on biomedical literature.
MedCLIP: Specifically trained on chest X-rays and radiology reports from MIMIC-CXR.
PMC-CLIP: Trained on PubMed Central open-access figures.
These models understand medical terminology and visual patterns that general-purpose CLIP misses.
# Using a pretrained medical CLIP model
from transformers import AutoModel, AutoTokenizer, AutoProcessor
# Load BiomedCLIP (example)
model = AutoModel.from_pretrained("microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224")
processor = AutoProcessor.from_pretrained("microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224")
# Encode an image
image = load_chest_xray("patient_001.png")
image_inputs = processor(images=image, return_tensors="pt")
image_embeddings = model.get_image_features(**image_inputs)
# Encode text descriptions
texts = [
"chest x-ray showing pneumonia",
"normal chest x-ray",
"chest x-ray with cardiomegaly",
"chest x-ray showing pleural effusion"
]
text_inputs = processor(text=texts, return_tensors="pt", padding=True)
text_embeddings = model.get_text_features(**text_inputs)
# Compute similarities
similarities = F.cosine_similarity(
image_embeddings.unsqueeze(1),
text_embeddings.unsqueeze(0),
dim=-1
)
print("Similarities:", similarities)
# Highest similarity indicates most likely description15.2.4 Zero-Shot Classification
CLIP’s power lies in zero-shot classification: classify images into categories using only text descriptions, without task-specific training.
def zero_shot_classify(model, processor, image, class_descriptions, device='cuda'):
"""
Classify an image using text descriptions of each class.
Args:
model: CLIP model
image: PIL image
class_descriptions: List of text descriptions for each class
Returns:
Predicted class index and probabilities
"""
model.eval()
# Encode image
image_inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
image_features = model.get_image_features(**image_inputs)
image_features = F.normalize(image_features, dim=-1)
# Encode class descriptions
text_inputs = processor(text=class_descriptions, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
text_features = model.get_text_features(**text_inputs)
text_features = F.normalize(text_features, dim=-1)
# Compute similarities
similarities = (image_features @ text_features.T).squeeze(0)
probabilities = F.softmax(similarities * 100, dim=-1) # Temperature scaling
predicted_class = probabilities.argmax().item()
return predicted_class, probabilities.cpu().numpy()
# Example: classify chest X-ray findings
class_descriptions = [
"a chest x-ray showing no abnormalities, normal lung fields",
"a chest x-ray showing consolidation consistent with pneumonia",
"a chest x-ray showing enlarged cardiac silhouette indicating cardiomegaly",
"a chest x-ray showing blunted costophrenic angles suggesting pleural effusion"
]
predicted, probs = zero_shot_classify(model, processor, xray_image, class_descriptions)
print(f"Predicted: {class_descriptions[predicted]}")
print(f"Probabilities: {dict(zip(['Normal', 'Pneumonia', 'Cardiomegaly', 'Effusion'], probs))}")15.2.5 Prompt Engineering for Medical CLIP
The text descriptions matter enormously. Better prompts yield better zero-shot performance:
# Poor prompts (too generic)
poor_prompts = ["pneumonia", "normal", "effusion"]
# Better prompts (more descriptive)
better_prompts = [
"chest radiograph demonstrating consolidation in the lung fields consistent with pneumonia",
"normal chest radiograph with clear lung fields and normal cardiac silhouette",
"chest radiograph showing layering pleural effusion with blunted costophrenic angle"
]
# Best prompts (ensemble multiple descriptions)
def ensemble_prompts(model, processor, image, class_name, prompt_templates):
"""Average embeddings across multiple prompt templates."""
prompts = [template.format(class_name) for template in prompt_templates]
text_inputs = processor(text=prompts, return_tensors="pt", padding=True)
with torch.no_grad():
embeddings = model.get_text_features(**text_inputs)
return embeddings.mean(dim=0)
templates = [
"a chest x-ray showing {}",
"chest radiograph demonstrating {}",
"x-ray image with findings of {}",
"radiological evidence of {} on chest x-ray"
]15.3 Vision-Language Models
Clinical Context: A resident shows an attending a complex CT scan and asks, “What do you see in the right lower lobe?” The attending examines the image and provides a detailed explanation. Vision-language models aim to replicate this capability—answering natural language questions about medical images.
15.3.1 From CLIP to Generative VLMs
CLIP learns to match images and text but cannot generate new text. Vision-Language Models (VLMs) extend this to generation: given an image and a question, produce a natural language answer.
The key architectural insight: connect a visual encoder to a language model.
15.3.2 The LLaVA Architecture
LLaVA (Large Language and Vision Assistant) is a influential VLM architecture:
- Visual encoder: Pretrained CLIP ViT extracts image features
- Projection layer: Maps visual features to the LLM’s embedding space
- Language model: Pretrained LLM (Vicuna, LLaMA) generates responses
import torch
import torch.nn as nn
class LLaVAModel(nn.Module):
"""Simplified LLaVA-style vision-language model."""
def __init__(self, vision_encoder, language_model, projection_dim):
super().__init__()
self.vision_encoder = vision_encoder
self.language_model = language_model
# Project vision features to LLM embedding dimension
self.vision_projection = nn.Sequential(
nn.Linear(vision_encoder.output_dim, projection_dim),
nn.GELU(),
nn.Linear(projection_dim, language_model.config.hidden_size)
)
def encode_image(self, images):
"""Extract and project visual features."""
# Get patch embeddings from vision encoder
vision_outputs = self.vision_encoder(images)
patch_features = vision_outputs.last_hidden_state # [B, num_patches, dim]
# Project to LLM space
visual_tokens = self.vision_projection(patch_features)
return visual_tokens
def forward(self, images, input_ids, attention_mask):
"""
Forward pass combining visual and text tokens.
The visual tokens are prepended to the text tokens,
allowing the LLM to attend to image information.
"""
# Encode image to visual tokens
visual_tokens = self.encode_image(images) # [B, num_patches, hidden]
# Get text embeddings
text_embeddings = self.language_model.get_input_embeddings()(input_ids)
# Concatenate: [visual_tokens, text_tokens]
inputs_embeds = torch.cat([visual_tokens, text_embeddings], dim=1)
# Extend attention mask for visual tokens
visual_attention = torch.ones(
visual_tokens.shape[:2],
device=attention_mask.device
)
extended_attention = torch.cat([visual_attention, attention_mask], dim=1)
# Forward through LLM
outputs = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=extended_attention
)
return outputs15.3.3 Medical Vision-Language Models
Several medical VLMs have emerged:
LLaVA-Med: LLaVA fine-tuned on biomedical image-text data from PMC articles, then instruction-tuned on medical QA.
Med-PaLM M: Google’s multimodal medical model, trained on diverse medical data including imaging, genomics, and clinical text.
RadFM: Specifically designed for radiology, trained on millions of radiology images and reports.
CheXagent: Focused on chest X-ray interpretation with instruction following.
# Using a medical VLM for visual question answering
from transformers import AutoModelForCausalLM, AutoProcessor
# Load LLaVA-Med (conceptual - actual loading may differ)
model_id = "microsoft/llava-med-v1.5-mistral-7b"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16)
processor = AutoProcessor.from_pretrained(model_id)
def ask_about_image(model, processor, image, question, max_tokens=256):
"""Ask a question about a medical image."""
# Format prompt with image placeholder
prompt = f"<image>\nUser: {question}\nAssistant:"
# Process inputs
inputs = processor(
text=prompt,
images=image,
return_tensors="pt"
).to(model.device)
# Generate response
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=False
)
# Decode response
response = processor.decode(output_ids[0], skip_special_tokens=True)
return response.split("Assistant:")[-1].strip()
# Example usage
image = load_chest_xray("patient_001.png")
questions = [
"What abnormalities do you see in this chest X-ray?",
"Is there evidence of pneumonia?",
"Describe the cardiac silhouette.",
"What is the most likely diagnosis?"
]
for q in questions:
answer = ask_about_image(model, processor, image, q)
print(f"Q: {q}")
print(f"A: {answer}\n")15.3.4 Visual Question Answering Evaluation
Evaluating VQA on medical images requires:
- Accuracy metrics: Exact match, token F1, BLEU for open-ended
- Clinical correctness: Does the answer align with ground truth diagnosis?
- Hallucination detection: Does the model describe features not in the image?
from collections import Counter
import re
def compute_token_f1(prediction, reference):
"""Compute token-level F1 score."""
pred_tokens = set(prediction.lower().split())
ref_tokens = set(reference.lower().split())
if len(pred_tokens) == 0 or len(ref_tokens) == 0:
return 0.0
common = pred_tokens & ref_tokens
precision = len(common) / len(pred_tokens)
recall = len(common) / len(ref_tokens)
if precision + recall == 0:
return 0.0
f1 = 2 * precision * recall / (precision + recall)
return f1
def evaluate_vqa(model, processor, test_data):
"""Evaluate VQA performance on test set."""
exact_matches = 0
f1_scores = []
for item in test_data:
prediction = ask_about_image(model, processor, item['image'], item['question'])
reference = item['answer']
# Exact match (normalized)
pred_norm = prediction.lower().strip()
ref_norm = reference.lower().strip()
if pred_norm == ref_norm:
exact_matches += 1
# Token F1
f1 = compute_token_f1(prediction, reference)
f1_scores.append(f1)
return {
'exact_match': exact_matches / len(test_data),
'token_f1': sum(f1_scores) / len(f1_scores)
}15.4 Clinical Report Generation
Clinical Context: Radiologists spend significant time dictating reports. An AI system that generates draft reports from images could save hours daily—if the drafts are accurate enough to be useful. This section covers how to build and evaluate such systems.
15.4.1 The Report Generation Task
Given a medical image (or set of images), generate a structured clinical report. For radiology, this typically includes:
- Findings: Detailed description of observations
- Impression: Summary and diagnostic conclusions
- Comparison: Reference to prior studies (if available)
15.4.2 Architecture for Report Generation
Report generation uses an encoder-decoder architecture:
- Image encoder: Extract visual features (CNN or ViT)
- Cross-attention: Allow decoder to attend to relevant image regions
- Text decoder: Generate report autoregressively
import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Config
class ReportGenerator(nn.Module):
"""Medical report generation from images."""
def __init__(self, image_encoder, vocab_size, hidden_dim=768, num_layers=6):
super().__init__()
self.image_encoder = image_encoder
# Project image features
self.image_projection = nn.Linear(image_encoder.output_dim, hidden_dim)
# Text decoder with cross-attention
config = GPT2Config(
vocab_size=vocab_size,
n_embd=hidden_dim,
n_layer=num_layers,
n_head=12,
add_cross_attention=True # Enable cross-attention to image
)
self.decoder = GPT2LMHeadModel(config)
def forward(self, images, input_ids, attention_mask=None):
# Encode images
image_features = self.image_encoder(images)
encoder_hidden_states = self.image_projection(image_features)
# Generate with cross-attention to image
outputs = self.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
labels=input_ids # For training
)
return outputs
def generate_report(self, images, tokenizer, max_length=512):
"""Generate a report for the given images."""
self.eval()
# Encode images
image_features = self.image_encoder(images)
encoder_hidden_states = self.image_projection(image_features)
# Start with BOS token
input_ids = torch.tensor([[tokenizer.bos_token_id]]).to(images.device)
# Generate autoregressively
with torch.no_grad():
output_ids = self.decoder.generate(
input_ids=input_ids,
encoder_hidden_states=encoder_hidden_states,
max_length=max_length,
num_beams=4,
early_stopping=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
report = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return report15.4.3 Using VLMs for Report Generation
Modern approach: leverage pretrained VLMs with appropriate prompting.
def generate_radiology_report(model, processor, image, prior_report=None):
"""Generate a structured radiology report using a VLM."""
if prior_report:
prompt = f"""<image>
You are an expert radiologist. Generate a detailed radiology report for this chest X-ray.
Prior study report for comparison:
{prior_report}
Please provide:
1. FINDINGS: Detailed description of all observations
2. IMPRESSION: Summary and diagnostic conclusions
3. COMPARISON: Changes from prior study
Report:"""
else:
prompt = """<image>
You are an expert radiologist. Generate a detailed radiology report for this chest X-ray.
Please provide:
1. FINDINGS: Detailed description of all observations
2. IMPRESSION: Summary and diagnostic conclusions
Report:"""
inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=512,
do_sample=False,
num_beams=4
)
report = processor.decode(output_ids[0], skip_special_tokens=True)
return report.split("Report:")[-1].strip()
# Generate report
image = load_chest_xray("patient_001.png")
report = generate_radiology_report(model, processor, image)
print(report)15.4.4 Evaluation Metrics
Report generation requires multiple evaluation approaches:
Natural language metrics: - BLEU: N-gram overlap with reference reports - ROUGE: Recall-oriented overlap measure - BERTScore: Semantic similarity using embeddings
Clinical metrics: - CheXbert labeler: Extract findings, compare to reference - RadGraph F1: Compare clinical entities and relations - Radiologist review: Gold standard but expensive
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
def evaluate_report(generated, reference):
"""Evaluate generated report against reference."""
# BLEU score
reference_tokens = reference.lower().split()
generated_tokens = generated.lower().split()
smoothing = SmoothingFunction().method1
bleu = sentence_bleu([reference_tokens], generated_tokens, smoothing_function=smoothing)
# ROUGE scores
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
rouge_scores = scorer.score(reference, generated)
return {
'bleu': bleu,
'rouge1': rouge_scores['rouge1'].fmeasure,
'rouge2': rouge_scores['rouge2'].fmeasure,
'rougeL': rouge_scores['rougeL'].fmeasure
}
def evaluate_clinical_accuracy(generated, reference, chexbert_labeler):
"""
Evaluate clinical accuracy using CheXbert.
Extracts 14 finding labels and compares.
"""
gen_labels = chexbert_labeler(generated)
ref_labels = chexbert_labeler(reference)
# Compare labels
matches = sum(g == r for g, r in zip(gen_labels, ref_labels))
accuracy = matches / len(gen_labels)
return {'clinical_accuracy': accuracy, 'gen_labels': gen_labels, 'ref_labels': ref_labels}15.4.5 Current Limitations
Report generation faces significant challenges:
- Hallucination: Models may describe findings not present in the image
- Omission: Models may miss subtle but important findings
- Inconsistency: Same image may produce different reports
- Liability: Who is responsible for errors in AI-generated reports?
Best practice: Use AI-generated reports as drafts for radiologist review, not as final outputs.
15.5 Multimodal Fusion Strategies
Clinical Context: You’re building a model to predict ICU mortality. You have chest X-rays, clinical notes, and structured vitals/labs. How do you combine these different data types into a single prediction?
15.5.1 Early Fusion
Early fusion combines raw or lightly processed features before the main model.
class EarlyFusionModel(nn.Module):
"""Combine features early, then process jointly."""
def __init__(self, image_dim, text_dim, tabular_dim, hidden_dim, num_classes):
super().__init__()
# Simple projections to common dimension
self.image_proj = nn.Linear(image_dim, hidden_dim)
self.text_proj = nn.Linear(text_dim, hidden_dim)
self.tabular_proj = nn.Linear(tabular_dim, hidden_dim)
# Joint processing after concatenation
self.classifier = nn.Sequential(
nn.Linear(hidden_dim * 3, hidden_dim),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, num_classes)
)
def forward(self, image_features, text_features, tabular_features):
# Project each modality
img = self.image_proj(image_features)
txt = self.text_proj(text_features)
tab = self.tabular_proj(tabular_features)
# Concatenate
combined = torch.cat([img, txt, tab], dim=-1)
# Classify
return self.classifier(combined)Pros: Simple, allows cross-modal interactions early Cons: Requires all modalities present, can be dominated by one modality
15.5.2 Late Fusion
Late fusion trains separate models per modality, then combines predictions.
class LateFusionModel(nn.Module):
"""Combine predictions from separate modality-specific models."""
def __init__(self, image_model, text_model, tabular_model, num_classes):
super().__init__()
self.image_model = image_model
self.text_model = text_model
self.tabular_model = tabular_model
# Fusion layer combines predictions
self.fusion = nn.Sequential(
nn.Linear(num_classes * 3, num_classes * 2),
nn.ReLU(),
nn.Linear(num_classes * 2, num_classes)
)
def forward(self, image_features, text_features, tabular_features,
image_mask=None, text_mask=None, tabular_mask=None):
predictions = []
# Get predictions from each modality
if image_mask is None or image_mask.any():
img_pred = self.image_model(image_features)
predictions.append(img_pred)
else:
predictions.append(torch.zeros_like(img_pred))
if text_mask is None or text_mask.any():
txt_pred = self.text_model(text_features)
predictions.append(txt_pred)
else:
predictions.append(torch.zeros_like(txt_pred))
if tabular_mask is None or tabular_mask.any():
tab_pred = self.tabular_model(tabular_features)
predictions.append(tab_pred)
else:
predictions.append(torch.zeros_like(tab_pred))
# Combine predictions
combined = torch.cat(predictions, dim=-1)
return self.fusion(combined)
def forward_with_missing(self, features_dict):
"""Handle missing modalities gracefully."""
predictions = []
if 'image' in features_dict:
predictions.append(self.image_model(features_dict['image']))
if 'text' in features_dict:
predictions.append(self.text_model(features_dict['text']))
if 'tabular' in features_dict:
predictions.append(self.tabular_model(features_dict['tabular']))
# Average available predictions
return torch.stack(predictions).mean(dim=0)Pros: Handles missing modalities, modality-specific architectures Cons: No cross-modal interaction during training
15.5.3 Cross-Attention Fusion
Cross-attention allows modalities to attend to each other, learning what information to share.
class CrossAttentionFusion(nn.Module):
"""Modalities attend to each other via cross-attention."""
def __init__(self, hidden_dim, num_heads=8):
super().__init__()
# Image attends to text
self.img_to_txt_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
# Text attends to image
self.txt_to_img_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
# Layer norms
self.img_norm = nn.LayerNorm(hidden_dim)
self.txt_norm = nn.LayerNorm(hidden_dim)
def forward(self, image_tokens, text_tokens):
"""
Args:
image_tokens: [batch, num_patches, hidden_dim]
text_tokens: [batch, seq_len, hidden_dim]
"""
# Image attends to text
img_attended, _ = self.img_to_txt_attn(
query=image_tokens,
key=text_tokens,
value=text_tokens
)
image_out = self.img_norm(image_tokens + img_attended)
# Text attends to image
txt_attended, _ = self.txt_to_img_attn(
query=text_tokens,
key=image_tokens,
value=image_tokens
)
text_out = self.txt_norm(text_tokens + txt_attended)
return image_out, text_out
class MultimodalTransformer(nn.Module):
"""Full multimodal model with cross-attention fusion."""
def __init__(self, image_encoder, text_encoder, hidden_dim, num_classes, num_fusion_layers=4):
super().__init__()
self.image_encoder = image_encoder
self.text_encoder = text_encoder
# Project to common dimension
self.image_proj = nn.Linear(image_encoder.output_dim, hidden_dim)
self.text_proj = nn.Linear(text_encoder.output_dim, hidden_dim)
# Cross-attention fusion layers
self.fusion_layers = nn.ModuleList([
CrossAttentionFusion(hidden_dim) for _ in range(num_fusion_layers)
])
# Classification head
self.classifier = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, num_classes)
)
def forward(self, images, text_input_ids, text_attention_mask):
# Encode modalities
image_features = self.image_encoder(images).last_hidden_state
text_features = self.text_encoder(text_input_ids, text_attention_mask).last_hidden_state
# Project to common space
image_tokens = self.image_proj(image_features)
text_tokens = self.text_proj(text_features)
# Cross-attention fusion
for fusion_layer in self.fusion_layers:
image_tokens, text_tokens = fusion_layer(image_tokens, text_tokens)
# Pool and classify
image_pooled = image_tokens.mean(dim=1)
text_pooled = text_tokens.mean(dim=1)
combined = torch.cat([image_pooled, text_pooled], dim=-1)
return self.classifier(combined)Pros: Rich cross-modal interactions, learns alignment Cons: Computationally expensive, requires both modalities during training
15.5.4 Choosing a Fusion Strategy
| Scenario | Recommended Strategy |
|---|---|
| Simple baseline | Late fusion |
| Missing modalities common | Late fusion with averaging |
| Cross-modal reasoning needed | Cross-attention |
| Limited compute | Early fusion |
| Pretrained unimodal models | Late fusion |
15.6 Combining Imaging and EHR Data
Clinical Context: A chest X-ray shows an ambiguous opacity. Is it pneumonia, atelectasis, or malignancy? The answer might depend on context: is the patient febrile? immunocompromised? a smoker? Combining imaging with EHR data can resolve ambiguities that imaging alone cannot.
15.6.1 Why Add Structured Data?
Structured EHR data provides context that images lack:
- Demographics: Age, sex (some conditions more common in certain groups)
- Vitals: Fever suggests infection; hypotension suggests severity
- Labs: Elevated WBC, procalcitonin support infectious etiology
- History: Prior diagnoses, surgeries, risk factors
- Medications: Immunosuppression, anticoagulation affect interpretation
15.6.2 Feature Extraction from EHR
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
def prepare_ehr_features(ehr_df):
"""
Prepare structured EHR features for multimodal fusion.
Returns normalized feature vector for each patient.
"""
# Define feature groups
numeric_features = [
'age', 'temperature', 'heart_rate', 'respiratory_rate',
'systolic_bp', 'o2_saturation', 'wbc', 'creatinine', 'bun'
]
categorical_features = ['sex', 'admission_type', 'insurance']
binary_features = [
'diabetes', 'hypertension', 'copd', 'chf', 'ckd',
'immunocompromised', 'smoker', 'prior_pneumonia'
]
# Preprocessing pipeline
preprocessor = ColumnTransformer([
('numeric', StandardScaler(), numeric_features),
('categorical', OneHotEncoder(sparse_output=False, handle_unknown='ignore'), categorical_features),
('binary', 'passthrough', binary_features)
])
features = preprocessor.fit_transform(ehr_df)
return features, preprocessor
# Example EHR data
ehr_data = pd.DataFrame({
'patient_id': [1, 2, 3],
'age': [65, 45, 78],
'sex': ['M', 'F', 'M'],
'temperature': [38.5, 37.0, 39.2],
'heart_rate': [95, 72, 110],
'respiratory_rate': [22, 16, 28],
'systolic_bp': [130, 120, 95],
'o2_saturation': [92, 98, 88],
'wbc': [15.2, 8.1, 18.5],
'creatinine': [1.2, 0.9, 2.1],
'bun': [25, 15, 45],
'admission_type': ['emergency', 'elective', 'emergency'],
'insurance': ['medicare', 'commercial', 'medicare'],
'diabetes': [1, 0, 1],
'hypertension': [1, 0, 1],
'copd': [1, 0, 0],
'chf': [0, 0, 1],
'ckd': [0, 0, 1],
'immunocompromised': [0, 0, 0],
'smoker': [1, 0, 1],
'prior_pneumonia': [1, 0, 0]
})
features, preprocessor = prepare_ehr_features(ehr_data)
print(f"EHR feature dimension: {features.shape[1]}")15.6.3 Image + Tabular Fusion Model
import torch
import torch.nn as nn
from torchvision import models
class ImageTabularFusion(nn.Module):
"""Combine chest X-ray with structured EHR data."""
def __init__(self, tabular_dim, num_classes, hidden_dim=256):
super().__init__()
# Image encoder (pretrained ResNet)
self.image_encoder = models.resnet50(weights='IMAGENET1K_V1')
image_feature_dim = self.image_encoder.fc.in_features
self.image_encoder.fc = nn.Identity()
# Tabular encoder
self.tabular_encoder = nn.Sequential(
nn.Linear(tabular_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
# Image projection
self.image_proj = nn.Sequential(
nn.Linear(image_feature_dim, hidden_dim),
nn.ReLU()
)
# Fusion with gating mechanism
self.gate = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.Sigmoid()
)
# Classification head
self.classifier = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(hidden_dim // 2, num_classes)
)
def forward(self, images, tabular):
# Encode each modality
img_features = self.image_encoder(images)
img_features = self.image_proj(img_features)
tab_features = self.tabular_encoder(tabular)
# Gated fusion
combined = torch.cat([img_features, tab_features], dim=-1)
gate = self.gate(combined)
fused = gate * img_features + (1 - gate) * tab_features
# Classify
return self.classifier(fused)
# Training loop
def train_multimodal(model, train_loader, val_loader, epochs=20):
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
best_val_auc = 0
for epoch in range(epochs):
model.train()
train_loss = 0
for images, tabular, labels in train_loader:
images = images.to(device)
tabular = tabular.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(images, tabular)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
# Validation
model.eval()
val_probs = []
val_labels = []
with torch.no_grad():
for images, tabular, labels in val_loader:
images = images.to(device)
tabular = tabular.to(device)
outputs = model(images, tabular)
probs = torch.softmax(outputs, dim=1)[:, 1]
val_probs.extend(probs.cpu().numpy())
val_labels.extend(labels.numpy())
from sklearn.metrics import roc_auc_score
val_auc = roc_auc_score(val_labels, val_probs)
print(f"Epoch {epoch+1}: Train Loss={train_loss/len(train_loader):.4f}, Val AUC={val_auc:.4f}")
if val_auc > best_val_auc:
best_val_auc = val_auc
torch.save(model.state_dict(), 'best_multimodal.pt')
return best_val_auc15.6.4 Handling Missing Modalities
In practice, not all patients have all data. Handle this gracefully:
class RobustMultimodalModel(nn.Module):
"""Handle missing modalities during training and inference."""
def __init__(self, image_encoder, tabular_encoder, hidden_dim, num_classes):
super().__init__()
self.image_encoder = image_encoder
self.tabular_encoder = tabular_encoder
self.hidden_dim = hidden_dim
# Learnable default embeddings for missing modalities
self.default_image_embedding = nn.Parameter(torch.randn(1, hidden_dim))
self.default_tabular_embedding = nn.Parameter(torch.randn(1, hidden_dim))
# Classifier
self.classifier = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, num_classes)
)
def forward(self, images=None, tabular=None, image_mask=None, tabular_mask=None):
batch_size = images.shape[0] if images is not None else tabular.shape[0]
device = images.device if images is not None else tabular.device
# Encode available modalities
if images is not None and (image_mask is None or image_mask.any()):
img_features = self.image_encoder(images)
else:
img_features = self.default_image_embedding.expand(batch_size, -1)
if tabular is not None and (tabular_mask is None or tabular_mask.any()):
tab_features = self.tabular_encoder(tabular)
else:
tab_features = self.default_tabular_embedding.expand(batch_size, -1)
# Apply masks if provided
if image_mask is not None:
img_features = torch.where(
image_mask.unsqueeze(-1),
img_features,
self.default_image_embedding.expand(batch_size, -1)
)
if tabular_mask is not None:
tab_features = torch.where(
tabular_mask.unsqueeze(-1),
tab_features,
self.default_tabular_embedding.expand(batch_size, -1)
)
combined = torch.cat([img_features, tab_features], dim=-1)
return self.classifier(combined)15.6.5 When Multimodal Helps
Multimodal fusion typically improves performance when:
- Modalities are complementary: Each provides unique information
- Task requires context: Imaging interpretation depends on clinical setting
- Single modality is ambiguous: Multiple interpretations possible from image alone
It may not help when:
- One modality dominates: The other adds only noise
- Modalities are redundant: Same information in different forms
- Data quality varies: Poor-quality modality hurts fusion
Always compare multimodal to best single-modality baseline.
15.7 Evaluation and Challenges
15.7.1 Evaluating Multimodal Models
Standard metrics apply, but consider:
- Modality ablation: How much does each modality contribute?
- Missing modality performance: How does the model degrade with missing data?
- Cross-modal consistency: Do predictions make sense given all inputs?
def modality_ablation_study(model, test_loader, device):
"""Evaluate model with different modality combinations."""
results = {}
# Full model (both modalities)
auc_full = evaluate_model(model, test_loader, device,
use_image=True, use_tabular=True)
results['full'] = auc_full
# Image only
auc_image = evaluate_model(model, test_loader, device,
use_image=True, use_tabular=False)
results['image_only'] = auc_image
# Tabular only
auc_tabular = evaluate_model(model, test_loader, device,
use_image=False, use_tabular=True)
results['tabular_only'] = auc_tabular
print("Modality Ablation Results:")
print(f" Full model (image + tabular): AUC = {auc_full:.4f}")
print(f" Image only: AUC = {auc_image:.4f}")
print(f" Tabular only: AUC = {auc_tabular:.4f}")
print(f" Multimodal improvement over best unimodal: "
f"{auc_full - max(auc_image, auc_tabular):+.4f}")
return results15.7.2 Key Challenges
Hallucination in Generation
VLMs can generate plausible-sounding but incorrect findings. Mitigation: - Constrain generation vocabulary - Verify findings against image-level classifiers - Require human review
Modality Imbalance
One modality may dominate learning. Solutions: - Gradient balancing across modalities - Modality dropout during training - Uncertainty-weighted fusion
Computational Cost
Multimodal models are expensive. Consider: - Efficient architectures (early exit, pruning) - Modality-specific compression - When to use full multimodal vs. single modality
Regulatory Pathway
Multimodal medical devices face complex approval: - Each modality may have different risk profiles - Interaction effects must be validated - Missing modality handling needs documentation
15.8 The Future: Medical Foundation Models
15.8.1 The Vision: Generalist Medical AI
Recent models point toward generalist medical AI—systems that handle diverse tasks across modalities:
Med-PaLM M (Google): Single model for medical QA, radiology, pathology, genomics, and dermatology.
GPT-4V (OpenAI): General vision-language model with emerging medical capabilities.
GigaPath (Microsoft): Pathology foundation model trained on 170,000 whole slide images.
These models suggest a future where one system handles many medical AI tasks, similar to how radiologists develop broad expertise.
15.8.2 Implications for Clinical Practice
If foundation models succeed:
- Reduced development cost: Fine-tune one model for many tasks
- Cross-task transfer: Learning from one task improves others
- Unified interfaces: Single system for diverse clinical AI needs
- New capabilities: Emergent abilities from scale
15.8.3 Remaining Challenges
- Validation: How do you validate a model that does everything?
- Failure modes: Generalist models may fail in unexpected ways
- Specialization vs. generalization: When is a specialist model better?
- Deployment complexity: Running large models in clinical settings
- Liability: Who is responsible when a general-purpose AI errs?
15.8.4 What This Means for Students
The field is moving fast. Key skills:
- Understand architectures: Know how multimodal models work
- Evaluation expertise: Be able to assess multimodal systems rigorously
- Clinical grounding: Always connect technical work to clinical needs
- Adaptability: Today’s SOTA will be tomorrow’s baseline
15.9 Chapter Summary
Multimodal AI represents the frontier of medical AI, combining information as clinicians do.
Contrastive learning (CLIP): - Learns shared image-text embeddings - Enables zero-shot classification - Medical variants (BiomedCLIP, MedCLIP) understand clinical content
Vision-language models: - LLaVA architecture: vision encoder + LLM - Visual question answering on medical images - Medical VLMs (LLaVA-Med, Med-PaLM M) for clinical applications
Report generation: - Encoder-decoder architecture with cross-attention - VLMs can generate structured reports - Evaluation requires both NLP and clinical metrics
Fusion strategies: - Early fusion: combine features before processing - Late fusion: combine predictions from separate models - Cross-attention: modalities attend to each other
Image + EHR integration: - Structured data provides clinical context - Gated fusion learns modality weighting - Handle missing modalities gracefully
The path forward leads to foundation models that integrate all medical data modalities. Understanding multimodal architectures today prepares you for this future.
15.10 Exercises
Implement zero-shot classification with a medical CLIP model. Compare performance on chest X-ray classification using different prompt templates. Which prompts work best?
Build a simple report generation system using a pretrained VLM. Evaluate on a small held-out set using BLEU, ROUGE, and manual clinical review. What types of errors does the model make?
Implement early, late, and cross-attention fusion for a binary classification task using synthetic image and tabular features. Compare training dynamics and final performance.
Create a dataset where some samples have missing modalities. Implement and compare strategies for handling missing data: zero imputation, learned defaults, and modality dropout.
Design an evaluation protocol for a multimodal radiology AI system that will be deployed clinically. What metrics would you track? How would you monitor for failure modes?