18  Interpretability & Uncertainty

When deploying AI in clinical settings, “the model says so” is insufficient. Clinicians need to understand why a model made a prediction and how confident that prediction is. This chapter covers methods to explain model decisions and quantify uncertainty.

18.1 Why Interpretability Matters

Clinical Context: A radiologist reviewing an AI-flagged chest X-ray needs more than a probability score. Which regions drove the prediction? Are there image artifacts that might have fooled the model? Interpretability tools help clinicians appropriately trust, question, and override AI recommendations.

Interpretability serves multiple purposes:

  • Clinical validation: Confirm the model uses medically relevant features (e.g., looking at the lung, not the scanner label)
  • Debugging: Identify spurious correlations and dataset biases
  • Regulatory compliance: FDA guidance increasingly expects explainability
  • Education: Help trainees understand what patterns indicate disease
  • Trust calibration: Know when to rely on vs. question predictions

18.2 Feature Attribution Methods

Feature attribution methods assign importance scores to input features, showing which parts of the input most influenced the prediction.

18.2.1 Gradient-Based Methods

The simplest approach: compute the gradient of the output with respect to each input pixel. Large gradients indicate pixels that, if changed slightly, would significantly affect the prediction.

Vanilla Gradients: \(\text{Attribution} = \left| \frac{\partial y}{\partial x} \right|\)

Limitations: Gradients can be noisy and may highlight edges rather than semantically meaningful regions.

Integrated Gradients average gradients along a path from a baseline (e.g., black image) to the actual input:

\[ \text{IG}_i = (x_i - x'_i) \times \int_0^1 \frac{\partial F(x' + \alpha(x-x'))}{\partial x_i} d\alpha \]

This satisfies desirable axioms (sensitivity, implementation invariance) and often produces cleaner attributions.

18.2.2 Grad-CAM for CNNs

Grad-CAM (Gradient-weighted Class Activation Mapping) creates visual explanations by weighting feature maps by their importance to the target class (Selvaraju et al. 2017). It works with any CNN architecture.

For target class \(c\) and feature maps \(A^k\) from the last convolutional layer:

\[ \alpha_k^c = \frac{1}{Z} \sum_i \sum_j \frac{\partial y^c}{\partial A_{ij}^k} \]

\[ L^c_{\text{Grad-CAM}} = \text{ReLU}\left(\sum_k \alpha_k^c A^k\right) \]

The result is a coarse heatmap highlighting image regions important for the prediction.

import torch
import torch.nn.functional as F
from torchvision import models

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.gradients = None
        self.activations = None

        # Register hooks
        target_layer.register_forward_hook(self.save_activation)
        target_layer.register_backward_hook(self.save_gradient)

    def save_activation(self, module, input, output):
        self.activations = output.detach()

    def save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()

    def generate(self, input_image, target_class):
        # Forward pass
        output = self.model(input_image)
        self.model.zero_grad()

        # Backward pass for target class
        one_hot = torch.zeros_like(output)
        one_hot[0, target_class] = 1
        output.backward(gradient=one_hot)

        # Compute weights and weighted combination
        weights = self.gradients.mean(dim=[2, 3], keepdim=True)
        cam = (weights * self.activations).sum(dim=1, keepdim=True)
        cam = F.relu(cam)

        # Normalize to [0, 1]
        cam = cam - cam.min()
        cam = cam / (cam.max() + 1e-8)
        return cam

# Usage
model = models.resnet18(weights='IMAGENET1K_V1')
gradcam = GradCAM(model, model.layer4[-1])
heatmap = gradcam.generate(image, target_class=1)

18.2.3 SHAP for Tabular Data

SHAP (SHapley Additive exPlanations) uses game-theoretic Shapley values to fairly distribute the prediction among input features (Lundberg and Lee 2017). For each feature, SHAP computes its average marginal contribution across all possible feature subsets.

For tabular clinical data (lab values, demographics, vitals), SHAP provides both local explanations (why this patient) and global feature importance.

import shap
from sklearn.ensemble import RandomForestClassifier

# Train model on diabetes dataset
model = RandomForestClassifier(n_estimators=100)
model.fit(X_train, y_train)

# Create SHAP explainer
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)

# Visualize for a single patient
shap.force_plot(explainer.expected_value[1],
                shap_values[1][0],
                X_test.iloc[0],
                feature_names=feature_names)

# Global feature importance
shap.summary_plot(shap_values[1], X_test, feature_names=feature_names)

SHAP values have useful properties:

  • Local accuracy: Attributions sum to the difference between the prediction and expected value
  • Consistency: If a feature contributes more in one model, its attribution is higher
  • Positive/negative: Signs indicate direction of effect

18.3 Model Calibration

Clinical Context: When a model outputs 80% probability of pneumonia, how often is pneumonia actually present? A calibrated model’s predicted probabilities match observed frequencies. Miscalibration can lead to over- or under-treatment.

18.3.1 Reliability Diagrams

A reliability diagram bins predictions by confidence and plots the actual frequency of positive cases in each bin. A perfectly calibrated model falls on the diagonal.

import numpy as np
import matplotlib.pyplot as plt
from sklearn.calibration import calibration_curve

# y_true: binary labels, y_prob: predicted probabilities
fraction_pos, mean_predicted = calibration_curve(
    y_true, y_prob, n_bins=10, strategy='uniform'
)

plt.figure(figsize=(6, 6))
plt.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
plt.plot(mean_predicted, fraction_pos, 'o-', label='Model')
plt.xlabel('Mean predicted probability')
plt.ylabel('Fraction of positives')
plt.legend()
plt.title('Reliability Diagram')

18.3.2 Brier Score

The Brier score measures calibration and accuracy together:

\[ \text{Brier} = \frac{1}{N}\sum_{i=1}^N (p_i - y_i)^2 \]

where \(p_i\) is the predicted probability and \(y_i \in \{0, 1\}\) is the true label. Lower is better; 0 is perfect.

The Brier score can be decomposed into calibration, refinement (discrimination ability), and uncertainty components.

18.3.3 Calibration Methods

Common approaches to improve calibration:

Platt Scaling: Fit a logistic regression to transform logits:

\[ p_{\text{calibrated}} = \sigma(a \cdot z + b) \]

where \(z\) is the model’s raw logit and \(a, b\) are learned on a held-out calibration set.

Temperature Scaling: Special case of Platt scaling with \(b=0\):

\[ p_{\text{calibrated}} = \text{softmax}(z / T) \]

The temperature \(T > 1\) softens confident predictions.

from sklearn.calibration import CalibratedClassifierCV

# Calibrate using Platt scaling (sigmoid)
calibrated_model = CalibratedClassifierCV(
    base_model, method='sigmoid', cv='prefit'
)
calibrated_model.fit(X_calib, y_calib)

# Get calibrated probabilities
probs = calibrated_model.predict_proba(X_test)[:, 1]

18.4 Uncertainty Quantification

Clinical Context: A model’s prediction on a patient with unusual lab values should come with higher uncertainty. Uncertainty estimates help identify cases requiring extra review and support appropriate human oversight.

18.4.1 Types of Uncertainty

  • Aleatoric uncertainty: Inherent noise in the data (e.g., ambiguous images). Cannot be reduced with more training data.
  • Epistemic uncertainty: Model uncertainty due to limited training data. Reducible with more examples.

For clinical decision support, epistemic uncertainty is especially important: it flags out-of-distribution inputs where the model may be unreliable.

18.4.2 Monte Carlo Dropout

A simple uncertainty estimate: enable dropout at inference time and run multiple forward passes. The variance in predictions reflects model uncertainty.

import torch

def mc_dropout_predict(model, x, n_samples=30):
    model.train()  # Enable dropout
    predictions = []

    with torch.no_grad():
        for _ in range(n_samples):
            pred = torch.softmax(model(x), dim=1)
            predictions.append(pred)

    predictions = torch.stack(predictions)
    mean_pred = predictions.mean(dim=0)
    uncertainty = predictions.std(dim=0)
    return mean_pred, uncertainty

# High uncertainty suggests the model is unsure
mean, std = mc_dropout_predict(model, patient_data)
if std[0, 1] > threshold:
    print("High uncertainty - recommend human review")

18.4.3 Conformal Prediction

Conformal prediction provides prediction sets with guaranteed coverage: the true label is contained with probability at least \(1 - \alpha\).

Basic approach:

  1. Compute a nonconformity score for calibration examples (e.g., \(1 - p_{\text{correct}}\))
  2. Find the \((1-\alpha)\) quantile \(\hat{q}\) of these scores
  3. For new inputs, include all classes with \(p_c \geq 1 - \hat{q}\)
import numpy as np

def conformal_calibrate(y_true, y_prob, alpha=0.1):
    """Compute threshold for (1-alpha) coverage."""
    # Nonconformity: 1 - probability of true class
    scores = 1 - y_prob[np.arange(len(y_true)), y_true]
    # Quantile with finite-sample correction
    n = len(scores)
    q = np.quantile(scores, np.ceil((n+1)*(1-alpha))/n)
    return q

def conformal_predict(y_prob, threshold):
    """Return prediction set for each example."""
    return y_prob >= (1 - threshold)

# Calibrate on held-out data
threshold = conformal_calibrate(y_calib, probs_calib, alpha=0.1)

# Predict sets with 90% coverage guarantee
pred_sets = conformal_predict(probs_test, threshold)

Conformal prediction is distribution-free: the coverage guarantee holds regardless of the model or data distribution (given exchangeability).

18.5 Clinical Trust Guidelines

When should clinicians trust model predictions? Consider these factors:

  • Confidence level: High probability (>90%) with good calibration deserves more weight
  • Uncertainty: High uncertainty warrants extra scrutiny
  • Explanation quality: Does Grad-CAM highlight clinically relevant regions?
  • Patient characteristics: Is this patient similar to training data?
  • Clinical context: High-stakes decisions require higher confidence thresholds

For regulatory purposes, document:

  • Model performance on relevant subpopulations
  • Calibration metrics with confidence intervals
  • Known failure modes and out-of-distribution indicators
  • Recommended use conditions and override procedures

Models should support, not replace, clinical judgment. The most effective human-AI teams combine machine pattern recognition with human contextual reasoning.

18.6 Quick Reference: The Clinician’s Checklist

Before trusting an AI tool’s output, ask yourself these questions:

If you cannot answer these questions, you are not ready to use the tool clinically. Seek out the model card, validation documentation, or the responsible AI team before proceeding.