15  Multimodal Pipelines

Clinical Context: A radiologist reviews a chest X-ray alongside the patient’s history: 65-year-old smoker with chronic cough, elevated inflammatory markers, recent weight loss. The imaging findings alone might be ambiguous, but combined with the clinical context, the diagnosis becomes clearer. The best medical AI systems will need to reason across modalities just as clinicians do—integrating images, text, and structured data into unified understanding.

Medicine is inherently multimodal. A diagnosis emerges from imaging studies, laboratory values, clinical notes, patient history, and physical examination. Yet most medical AI systems operate on single modalities: a CNN analyzes an X-ray, an NLP model processes clinical notes, a classifier predicts from lab values. The frontier of medical AI lies in multimodal fusion—combining these data sources to capture the full clinical picture.

This chapter covers the architectures and techniques that enable multimodal medical AI. We’ll start with contrastive learning (CLIP), which learns shared representations across images and text. Then we’ll explore vision-language models that can answer questions about medical images and generate clinical reports. Finally, we’ll examine practical fusion strategies for combining imaging with structured EHR data.

15.1 Why Multimodal?

15.1.1 The Clinical Reality

Consider the information available for a hospitalized patient:

  • Imaging: Chest X-ray, CT scans, ultrasound
  • Clinical notes: Admission note, progress notes, discharge summary
  • Structured data: Vital signs, lab values, medications, diagnoses
  • Time series: Continuous monitoring, longitudinal records

Each modality captures different aspects of the patient’s condition. Imaging shows anatomy and pathology. Notes capture clinical reasoning and context. Labs provide objective biomarkers. A complete understanding requires all of them.

15.1.2 The Promise of Multimodal AI

Multimodal models can:

  1. Improve accuracy: Combining modalities often outperforms single-modality models
  2. Enable new capabilities: Visual question answering, automatic report generation
  3. Reduce workload: Generate draft reports, highlight relevant prior findings
  4. Catch errors: Cross-check imaging findings against clinical context

15.1.3 The Technical Challenges

Building multimodal systems is hard because:

  • Different data structures: Images are dense tensors; text is sequential; EHR is tabular
  • Alignment: Mapping between modalities (which image regions relate to which text?)
  • Missing data: Not all patients have all modalities
  • Scale mismatch: A CT scan has millions of voxels; a diagnosis code is one token

This chapter provides the tools to address these challenges.

15.2 Contrastive Learning: CLIP and Medical Adaptations

Clinical Context: You have a chest X-ray and want to know if it shows pneumonia—but you don’t have a trained classifier for pneumonia specifically. With CLIP-style models, you can classify images using natural language descriptions, even for categories not seen during training.

15.2.1 The CLIP Paradigm

CLIP (Contrastive Language-Image Pre-training) revolutionized vision-language AI with a simple idea: learn to match images with their text descriptions.

The architecture has two encoders:

  • Image encoder: CNN or Vision Transformer (ViT) that produces image embeddings
  • Text encoder: Transformer that produces text embeddings

Both encoders map to the same embedding space. During training, matching image-text pairs should have similar embeddings; non-matching pairs should be dissimilar.

15.2.2 Contrastive Training

Given a batch of \(N\) image-text pairs, CLIP creates an \(N \times N\) matrix of similarities. The diagonal contains matching pairs; off-diagonal entries are negatives.

The loss maximizes similarity for matches while minimizing it for non-matches:

\[\mathcal{L} = -\frac{1}{N}\sum_{i=1}^{N}\left[\log\frac{\exp(\text{sim}(I_i, T_i)/\tau)}{\sum_{j=1}^{N}\exp(\text{sim}(I_i, T_j)/\tau)}\right]\]

where \(\text{sim}(I, T)\) is cosine similarity and \(\tau\) is a temperature parameter.

import torch
import torch.nn as nn
import torch.nn.functional as F

class CLIPModel(nn.Module):
    """Simplified CLIP architecture for medical imaging."""

    def __init__(self, image_encoder, text_encoder, embed_dim=512):
        super().__init__()
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder

        # Projection heads to shared embedding space
        self.image_projection = nn.Linear(image_encoder.output_dim, embed_dim)
        self.text_projection = nn.Linear(text_encoder.output_dim, embed_dim)

        # Learnable temperature parameter
        self.temperature = nn.Parameter(torch.ones([]) * 0.07)

    def encode_image(self, images):
        features = self.image_encoder(images)
        embeddings = self.image_projection(features)
        return F.normalize(embeddings, dim=-1)

    def encode_text(self, text_tokens):
        features = self.text_encoder(text_tokens)
        embeddings = self.text_projection(features)
        return F.normalize(embeddings, dim=-1)

    def forward(self, images, text_tokens):
        image_embeddings = self.encode_image(images)
        text_embeddings = self.encode_text(text_tokens)

        # Compute similarity matrix
        logits = image_embeddings @ text_embeddings.T / self.temperature

        return logits

def contrastive_loss(logits):
    """Symmetric contrastive loss."""
    batch_size = logits.shape[0]
    labels = torch.arange(batch_size, device=logits.device)

    # Image-to-text loss
    loss_i2t = F.cross_entropy(logits, labels)
    # Text-to-image loss
    loss_t2i = F.cross_entropy(logits.T, labels)

    return (loss_i2t + loss_t2i) / 2

15.2.3 Medical CLIP Variants

Several groups have adapted CLIP for medical imaging:

BiomedCLIP: Trained on PubMed figure-caption pairs. Strong for scientific/research images.

PubMedCLIP: Similar approach, focused on biomedical literature.

MedCLIP: Specifically trained on chest X-rays and radiology reports from MIMIC-CXR.

PMC-CLIP: Trained on PubMed Central open-access figures.

These models understand medical terminology and visual patterns that general-purpose CLIP misses.

# Using a pretrained medical CLIP model
from transformers import AutoModel, AutoTokenizer, AutoProcessor

# Load BiomedCLIP (example)
model = AutoModel.from_pretrained("microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224")
processor = AutoProcessor.from_pretrained("microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224")

# Encode an image
image = load_chest_xray("patient_001.png")
image_inputs = processor(images=image, return_tensors="pt")
image_embeddings = model.get_image_features(**image_inputs)

# Encode text descriptions
texts = [
    "chest x-ray showing pneumonia",
    "normal chest x-ray",
    "chest x-ray with cardiomegaly",
    "chest x-ray showing pleural effusion"
]
text_inputs = processor(text=texts, return_tensors="pt", padding=True)
text_embeddings = model.get_text_features(**text_inputs)

# Compute similarities
similarities = F.cosine_similarity(
    image_embeddings.unsqueeze(1),
    text_embeddings.unsqueeze(0),
    dim=-1
)
print("Similarities:", similarities)
# Highest similarity indicates most likely description

15.2.4 Zero-Shot Classification

CLIP’s power lies in zero-shot classification: classify images into categories using only text descriptions, without task-specific training.

def zero_shot_classify(model, processor, image, class_descriptions, device='cuda'):
    """
    Classify an image using text descriptions of each class.

    Args:
        model: CLIP model
        image: PIL image
        class_descriptions: List of text descriptions for each class
    Returns:
        Predicted class index and probabilities
    """
    model.eval()

    # Encode image
    image_inputs = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        image_features = model.get_image_features(**image_inputs)
        image_features = F.normalize(image_features, dim=-1)

    # Encode class descriptions
    text_inputs = processor(text=class_descriptions, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        text_features = model.get_text_features(**text_inputs)
        text_features = F.normalize(text_features, dim=-1)

    # Compute similarities
    similarities = (image_features @ text_features.T).squeeze(0)
    probabilities = F.softmax(similarities * 100, dim=-1)  # Temperature scaling

    predicted_class = probabilities.argmax().item()
    return predicted_class, probabilities.cpu().numpy()

# Example: classify chest X-ray findings
class_descriptions = [
    "a chest x-ray showing no abnormalities, normal lung fields",
    "a chest x-ray showing consolidation consistent with pneumonia",
    "a chest x-ray showing enlarged cardiac silhouette indicating cardiomegaly",
    "a chest x-ray showing blunted costophrenic angles suggesting pleural effusion"
]

predicted, probs = zero_shot_classify(model, processor, xray_image, class_descriptions)
print(f"Predicted: {class_descriptions[predicted]}")
print(f"Probabilities: {dict(zip(['Normal', 'Pneumonia', 'Cardiomegaly', 'Effusion'], probs))}")

15.2.5 Prompt Engineering for Medical CLIP

The text descriptions matter enormously. Better prompts yield better zero-shot performance:

# Poor prompts (too generic)
poor_prompts = ["pneumonia", "normal", "effusion"]

# Better prompts (more descriptive)
better_prompts = [
    "chest radiograph demonstrating consolidation in the lung fields consistent with pneumonia",
    "normal chest radiograph with clear lung fields and normal cardiac silhouette",
    "chest radiograph showing layering pleural effusion with blunted costophrenic angle"
]

# Best prompts (ensemble multiple descriptions)
def ensemble_prompts(model, processor, image, class_name, prompt_templates):
    """Average embeddings across multiple prompt templates."""
    prompts = [template.format(class_name) for template in prompt_templates]
    text_inputs = processor(text=prompts, return_tensors="pt", padding=True)
    with torch.no_grad():
        embeddings = model.get_text_features(**text_inputs)
    return embeddings.mean(dim=0)

templates = [
    "a chest x-ray showing {}",
    "chest radiograph demonstrating {}",
    "x-ray image with findings of {}",
    "radiological evidence of {} on chest x-ray"
]

15.3 Vision-Language Models

Clinical Context: A resident shows an attending a complex CT scan and asks, “What do you see in the right lower lobe?” The attending examines the image and provides a detailed explanation. Vision-language models aim to replicate this capability—answering natural language questions about medical images.

15.3.1 From CLIP to Generative VLMs

CLIP learns to match images and text but cannot generate new text. Vision-Language Models (VLMs) extend this to generation: given an image and a question, produce a natural language answer.

The key architectural insight: connect a visual encoder to a language model.

15.3.2 The LLaVA Architecture

LLaVA (Large Language and Vision Assistant) is a influential VLM architecture:

  1. Visual encoder: Pretrained CLIP ViT extracts image features
  2. Projection layer: Maps visual features to the LLM’s embedding space
  3. Language model: Pretrained LLM (Vicuna, LLaMA) generates responses
import torch
import torch.nn as nn

class LLaVAModel(nn.Module):
    """Simplified LLaVA-style vision-language model."""

    def __init__(self, vision_encoder, language_model, projection_dim):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.language_model = language_model

        # Project vision features to LLM embedding dimension
        self.vision_projection = nn.Sequential(
            nn.Linear(vision_encoder.output_dim, projection_dim),
            nn.GELU(),
            nn.Linear(projection_dim, language_model.config.hidden_size)
        )

    def encode_image(self, images):
        """Extract and project visual features."""
        # Get patch embeddings from vision encoder
        vision_outputs = self.vision_encoder(images)
        patch_features = vision_outputs.last_hidden_state  # [B, num_patches, dim]

        # Project to LLM space
        visual_tokens = self.vision_projection(patch_features)
        return visual_tokens

    def forward(self, images, input_ids, attention_mask):
        """
        Forward pass combining visual and text tokens.

        The visual tokens are prepended to the text tokens,
        allowing the LLM to attend to image information.
        """
        # Encode image to visual tokens
        visual_tokens = self.encode_image(images)  # [B, num_patches, hidden]

        # Get text embeddings
        text_embeddings = self.language_model.get_input_embeddings()(input_ids)

        # Concatenate: [visual_tokens, text_tokens]
        inputs_embeds = torch.cat([visual_tokens, text_embeddings], dim=1)

        # Extend attention mask for visual tokens
        visual_attention = torch.ones(
            visual_tokens.shape[:2],
            device=attention_mask.device
        )
        extended_attention = torch.cat([visual_attention, attention_mask], dim=1)

        # Forward through LLM
        outputs = self.language_model(
            inputs_embeds=inputs_embeds,
            attention_mask=extended_attention
        )

        return outputs

15.3.3 Medical Vision-Language Models

Several medical VLMs have emerged:

LLaVA-Med: LLaVA fine-tuned on biomedical image-text data from PMC articles, then instruction-tuned on medical QA.

Med-PaLM M: Google’s multimodal medical model, trained on diverse medical data including imaging, genomics, and clinical text.

RadFM: Specifically designed for radiology, trained on millions of radiology images and reports.

CheXagent: Focused on chest X-ray interpretation with instruction following.

# Using a medical VLM for visual question answering
from transformers import AutoModelForCausalLM, AutoProcessor

# Load LLaVA-Med (conceptual - actual loading may differ)
model_id = "microsoft/llava-med-v1.5-mistral-7b"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16)
processor = AutoProcessor.from_pretrained(model_id)

def ask_about_image(model, processor, image, question, max_tokens=256):
    """Ask a question about a medical image."""

    # Format prompt with image placeholder
    prompt = f"<image>\nUser: {question}\nAssistant:"

    # Process inputs
    inputs = processor(
        text=prompt,
        images=image,
        return_tensors="pt"
    ).to(model.device)

    # Generate response
    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            do_sample=False
        )

    # Decode response
    response = processor.decode(output_ids[0], skip_special_tokens=True)
    return response.split("Assistant:")[-1].strip()

# Example usage
image = load_chest_xray("patient_001.png")

questions = [
    "What abnormalities do you see in this chest X-ray?",
    "Is there evidence of pneumonia?",
    "Describe the cardiac silhouette.",
    "What is the most likely diagnosis?"
]

for q in questions:
    answer = ask_about_image(model, processor, image, q)
    print(f"Q: {q}")
    print(f"A: {answer}\n")

15.3.4 Visual Question Answering Evaluation

Evaluating VQA on medical images requires:

  1. Accuracy metrics: Exact match, token F1, BLEU for open-ended
  2. Clinical correctness: Does the answer align with ground truth diagnosis?
  3. Hallucination detection: Does the model describe features not in the image?
from collections import Counter
import re

def compute_token_f1(prediction, reference):
    """Compute token-level F1 score."""
    pred_tokens = set(prediction.lower().split())
    ref_tokens = set(reference.lower().split())

    if len(pred_tokens) == 0 or len(ref_tokens) == 0:
        return 0.0

    common = pred_tokens & ref_tokens
    precision = len(common) / len(pred_tokens)
    recall = len(common) / len(ref_tokens)

    if precision + recall == 0:
        return 0.0

    f1 = 2 * precision * recall / (precision + recall)
    return f1

def evaluate_vqa(model, processor, test_data):
    """Evaluate VQA performance on test set."""
    exact_matches = 0
    f1_scores = []

    for item in test_data:
        prediction = ask_about_image(model, processor, item['image'], item['question'])
        reference = item['answer']

        # Exact match (normalized)
        pred_norm = prediction.lower().strip()
        ref_norm = reference.lower().strip()
        if pred_norm == ref_norm:
            exact_matches += 1

        # Token F1
        f1 = compute_token_f1(prediction, reference)
        f1_scores.append(f1)

    return {
        'exact_match': exact_matches / len(test_data),
        'token_f1': sum(f1_scores) / len(f1_scores)
    }

15.4 Clinical Report Generation

Clinical Context: Radiologists spend significant time dictating reports. An AI system that generates draft reports from images could save hours daily—if the drafts are accurate enough to be useful. This section covers how to build and evaluate such systems.

15.4.1 The Report Generation Task

Given a medical image (or set of images), generate a structured clinical report. For radiology, this typically includes:

  • Findings: Detailed description of observations
  • Impression: Summary and diagnostic conclusions
  • Comparison: Reference to prior studies (if available)

15.4.2 Architecture for Report Generation

Report generation uses an encoder-decoder architecture:

  1. Image encoder: Extract visual features (CNN or ViT)
  2. Cross-attention: Allow decoder to attend to relevant image regions
  3. Text decoder: Generate report autoregressively
import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Config

class ReportGenerator(nn.Module):
    """Medical report generation from images."""

    def __init__(self, image_encoder, vocab_size, hidden_dim=768, num_layers=6):
        super().__init__()
        self.image_encoder = image_encoder

        # Project image features
        self.image_projection = nn.Linear(image_encoder.output_dim, hidden_dim)

        # Text decoder with cross-attention
        config = GPT2Config(
            vocab_size=vocab_size,
            n_embd=hidden_dim,
            n_layer=num_layers,
            n_head=12,
            add_cross_attention=True  # Enable cross-attention to image
        )
        self.decoder = GPT2LMHeadModel(config)

    def forward(self, images, input_ids, attention_mask=None):
        # Encode images
        image_features = self.image_encoder(images)
        encoder_hidden_states = self.image_projection(image_features)

        # Generate with cross-attention to image
        outputs = self.decoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            labels=input_ids  # For training
        )

        return outputs

    def generate_report(self, images, tokenizer, max_length=512):
        """Generate a report for the given images."""
        self.eval()

        # Encode images
        image_features = self.image_encoder(images)
        encoder_hidden_states = self.image_projection(image_features)

        # Start with BOS token
        input_ids = torch.tensor([[tokenizer.bos_token_id]]).to(images.device)

        # Generate autoregressively
        with torch.no_grad():
            output_ids = self.decoder.generate(
                input_ids=input_ids,
                encoder_hidden_states=encoder_hidden_states,
                max_length=max_length,
                num_beams=4,
                early_stopping=True,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id
            )

        report = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        return report

15.4.3 Using VLMs for Report Generation

Modern approach: leverage pretrained VLMs with appropriate prompting.

def generate_radiology_report(model, processor, image, prior_report=None):
    """Generate a structured radiology report using a VLM."""

    if prior_report:
        prompt = f"""<image>
You are an expert radiologist. Generate a detailed radiology report for this chest X-ray.

Prior study report for comparison:
{prior_report}

Please provide:
1. FINDINGS: Detailed description of all observations
2. IMPRESSION: Summary and diagnostic conclusions
3. COMPARISON: Changes from prior study

Report:"""
    else:
        prompt = """<image>
You are an expert radiologist. Generate a detailed radiology report for this chest X-ray.

Please provide:
1. FINDINGS: Detailed description of all observations
2. IMPRESSION: Summary and diagnostic conclusions

Report:"""

    inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)

    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=512,
            do_sample=False,
            num_beams=4
        )

    report = processor.decode(output_ids[0], skip_special_tokens=True)
    return report.split("Report:")[-1].strip()

# Generate report
image = load_chest_xray("patient_001.png")
report = generate_radiology_report(model, processor, image)
print(report)

15.4.4 Evaluation Metrics

Report generation requires multiple evaluation approaches:

Natural language metrics: - BLEU: N-gram overlap with reference reports - ROUGE: Recall-oriented overlap measure - BERTScore: Semantic similarity using embeddings

Clinical metrics: - CheXbert labeler: Extract findings, compare to reference - RadGraph F1: Compare clinical entities and relations - Radiologist review: Gold standard but expensive

from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer

def evaluate_report(generated, reference):
    """Evaluate generated report against reference."""

    # BLEU score
    reference_tokens = reference.lower().split()
    generated_tokens = generated.lower().split()
    smoothing = SmoothingFunction().method1
    bleu = sentence_bleu([reference_tokens], generated_tokens, smoothing_function=smoothing)

    # ROUGE scores
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    rouge_scores = scorer.score(reference, generated)

    return {
        'bleu': bleu,
        'rouge1': rouge_scores['rouge1'].fmeasure,
        'rouge2': rouge_scores['rouge2'].fmeasure,
        'rougeL': rouge_scores['rougeL'].fmeasure
    }

def evaluate_clinical_accuracy(generated, reference, chexbert_labeler):
    """
    Evaluate clinical accuracy using CheXbert.
    Extracts 14 finding labels and compares.
    """
    gen_labels = chexbert_labeler(generated)
    ref_labels = chexbert_labeler(reference)

    # Compare labels
    matches = sum(g == r for g, r in zip(gen_labels, ref_labels))
    accuracy = matches / len(gen_labels)

    return {'clinical_accuracy': accuracy, 'gen_labels': gen_labels, 'ref_labels': ref_labels}

15.4.5 Current Limitations

Report generation faces significant challenges:

  1. Hallucination: Models may describe findings not present in the image
  2. Omission: Models may miss subtle but important findings
  3. Inconsistency: Same image may produce different reports
  4. Liability: Who is responsible for errors in AI-generated reports?

Best practice: Use AI-generated reports as drafts for radiologist review, not as final outputs.

15.5 Multimodal Fusion Strategies

Clinical Context: You’re building a model to predict ICU mortality. You have chest X-rays, clinical notes, and structured vitals/labs. How do you combine these different data types into a single prediction?

15.5.1 Early Fusion

Early fusion combines raw or lightly processed features before the main model.

class EarlyFusionModel(nn.Module):
    """Combine features early, then process jointly."""

    def __init__(self, image_dim, text_dim, tabular_dim, hidden_dim, num_classes):
        super().__init__()

        # Simple projections to common dimension
        self.image_proj = nn.Linear(image_dim, hidden_dim)
        self.text_proj = nn.Linear(text_dim, hidden_dim)
        self.tabular_proj = nn.Linear(tabular_dim, hidden_dim)

        # Joint processing after concatenation
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 3, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, num_classes)
        )

    def forward(self, image_features, text_features, tabular_features):
        # Project each modality
        img = self.image_proj(image_features)
        txt = self.text_proj(text_features)
        tab = self.tabular_proj(tabular_features)

        # Concatenate
        combined = torch.cat([img, txt, tab], dim=-1)

        # Classify
        return self.classifier(combined)

Pros: Simple, allows cross-modal interactions early Cons: Requires all modalities present, can be dominated by one modality

15.5.2 Late Fusion

Late fusion trains separate models per modality, then combines predictions.

class LateFusionModel(nn.Module):
    """Combine predictions from separate modality-specific models."""

    def __init__(self, image_model, text_model, tabular_model, num_classes):
        super().__init__()
        self.image_model = image_model
        self.text_model = text_model
        self.tabular_model = tabular_model

        # Fusion layer combines predictions
        self.fusion = nn.Sequential(
            nn.Linear(num_classes * 3, num_classes * 2),
            nn.ReLU(),
            nn.Linear(num_classes * 2, num_classes)
        )

    def forward(self, image_features, text_features, tabular_features,
                image_mask=None, text_mask=None, tabular_mask=None):
        predictions = []

        # Get predictions from each modality
        if image_mask is None or image_mask.any():
            img_pred = self.image_model(image_features)
            predictions.append(img_pred)
        else:
            predictions.append(torch.zeros_like(img_pred))

        if text_mask is None or text_mask.any():
            txt_pred = self.text_model(text_features)
            predictions.append(txt_pred)
        else:
            predictions.append(torch.zeros_like(txt_pred))

        if tabular_mask is None or tabular_mask.any():
            tab_pred = self.tabular_model(tabular_features)
            predictions.append(tab_pred)
        else:
            predictions.append(torch.zeros_like(tab_pred))

        # Combine predictions
        combined = torch.cat(predictions, dim=-1)
        return self.fusion(combined)

    def forward_with_missing(self, features_dict):
        """Handle missing modalities gracefully."""
        predictions = []

        if 'image' in features_dict:
            predictions.append(self.image_model(features_dict['image']))

        if 'text' in features_dict:
            predictions.append(self.text_model(features_dict['text']))

        if 'tabular' in features_dict:
            predictions.append(self.tabular_model(features_dict['tabular']))

        # Average available predictions
        return torch.stack(predictions).mean(dim=0)

Pros: Handles missing modalities, modality-specific architectures Cons: No cross-modal interaction during training

15.5.3 Cross-Attention Fusion

Cross-attention allows modalities to attend to each other, learning what information to share.

class CrossAttentionFusion(nn.Module):
    """Modalities attend to each other via cross-attention."""

    def __init__(self, hidden_dim, num_heads=8):
        super().__init__()

        # Image attends to text
        self.img_to_txt_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)

        # Text attends to image
        self.txt_to_img_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)

        # Layer norms
        self.img_norm = nn.LayerNorm(hidden_dim)
        self.txt_norm = nn.LayerNorm(hidden_dim)

    def forward(self, image_tokens, text_tokens):
        """
        Args:
            image_tokens: [batch, num_patches, hidden_dim]
            text_tokens: [batch, seq_len, hidden_dim]
        """
        # Image attends to text
        img_attended, _ = self.img_to_txt_attn(
            query=image_tokens,
            key=text_tokens,
            value=text_tokens
        )
        image_out = self.img_norm(image_tokens + img_attended)

        # Text attends to image
        txt_attended, _ = self.txt_to_img_attn(
            query=text_tokens,
            key=image_tokens,
            value=image_tokens
        )
        text_out = self.txt_norm(text_tokens + txt_attended)

        return image_out, text_out


class MultimodalTransformer(nn.Module):
    """Full multimodal model with cross-attention fusion."""

    def __init__(self, image_encoder, text_encoder, hidden_dim, num_classes, num_fusion_layers=4):
        super().__init__()
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder

        # Project to common dimension
        self.image_proj = nn.Linear(image_encoder.output_dim, hidden_dim)
        self.text_proj = nn.Linear(text_encoder.output_dim, hidden_dim)

        # Cross-attention fusion layers
        self.fusion_layers = nn.ModuleList([
            CrossAttentionFusion(hidden_dim) for _ in range(num_fusion_layers)
        ])

        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, images, text_input_ids, text_attention_mask):
        # Encode modalities
        image_features = self.image_encoder(images).last_hidden_state
        text_features = self.text_encoder(text_input_ids, text_attention_mask).last_hidden_state

        # Project to common space
        image_tokens = self.image_proj(image_features)
        text_tokens = self.text_proj(text_features)

        # Cross-attention fusion
        for fusion_layer in self.fusion_layers:
            image_tokens, text_tokens = fusion_layer(image_tokens, text_tokens)

        # Pool and classify
        image_pooled = image_tokens.mean(dim=1)
        text_pooled = text_tokens.mean(dim=1)
        combined = torch.cat([image_pooled, text_pooled], dim=-1)

        return self.classifier(combined)

Pros: Rich cross-modal interactions, learns alignment Cons: Computationally expensive, requires both modalities during training

15.5.4 Choosing a Fusion Strategy

Scenario Recommended Strategy
Simple baseline Late fusion
Missing modalities common Late fusion with averaging
Cross-modal reasoning needed Cross-attention
Limited compute Early fusion
Pretrained unimodal models Late fusion

15.6 Combining Imaging and EHR Data

Clinical Context: A chest X-ray shows an ambiguous opacity. Is it pneumonia, atelectasis, or malignancy? The answer might depend on context: is the patient febrile? immunocompromised? a smoker? Combining imaging with EHR data can resolve ambiguities that imaging alone cannot.

15.6.1 Why Add Structured Data?

Structured EHR data provides context that images lack:

  • Demographics: Age, sex (some conditions more common in certain groups)
  • Vitals: Fever suggests infection; hypotension suggests severity
  • Labs: Elevated WBC, procalcitonin support infectious etiology
  • History: Prior diagnoses, surgeries, risk factors
  • Medications: Immunosuppression, anticoagulation affect interpretation

15.6.2 Feature Extraction from EHR

import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer

def prepare_ehr_features(ehr_df):
    """
    Prepare structured EHR features for multimodal fusion.

    Returns normalized feature vector for each patient.
    """
    # Define feature groups
    numeric_features = [
        'age', 'temperature', 'heart_rate', 'respiratory_rate',
        'systolic_bp', 'o2_saturation', 'wbc', 'creatinine', 'bun'
    ]

    categorical_features = ['sex', 'admission_type', 'insurance']

    binary_features = [
        'diabetes', 'hypertension', 'copd', 'chf', 'ckd',
        'immunocompromised', 'smoker', 'prior_pneumonia'
    ]

    # Preprocessing pipeline
    preprocessor = ColumnTransformer([
        ('numeric', StandardScaler(), numeric_features),
        ('categorical', OneHotEncoder(sparse_output=False, handle_unknown='ignore'), categorical_features),
        ('binary', 'passthrough', binary_features)
    ])

    features = preprocessor.fit_transform(ehr_df)
    return features, preprocessor

# Example EHR data
ehr_data = pd.DataFrame({
    'patient_id': [1, 2, 3],
    'age': [65, 45, 78],
    'sex': ['M', 'F', 'M'],
    'temperature': [38.5, 37.0, 39.2],
    'heart_rate': [95, 72, 110],
    'respiratory_rate': [22, 16, 28],
    'systolic_bp': [130, 120, 95],
    'o2_saturation': [92, 98, 88],
    'wbc': [15.2, 8.1, 18.5],
    'creatinine': [1.2, 0.9, 2.1],
    'bun': [25, 15, 45],
    'admission_type': ['emergency', 'elective', 'emergency'],
    'insurance': ['medicare', 'commercial', 'medicare'],
    'diabetes': [1, 0, 1],
    'hypertension': [1, 0, 1],
    'copd': [1, 0, 0],
    'chf': [0, 0, 1],
    'ckd': [0, 0, 1],
    'immunocompromised': [0, 0, 0],
    'smoker': [1, 0, 1],
    'prior_pneumonia': [1, 0, 0]
})

features, preprocessor = prepare_ehr_features(ehr_data)
print(f"EHR feature dimension: {features.shape[1]}")

15.6.3 Image + Tabular Fusion Model

import torch
import torch.nn as nn
from torchvision import models

class ImageTabularFusion(nn.Module):
    """Combine chest X-ray with structured EHR data."""

    def __init__(self, tabular_dim, num_classes, hidden_dim=256):
        super().__init__()

        # Image encoder (pretrained ResNet)
        self.image_encoder = models.resnet50(weights='IMAGENET1K_V1')
        image_feature_dim = self.image_encoder.fc.in_features
        self.image_encoder.fc = nn.Identity()

        # Tabular encoder
        self.tabular_encoder = nn.Sequential(
            nn.Linear(tabular_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

        # Image projection
        self.image_proj = nn.Sequential(
            nn.Linear(image_feature_dim, hidden_dim),
            nn.ReLU()
        )

        # Fusion with gating mechanism
        self.gate = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Sigmoid()
        )

        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, num_classes)
        )

    def forward(self, images, tabular):
        # Encode each modality
        img_features = self.image_encoder(images)
        img_features = self.image_proj(img_features)

        tab_features = self.tabular_encoder(tabular)

        # Gated fusion
        combined = torch.cat([img_features, tab_features], dim=-1)
        gate = self.gate(combined)
        fused = gate * img_features + (1 - gate) * tab_features

        # Classify
        return self.classifier(fused)


# Training loop
def train_multimodal(model, train_loader, val_loader, epochs=20):
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    best_val_auc = 0

    for epoch in range(epochs):
        model.train()
        train_loss = 0

        for images, tabular, labels in train_loader:
            images = images.to(device)
            tabular = tabular.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images, tabular)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        # Validation
        model.eval()
        val_probs = []
        val_labels = []

        with torch.no_grad():
            for images, tabular, labels in val_loader:
                images = images.to(device)
                tabular = tabular.to(device)

                outputs = model(images, tabular)
                probs = torch.softmax(outputs, dim=1)[:, 1]

                val_probs.extend(probs.cpu().numpy())
                val_labels.extend(labels.numpy())

        from sklearn.metrics import roc_auc_score
        val_auc = roc_auc_score(val_labels, val_probs)

        print(f"Epoch {epoch+1}: Train Loss={train_loss/len(train_loader):.4f}, Val AUC={val_auc:.4f}")

        if val_auc > best_val_auc:
            best_val_auc = val_auc
            torch.save(model.state_dict(), 'best_multimodal.pt')

    return best_val_auc

15.6.4 Handling Missing Modalities

In practice, not all patients have all data. Handle this gracefully:

class RobustMultimodalModel(nn.Module):
    """Handle missing modalities during training and inference."""

    def __init__(self, image_encoder, tabular_encoder, hidden_dim, num_classes):
        super().__init__()
        self.image_encoder = image_encoder
        self.tabular_encoder = tabular_encoder
        self.hidden_dim = hidden_dim

        # Learnable default embeddings for missing modalities
        self.default_image_embedding = nn.Parameter(torch.randn(1, hidden_dim))
        self.default_tabular_embedding = nn.Parameter(torch.randn(1, hidden_dim))

        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, images=None, tabular=None, image_mask=None, tabular_mask=None):
        batch_size = images.shape[0] if images is not None else tabular.shape[0]
        device = images.device if images is not None else tabular.device

        # Encode available modalities
        if images is not None and (image_mask is None or image_mask.any()):
            img_features = self.image_encoder(images)
        else:
            img_features = self.default_image_embedding.expand(batch_size, -1)

        if tabular is not None and (tabular_mask is None or tabular_mask.any()):
            tab_features = self.tabular_encoder(tabular)
        else:
            tab_features = self.default_tabular_embedding.expand(batch_size, -1)

        # Apply masks if provided
        if image_mask is not None:
            img_features = torch.where(
                image_mask.unsqueeze(-1),
                img_features,
                self.default_image_embedding.expand(batch_size, -1)
            )

        if tabular_mask is not None:
            tab_features = torch.where(
                tabular_mask.unsqueeze(-1),
                tab_features,
                self.default_tabular_embedding.expand(batch_size, -1)
            )

        combined = torch.cat([img_features, tab_features], dim=-1)
        return self.classifier(combined)

15.6.5 When Multimodal Helps

Multimodal fusion typically improves performance when:

  1. Modalities are complementary: Each provides unique information
  2. Task requires context: Imaging interpretation depends on clinical setting
  3. Single modality is ambiguous: Multiple interpretations possible from image alone

It may not help when:

  1. One modality dominates: The other adds only noise
  2. Modalities are redundant: Same information in different forms
  3. Data quality varies: Poor-quality modality hurts fusion

Always compare multimodal to best single-modality baseline.

15.7 Evaluation and Challenges

15.7.1 Evaluating Multimodal Models

Standard metrics apply, but consider:

  1. Modality ablation: How much does each modality contribute?
  2. Missing modality performance: How does the model degrade with missing data?
  3. Cross-modal consistency: Do predictions make sense given all inputs?
def modality_ablation_study(model, test_loader, device):
    """Evaluate model with different modality combinations."""
    results = {}

    # Full model (both modalities)
    auc_full = evaluate_model(model, test_loader, device,
                              use_image=True, use_tabular=True)
    results['full'] = auc_full

    # Image only
    auc_image = evaluate_model(model, test_loader, device,
                               use_image=True, use_tabular=False)
    results['image_only'] = auc_image

    # Tabular only
    auc_tabular = evaluate_model(model, test_loader, device,
                                 use_image=False, use_tabular=True)
    results['tabular_only'] = auc_tabular

    print("Modality Ablation Results:")
    print(f"  Full model (image + tabular): AUC = {auc_full:.4f}")
    print(f"  Image only: AUC = {auc_image:.4f}")
    print(f"  Tabular only: AUC = {auc_tabular:.4f}")
    print(f"  Multimodal improvement over best unimodal: "
          f"{auc_full - max(auc_image, auc_tabular):+.4f}")

    return results

15.7.2 Key Challenges

Hallucination in Generation

VLMs can generate plausible-sounding but incorrect findings. Mitigation: - Constrain generation vocabulary - Verify findings against image-level classifiers - Require human review

Modality Imbalance

One modality may dominate learning. Solutions: - Gradient balancing across modalities - Modality dropout during training - Uncertainty-weighted fusion

Computational Cost

Multimodal models are expensive. Consider: - Efficient architectures (early exit, pruning) - Modality-specific compression - When to use full multimodal vs. single modality

Regulatory Pathway

Multimodal medical devices face complex approval: - Each modality may have different risk profiles - Interaction effects must be validated - Missing modality handling needs documentation

15.8 The Future: Medical Foundation Models

15.8.1 The Vision: Generalist Medical AI

Recent models point toward generalist medical AI—systems that handle diverse tasks across modalities:

Med-PaLM M (Google): Single model for medical QA, radiology, pathology, genomics, and dermatology.

GPT-4V (OpenAI): General vision-language model with emerging medical capabilities.

GigaPath (Microsoft): Pathology foundation model trained on 170,000 whole slide images.

These models suggest a future where one system handles many medical AI tasks, similar to how radiologists develop broad expertise.

15.8.2 Implications for Clinical Practice

If foundation models succeed:

  1. Reduced development cost: Fine-tune one model for many tasks
  2. Cross-task transfer: Learning from one task improves others
  3. Unified interfaces: Single system for diverse clinical AI needs
  4. New capabilities: Emergent abilities from scale

15.8.3 Remaining Challenges

  1. Validation: How do you validate a model that does everything?
  2. Failure modes: Generalist models may fail in unexpected ways
  3. Specialization vs. generalization: When is a specialist model better?
  4. Deployment complexity: Running large models in clinical settings
  5. Liability: Who is responsible when a general-purpose AI errs?

15.8.4 What This Means for Students

The field is moving fast. Key skills:

  1. Understand architectures: Know how multimodal models work
  2. Evaluation expertise: Be able to assess multimodal systems rigorously
  3. Clinical grounding: Always connect technical work to clinical needs
  4. Adaptability: Today’s SOTA will be tomorrow’s baseline

15.9 Chapter Summary

Multimodal AI represents the frontier of medical AI, combining information as clinicians do.

Contrastive learning (CLIP): - Learns shared image-text embeddings - Enables zero-shot classification - Medical variants (BiomedCLIP, MedCLIP) understand clinical content

Vision-language models: - LLaVA architecture: vision encoder + LLM - Visual question answering on medical images - Medical VLMs (LLaVA-Med, Med-PaLM M) for clinical applications

Report generation: - Encoder-decoder architecture with cross-attention - VLMs can generate structured reports - Evaluation requires both NLP and clinical metrics

Fusion strategies: - Early fusion: combine features before processing - Late fusion: combine predictions from separate models - Cross-attention: modalities attend to each other

Image + EHR integration: - Structured data provides clinical context - Gated fusion learns modality weighting - Handle missing modalities gracefully

The path forward leads to foundation models that integrate all medical data modalities. Understanding multimodal architectures today prepares you for this future.

15.10 Exercises

  1. Implement zero-shot classification with a medical CLIP model. Compare performance on chest X-ray classification using different prompt templates. Which prompts work best?

  2. Build a simple report generation system using a pretrained VLM. Evaluate on a small held-out set using BLEU, ROUGE, and manual clinical review. What types of errors does the model make?

  3. Implement early, late, and cross-attention fusion for a binary classification task using synthetic image and tabular features. Compare training dynamics and final performance.

  4. Create a dataset where some samples have missing modalities. Implement and compare strategies for handling missing data: zero imputation, learned defaults, and modality dropout.

  5. Design an evaluation protocol for a multimodal radiology AI system that will be deployed clinically. What metrics would you track? How would you monitor for failure modes?