import torch
import torch.nn.functional as F
from torchvision import models
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.gradients = None
self.activations = None
# Register hooks
target_layer.register_forward_hook(self.save_activation)
target_layer.register_backward_hook(self.save_gradient)
def save_activation(self, module, input, output):
self.activations = output.detach()
def save_gradient(self, module, grad_input, grad_output):
self.gradients = grad_output[0].detach()
def generate(self, input_image, target_class):
# Forward pass
output = self.model(input_image)
self.model.zero_grad()
# Backward pass for target class
one_hot = torch.zeros_like(output)
one_hot[0, target_class] = 1
output.backward(gradient=one_hot)
# Compute weights and weighted combination
weights = self.gradients.mean(dim=[2, 3], keepdim=True)
cam = (weights * self.activations).sum(dim=1, keepdim=True)
cam = F.relu(cam)
# Normalize to [0, 1]
cam = cam - cam.min()
cam = cam / (cam.max() + 1e-8)
return cam
# Usage
model = models.resnet18(weights='IMAGENET1K_V1')
gradcam = GradCAM(model, model.layer4[-1])
heatmap = gradcam.generate(image, target_class=1)18 Interpretability & Uncertainty
When deploying AI in clinical settings, “the model says so” is insufficient. Clinicians need to understand why a model made a prediction and how confident that prediction is. This chapter covers methods to explain model decisions and quantify uncertainty.
18.1 Why Interpretability Matters
Clinical Context: A radiologist reviewing an AI-flagged chest X-ray needs more than a probability score. Which regions drove the prediction? Are there image artifacts that might have fooled the model? Interpretability tools help clinicians appropriately trust, question, and override AI recommendations.
Interpretability serves multiple purposes:
- Clinical validation: Confirm the model uses medically relevant features (e.g., looking at the lung, not the scanner label)
- Debugging: Identify spurious correlations and dataset biases
- Regulatory compliance: FDA guidance increasingly expects explainability
- Education: Help trainees understand what patterns indicate disease
- Trust calibration: Know when to rely on vs. question predictions
18.2 Feature Attribution Methods
Feature attribution methods assign importance scores to input features, showing which parts of the input most influenced the prediction.
18.2.1 Gradient-Based Methods
The simplest approach: compute the gradient of the output with respect to each input pixel. Large gradients indicate pixels that, if changed slightly, would significantly affect the prediction.
Vanilla Gradients: \(\text{Attribution} = \left| \frac{\partial y}{\partial x} \right|\)
Limitations: Gradients can be noisy and may highlight edges rather than semantically meaningful regions.
Integrated Gradients average gradients along a path from a baseline (e.g., black image) to the actual input:
\[ \text{IG}_i = (x_i - x'_i) \times \int_0^1 \frac{\partial F(x' + \alpha(x-x'))}{\partial x_i} d\alpha \]
This satisfies desirable axioms (sensitivity, implementation invariance) and often produces cleaner attributions.
18.2.2 Grad-CAM for CNNs
Grad-CAM (Gradient-weighted Class Activation Mapping) creates visual explanations by weighting feature maps by their importance to the target class (Selvaraju et al. 2017). It works with any CNN architecture.
For target class \(c\) and feature maps \(A^k\) from the last convolutional layer:
\[ \alpha_k^c = \frac{1}{Z} \sum_i \sum_j \frac{\partial y^c}{\partial A_{ij}^k} \]
\[ L^c_{\text{Grad-CAM}} = \text{ReLU}\left(\sum_k \alpha_k^c A^k\right) \]
The result is a coarse heatmap highlighting image regions important for the prediction.
18.2.3 SHAP for Tabular Data
SHAP (SHapley Additive exPlanations) uses game-theoretic Shapley values to fairly distribute the prediction among input features (Lundberg and Lee 2017). For each feature, SHAP computes its average marginal contribution across all possible feature subsets.
For tabular clinical data (lab values, demographics, vitals), SHAP provides both local explanations (why this patient) and global feature importance.
import shap
from sklearn.ensemble import RandomForestClassifier
# Train model on diabetes dataset
model = RandomForestClassifier(n_estimators=100)
model.fit(X_train, y_train)
# Create SHAP explainer
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)
# Visualize for a single patient
shap.force_plot(explainer.expected_value[1],
shap_values[1][0],
X_test.iloc[0],
feature_names=feature_names)
# Global feature importance
shap.summary_plot(shap_values[1], X_test, feature_names=feature_names)SHAP values have useful properties:
- Local accuracy: Attributions sum to the difference between the prediction and expected value
- Consistency: If a feature contributes more in one model, its attribution is higher
- Positive/negative: Signs indicate direction of effect
18.3 Model Calibration
Clinical Context: When a model outputs 80% probability of pneumonia, how often is pneumonia actually present? A calibrated model’s predicted probabilities match observed frequencies. Miscalibration can lead to over- or under-treatment.
18.3.1 Reliability Diagrams
A reliability diagram bins predictions by confidence and plots the actual frequency of positive cases in each bin. A perfectly calibrated model falls on the diagonal.
import numpy as np
import matplotlib.pyplot as plt
from sklearn.calibration import calibration_curve
# y_true: binary labels, y_prob: predicted probabilities
fraction_pos, mean_predicted = calibration_curve(
y_true, y_prob, n_bins=10, strategy='uniform'
)
plt.figure(figsize=(6, 6))
plt.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
plt.plot(mean_predicted, fraction_pos, 'o-', label='Model')
plt.xlabel('Mean predicted probability')
plt.ylabel('Fraction of positives')
plt.legend()
plt.title('Reliability Diagram')18.3.2 Brier Score
The Brier score measures calibration and accuracy together:
\[ \text{Brier} = \frac{1}{N}\sum_{i=1}^N (p_i - y_i)^2 \]
where \(p_i\) is the predicted probability and \(y_i \in \{0, 1\}\) is the true label. Lower is better; 0 is perfect.
The Brier score can be decomposed into calibration, refinement (discrimination ability), and uncertainty components.
18.3.3 Calibration Methods
Common approaches to improve calibration:
Platt Scaling: Fit a logistic regression to transform logits:
\[ p_{\text{calibrated}} = \sigma(a \cdot z + b) \]
where \(z\) is the model’s raw logit and \(a, b\) are learned on a held-out calibration set.
Temperature Scaling: Special case of Platt scaling with \(b=0\):
\[ p_{\text{calibrated}} = \text{softmax}(z / T) \]
The temperature \(T > 1\) softens confident predictions.
from sklearn.calibration import CalibratedClassifierCV
# Calibrate using Platt scaling (sigmoid)
calibrated_model = CalibratedClassifierCV(
base_model, method='sigmoid', cv='prefit'
)
calibrated_model.fit(X_calib, y_calib)
# Get calibrated probabilities
probs = calibrated_model.predict_proba(X_test)[:, 1]18.4 Uncertainty Quantification
Clinical Context: A model’s prediction on a patient with unusual lab values should come with higher uncertainty. Uncertainty estimates help identify cases requiring extra review and support appropriate human oversight.
18.4.1 Types of Uncertainty
- Aleatoric uncertainty: Inherent noise in the data (e.g., ambiguous images). Cannot be reduced with more training data.
- Epistemic uncertainty: Model uncertainty due to limited training data. Reducible with more examples.
For clinical decision support, epistemic uncertainty is especially important: it flags out-of-distribution inputs where the model may be unreliable.
18.4.2 Monte Carlo Dropout
A simple uncertainty estimate: enable dropout at inference time and run multiple forward passes. The variance in predictions reflects model uncertainty.
import torch
def mc_dropout_predict(model, x, n_samples=30):
model.train() # Enable dropout
predictions = []
with torch.no_grad():
for _ in range(n_samples):
pred = torch.softmax(model(x), dim=1)
predictions.append(pred)
predictions = torch.stack(predictions)
mean_pred = predictions.mean(dim=0)
uncertainty = predictions.std(dim=0)
return mean_pred, uncertainty
# High uncertainty suggests the model is unsure
mean, std = mc_dropout_predict(model, patient_data)
if std[0, 1] > threshold:
print("High uncertainty - recommend human review")18.4.3 Conformal Prediction
Conformal prediction provides prediction sets with guaranteed coverage: the true label is contained with probability at least \(1 - \alpha\).
Basic approach:
- Compute a nonconformity score for calibration examples (e.g., \(1 - p_{\text{correct}}\))
- Find the \((1-\alpha)\) quantile \(\hat{q}\) of these scores
- For new inputs, include all classes with \(p_c \geq 1 - \hat{q}\)
import numpy as np
def conformal_calibrate(y_true, y_prob, alpha=0.1):
"""Compute threshold for (1-alpha) coverage."""
# Nonconformity: 1 - probability of true class
scores = 1 - y_prob[np.arange(len(y_true)), y_true]
# Quantile with finite-sample correction
n = len(scores)
q = np.quantile(scores, np.ceil((n+1)*(1-alpha))/n)
return q
def conformal_predict(y_prob, threshold):
"""Return prediction set for each example."""
return y_prob >= (1 - threshold)
# Calibrate on held-out data
threshold = conformal_calibrate(y_calib, probs_calib, alpha=0.1)
# Predict sets with 90% coverage guarantee
pred_sets = conformal_predict(probs_test, threshold)Conformal prediction is distribution-free: the coverage guarantee holds regardless of the model or data distribution (given exchangeability).
18.5 Clinical Trust Guidelines
When should clinicians trust model predictions? Consider these factors:
- Confidence level: High probability (>90%) with good calibration deserves more weight
- Uncertainty: High uncertainty warrants extra scrutiny
- Explanation quality: Does Grad-CAM highlight clinically relevant regions?
- Patient characteristics: Is this patient similar to training data?
- Clinical context: High-stakes decisions require higher confidence thresholds
For regulatory purposes, document:
- Model performance on relevant subpopulations
- Calibration metrics with confidence intervals
- Known failure modes and out-of-distribution indicators
- Recommended use conditions and override procedures
Models should support, not replace, clinical judgment. The most effective human-AI teams combine machine pattern recognition with human contextual reasoning.
18.6 Quick Reference: The Clinician’s Checklist
Before trusting an AI tool’s output, ask yourself these questions:
If you cannot answer these questions, you are not ready to use the tool clinically. Seek out the model card, validation documentation, or the responsible AI team before proceeding.