import torch
import torch.nn as nn
# Synthetic patient data: [heart_rate, bp_systolic, temp, resp_rate]
x = torch.tensor([[85.0, 120.0, 37.2, 16.0],
[110.0, 85.0, 38.5, 24.0]]) # 2 patients
# Define a single neuron with 4 inputs
linear = nn.Linear(in_features=4, out_features=1)
# Compute pre-activation (weighted sum + bias)
z = linear(x)
print(f"Pre-activation z: {z}")
# Apply different activations
print(f"Sigmoid: {torch.sigmoid(z)}")
print(f"Tanh: {torch.tanh(z)}")
print(f"ReLU: {torch.relu(z)}")7 Deep Neural Networks & Backprop
This chapter introduces neural network architectures and the backpropagation algorithm for training. We build intuition for how networks learn hierarchical representations—a capability that has driven deep learning’s remarkable success across medicine and other domains (LeCun, Bengio, and Hinton 2015). We then cover practical optimization and regularization techniques used in medical AI systems.
7.1 From Pixels to Predictions
Clinical Context: A radiologist glances at a chest X-ray and sees lung fields, cardiac silhouette, and perhaps the hazy opacity of early pneumonia. A computer sees something different: a grid of 224×224 numbers, each representing how much X-ray passed through that spot. Our challenge is bridging these two views—teaching a machine to go from raw pixel values to clinically meaningful predictions.
7.1.1 What the Computer Sees
A digital chest X-ray is a two-dimensional array of intensity values. Each pixel might range from 0 (black, meaning dense tissue absorbed the X-rays) to 255 (white, meaning air let them through). A single image contains roughly 50,000 pixels—50,000 numbers that somehow encode whether this patient has pneumonia, a tumor, or healthy lungs.
The classification problem sounds simple: take these 50,000 inputs and output a probability. “85% chance of bacterial pneumonia.” But how do you write a formula that does this?
7.1.2 Why Simple Models Fail
You might try a linear approach: assign a weight to each pixel, sum them up, and threshold the result. If the weighted sum exceeds some value, predict “pneumonia.” This is essentially what logistic regression does.
The problem is that disease patterns aren’t that simple. Pneumonia doesn’t mean “pixels 1,000 through 2,000 are brighter than average.” It means a consolidation in the lung fields—a specific spatial pattern that could appear in the left lung, right lung, upper lobes, or lower lobes. The same pixels might indicate pneumonia in one patient and healthy tissue in another, depending on what surrounds them.
Linear models can only draw straight lines through pixel space. But the boundary between “pneumonia” and “normal” is a complex, curving surface that depends on how pixels relate to each other.
7.1.3 The Need for Learned Representations
What we need is a system that can learn intermediate representations—transformations of the raw pixels that make the classification problem easier. Perhaps one transformation detects edges. Another combines edges into textures. Another recognizes that a certain texture pattern in the lung region suggests consolidation.
This is exactly what neural networks do. They stack layers of simple transformations, each building on the last, until the final layer sees not 50,000 raw pixels but a handful of learned features like “consolidation present” or “cardiac silhouette enlarged.” At that point, classification becomes straightforward.
The rest of this chapter shows how to build and train these networks. We start with the basic building block—the artificial neuron—and work up to complete systems that learn from thousands of labeled X-rays.
7.2 Neurons and Activation Functions
Clinical Context: A chest X-ray classifier must recognize that certain pixel patterns indicate pathology—but the relationship isn’t straightforward. A bright region might be normal bone, abnormal consolidation, or an artifact, depending on its shape, location, and surrounding context. We need a building block that can learn these nonlinear relationships.
7.2.1 The Artificial Neuron
A single artificial neuron computes a weighted sum of inputs, adds a bias term, and applies a nonlinear activation function:
\[ a = \sigma\bigl(w^\top x + b\bigr) \]
where \(x \in \mathbb{R}^d\) is the input vector, \(w \in \mathbb{R}^d\) are learnable weights, \(b \in \mathbb{R}\) is a learnable bias, and \(\sigma\) is the activation function. The weights determine which input features matter; the activation function introduces nonlinearity.
Without activation functions, stacking multiple layers would collapse to a single linear transformation—no matter how deep the network, it could only learn linear decision boundaries. The activation function is what gives neural networks their representational power.
7.2.2 Common Activation Functions
Several activation functions appear throughout medical AI:
Sigmoid: \(\sigma(z) = \frac{1}{1+e^{-z}}\). Outputs in \((0,1)\), interpretable as probabilities. Used in output layers for binary classification (e.g., disease present/absent). Suffers from vanishing gradients when \(|z|\) is large.
Tanh: \(\tanh(z) = \frac{e^z - e^{-z}}{e^z + e^{-z}}\). Outputs in \((-1,1)\), zero-centered. Similar vanishing gradient issues as sigmoid.
ReLU: \(\text{ReLU}(z) = \max(0, z)\). The workhorse of modern deep learning. Computationally efficient and avoids vanishing gradients for positive inputs. Can suffer from “dead neurons” if inputs are always negative.
Leaky ReLU: \(\text{LeakyReLU}(z) = \max(\alpha z, z)\) with small \(\alpha\) (e.g., 0.01). Addresses the dead neuron problem by allowing small gradients for negative inputs.
GELU: \(\text{GELU}(z) = z \cdot \Phi(z)\) where \(\Phi\) is the standard normal CDF. Used in Transformers; smooth approximation to ReLU.
For hidden layers in image and tabular medical data, ReLU is the default choice. For output layers, use sigmoid for binary classification or softmax for multi-class classification (e.g., classifying chest X-rays into normal, pneumonia, or COVID-19).
7.2.3 PyTorch Activation Example
The following code demonstrates applying activation functions to synthetic patient vitals data:
The nn.Linear layer contains learnable parameters weight and bias. In practice, we stack many such layers with activations between them to form deep networks.
7.3 Feedforward Network Architecture
Clinical Context: Classifying a chest X-ray as normal, pneumonia, or cardiomegaly requires combining information across the entire image. A feedforward neural network processes pixel values through successive layers, each extracting increasingly abstract patterns—from edges to textures to anatomical structures to diagnostic features.
7.3.1 Layers and Dimension Flow
A feedforward neural network (also called a multilayer perceptron or MLP) consists of an input layer, one or more hidden layers, and an output layer. Data flows in one direction: forward from input to output, with no cycles.
Each layer performs two operations:
- Linear transformation: \(z = Wx + b\), where \(W\) is a weight matrix and \(b\) is a bias vector
- Nonlinear activation: \(a = \sigma(z)\)
The dimensions at each layer matter for understanding network capacity:
- Input layer: \(d_{\text{in}}\) features (e.g., 50 patient variables)
- Hidden layer 1: \(h_1\) neurons (e.g., 128 units)
- Hidden layer 2: \(h_2\) neurons (e.g., 64 units)
- Output layer: \(d_{\text{out}}\) classes (e.g., 3 risk levels)
The weight matrix connecting a layer with \(n\) inputs to \(m\) outputs has shape \((m \times n)\), containing \(m \cdot n\) learnable parameters. A network with layers [50, 128, 64, 3] has \(50 \times 128 + 128 \times 64 + 64 \times 3 = 14,784\) weights, plus biases.
7.3.2 The Universal Approximation Theorem
A foundational result states that a feedforward network with a single hidden layer containing enough neurons can approximate any continuous function to arbitrary precision. This universal approximation theorem explains why neural networks are so flexible—but it doesn’t guarantee that training will find good weights, nor that the network will generalize to new data.
In practice, we use deeper networks (more layers) rather than extremely wide single-layer networks. Depth enables hierarchical feature learning: early layers detect simple patterns, later layers combine them into complex concepts.
7.3.3 PyTorch MLP Example
Here is a complete feedforward network for clinical risk prediction:
import torch
import torch.nn as nn
class ClinicalMLP(nn.Module):
def __init__(self, n_features, n_hidden, n_classes):
super().__init__()
self.network = nn.Sequential(
nn.Linear(n_features, n_hidden),
nn.ReLU(),
nn.Linear(n_hidden, n_hidden // 2),
nn.ReLU(),
nn.Linear(n_hidden // 2, n_classes)
)
def forward(self, x):
return self.network(x)
# Example: 50 patient features -> 3 risk levels
model = ClinicalMLP(n_features=50, n_hidden=128, n_classes=3)
# Synthetic batch of 32 patients
x = torch.randn(32, 50)
logits = model(x) # Shape: (32, 3)
# Convert to probabilities for classification
probs = torch.softmax(logits, dim=1)The nn.Sequential container chains layers together automatically. For classification, we typically omit the final activation during training (using nn.CrossEntropyLoss, which applies softmax internally) and add it during inference when we need probabilities.
7.4 Backpropagation and Automatic Differentiation
Clinical Context: When training a chest X-ray classifier, we compare the network’s predictions to radiologist-provided labels. If the model predicts “normal” but the true label is “pneumonia,” we need to adjust the millions of weights throughout the network to reduce this error. Backpropagation efficiently computes how each weight contributed to the mistake, enabling targeted updates.
7.4.1 Loss Functions and Gradients
Training a neural network means finding weights that minimize a loss function \(\mathcal{L}\) measuring prediction error. Common choices:
Cross-entropy loss for classification: \[ \mathcal{L}_{\text{CE}} = -\sum_{c=1}^{C} y_c \log(\hat{y}_c) \] where \(y\) is the one-hot true label and \(\hat{y}\) is the predicted probability distribution.
Binary cross-entropy for binary classification: \[ \mathcal{L}_{\text{BCE}} = -\bigl[y \log(\hat{y}) + (1-y)\log(1-\hat{y})\bigr] \]
Mean squared error for regression: \[ \mathcal{L}_{\text{MSE}} = \frac{1}{n}\sum_{i=1}^{n}(y_i - \hat{y}_i)^2 \]
To minimize the loss, we compute its gradient with respect to each weight: \(\nabla_W \mathcal{L}\). The gradient points in the direction of steepest increase, so we update weights in the opposite direction.
7.4.2 The Chain Rule
Consider a simple two-layer network computing \(\hat{y} = \sigma(W_2 \cdot \text{ReLU}(W_1 x))\). To find how \(\mathcal{L}\) changes with \(W_1\), we apply the chain rule:
\[ \frac{\partial \mathcal{L}}{\partial W_1} = \frac{\partial \mathcal{L}}{\partial \hat{y}} \cdot \frac{\partial \hat{y}}{\partial z_2} \cdot \frac{\partial z_2}{\partial a_1} \cdot \frac{\partial a_1}{\partial z_1} \cdot \frac{\partial z_1}{\partial W_1} \]
Each term represents how one intermediate quantity affects the next. Backpropagation organizes this computation efficiently by:
- Forward pass: Compute and cache all intermediate values (\(z_1, a_1, z_2, \hat{y}\))
- Backward pass: Propagate gradients from output to input, reusing cached values
This avoids redundant computation—each intermediate gradient is computed exactly once.
7.4.3 Computational Graphs
Modern deep learning frameworks represent computations as computational graphs, where nodes are operations and edges carry tensors. During the forward pass, the framework builds this graph. During the backward pass, it traverses the graph in reverse, applying the chain rule at each node.
For example, computing \(\mathcal{L} = (Wx - y)^2\) creates nodes for matrix multiply, subtract, and square. The backward pass computes:
\[ \frac{\partial \mathcal{L}}{\partial W} = 2(Wx - y) \cdot x^\top \]
automatically by composing the local gradients of each operation.
7.4.4 PyTorch Autograd
PyTorch implements automatic differentiation via its autograd system. Any tensor with requires_grad=True tracks operations for gradient computation:
import torch
import torch.nn as nn
# Simple network and loss
model = nn.Linear(10, 1)
criterion = nn.BCEWithLogitsLoss()
# Forward pass
x = torch.randn(32, 10) # 32 patients, 10 features
y = torch.randint(0, 2, (32, 1)).float() # Binary labels
logits = model(x)
loss = criterion(logits, y)
# Backward pass: compute all gradients
loss.backward()
# Gradients are now stored in model.weight.grad
print(f"Weight gradient shape: {model.weight.grad.shape}")
print(f"Bias gradient shape: {model.bias.grad.shape}")The loss.backward() call traverses the computational graph, computing \(\partial\mathcal{L}/\partial w\) for every parameter \(w\). These gradients are stored in the .grad attribute of each parameter tensor, ready for the optimizer to use.
Key points for medical AI practitioners:
- Call
optimizer.zero_grad()before each backward pass to clear old gradients - Use
torch.no_grad()during inference to save memory - Gradient values can diagnose training problems (e.g., vanishing or exploding gradients)
7.5 Optimization Algorithms
Clinical Context: Training a diagnostic model on 100,000 chest X-rays could take days or weeks. Choosing the right optimizer and learning rate schedule dramatically affects both training speed and final model quality. A well-tuned optimizer finds good weights faster and often reaches better solutions.
7.5.1 Stochastic Gradient Descent
Stochastic gradient descent (SGD) updates weights using gradients computed on random mini-batches rather than the full dataset:
\[ W \leftarrow W - \eta \nabla_W \mathcal{L} \]
where \(\eta\) is the learning rate. Mini-batch training has two benefits:
- Computational efficiency: Gradient computation scales with batch size, not dataset size
- Regularization effect: Noise from random sampling can help escape local minima
SGD with momentum adds a velocity term that accumulates past gradients:
\[ v \leftarrow \beta v + \nabla_W \mathcal{L}, \quad W \leftarrow W - \eta v \]
Momentum (typically \(\beta = 0.9\)) smooths updates and accelerates convergence, especially in ravine-like loss landscapes.
7.5.2 Adaptive Learning Rates: Adam
Adam (Adaptive Moment Estimation) adapts the learning rate for each parameter based on historical gradients:
\[ m \leftarrow \beta_1 m + (1-\beta_1) g, \quad v \leftarrow \beta_2 v + (1-\beta_2) g^2 \]
\[ W \leftarrow W - \eta \frac{\hat{m}}{\sqrt{\hat{v}} + \epsilon} \]
where \(g = \nabla_W \mathcal{L}\), and \(\hat{m}, \hat{v}\) are bias-corrected estimates.
Adam is the default choice for most medical AI applications because:
- Robust to learning rate choice (less hyperparameter tuning)
- Works well with sparse gradients (common in medical imaging)
- Fast convergence on diverse architectures
Default hyperparameters (\(\beta_1=0.9\), \(\beta_2=0.999\), \(\epsilon=10^{-8}\)) work well in most cases. Learning rates of \(10^{-3}\) to \(10^{-4}\) are typical starting points.
7.5.3 Learning Rate Schedules
A fixed learning rate is rarely optimal. Common schedules:
- Step decay: Reduce by factor (e.g., 0.1) every \(N\) epochs
- Cosine annealing: Smoothly decrease following a cosine curve
- Warmup: Start with small learning rate, gradually increase, then decay
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=100)
for epoch in range(100):
train_one_epoch(model, optimizer)
scheduler.step() # Adjust learning rate7.6 Regularization Techniques
Clinical Context: A chest X-ray classifier trained at one hospital might memorize institution-specific patterns—the particular scanner model, positioning conventions, or patient demographics—rather than learning generalizable radiographic features. Regularization techniques encourage the model to learn simpler, more robust patterns that transfer across imaging sites.
7.6.1 Dropout
Dropout randomly sets a fraction \(p\) of neuron activations to zero during training:
\[ a_{\text{dropped}} = \frac{1}{1-p} \cdot a \cdot \text{mask} \]
where mask is a random binary vector. The \(1/(1-p)\) scaling ensures expected activation magnitude is preserved.
Dropout prevents co-adaptation—neurons cannot rely on specific other neurons, forcing each to learn independently useful features. At inference time, dropout is disabled (all neurons active).
import torch.nn as nn
model = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Dropout(p=0.5), # 50% dropout
nn.Linear(256, 64),
nn.ReLU(),
nn.Dropout(p=0.3), # 30% dropout
nn.Linear(64, 2)
)
model.train() # Dropout active
model.eval() # Dropout disabledTypical dropout rates: 0.2–0.5 for fully connected layers, 0.1–0.2 for convolutional layers.
7.6.2 Batch Normalization
Batch normalization normalizes layer inputs to have zero mean and unit variance across the mini-batch:
\[ \hat{x} = \frac{x - \mu_{\text{batch}}}{\sqrt{\sigma^2_{\text{batch}} + \epsilon}} \]
followed by learnable scale (\(\gamma\)) and shift (\(\beta\)) parameters: \(y = \gamma \hat{x} + \beta\).
Benefits for medical AI:
- Training stability: Reduces internal covariate shift, allowing higher learning rates
- Mild regularization: Batch statistics add noise similar to dropout
- Faster convergence: Networks often train 5–10x faster with batch normalization
model = nn.Sequential(
nn.Linear(100, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Linear(256, 64),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Linear(64, 2)
)Note: Batch normalization behavior differs between training (uses batch statistics) and inference (uses running averages). Always call model.eval() before inference.
7.6.3 Weight Decay (L2 Regularization)
Weight decay adds a penalty proportional to weight magnitude to the loss:
\[ \mathcal{L}_{\text{reg}} = \mathcal{L} + \lambda \sum_i w_i^2 \]
This encourages smaller weights, which often correspond to simpler, more generalizable models. In PyTorch, set via the optimizer:
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)Typical weight decay values: \(10^{-4}\) to \(10^{-2}\). Higher values impose stronger regularization but may underfit if set too high.
7.6.4 Early Stopping
The simplest regularization: stop training when validation performance stops improving. Monitor validation loss or a clinical metric (e.g., AUROC) and save the best checkpoint:
best_val_loss = float('inf')
patience = 10
epochs_no_improve = 0
for epoch in range(max_epochs):
train_loss = train_one_epoch(model, train_loader)
val_loss = evaluate(model, val_loader)
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), 'best_model.pt')
epochs_no_improve = 0
else:
epochs_no_improve += 1
if epochs_no_improve >= patience:
print("Early stopping triggered")
breakFor medical AI, early stopping is essential—overfit models may achieve perfect training accuracy but fail catastrophically on new patients.
7.7 Putting It Together: The Complete Training Pipeline
Clinical Context: You have a dataset of 10,000 labeled chest X-rays and a network architecture. How do you actually train it? This section walks through the complete pipeline—from loading data to saving your best model.
7.7.1 Data Loading and Batching
Training on all 10,000 images at once would exhaust GPU memory and provide noisy gradient estimates. Instead, we process data in mini-batches—small random subsets (typically 16–128 images) that fit in memory and provide stable gradient estimates.
PyTorch’s DataLoader handles batching, shuffling, and parallel data loading:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# Define preprocessing: resize, normalize to ImageNet statistics
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Load dataset (assumes images organized in class folders)
train_dataset = datasets.ImageFolder('data/train', transform=transform)
val_dataset = datasets.ImageFolder('data/val', transform=transform)
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)Shuffling the training data each epoch prevents the model from learning spurious patterns based on data order. Validation data doesn’t need shuffling since we only evaluate, never train on it.
7.7.2 The Training Loop
Every training loop follows the same structure: iterate over epochs, iterate over batches, compute loss, backpropagate, update weights. Here’s the complete pattern:
import torch
import torch.nn as nn
import torch.optim as optim
# Setup
model = ChestXrayClassifier() # Your network architecture
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# Move to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# Training loop
num_epochs = 50
for epoch in range(num_epochs):
model.train() # Enable dropout, batch norm training mode
train_loss = 0.0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad() # Clear previous gradients
outputs = model(images) # Forward pass
loss = criterion(outputs, labels) # Compute loss
loss.backward() # Backward pass
optimizer.step() # Update weights
train_loss += loss.item()
# Validation phase
model.eval() # Disable dropout, use running batch norm stats
val_loss = 0.0
correct = 0
total = 0
with torch.no_grad(): # Don't compute gradients for validation
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
val_accuracy = 100. * correct / total
print(f'Epoch {epoch+1}: Train Loss: {train_loss/len(train_loader):.4f}, '
f'Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {val_accuracy:.2f}%')Key details that matter:
model.train()vsmodel.eval(): Switches behavior of dropout and batch normalization. Forgetting this is a common bug.optimizer.zero_grad(): Gradients accumulate by default. Clear them before each batch.torch.no_grad(): Disables gradient computation during validation, saving memory..to(device): Moves tensors to GPU. Both model and data must be on the same device.
7.7.3 Monitoring and Checkpointing
Watch the training and validation loss curves. Healthy training shows both decreasing initially, then validation loss plateauing while training loss continues to drop (mild overfitting is normal). If validation loss increases while training loss drops sharply, you’re overfitting—stop training or increase regularization.
Save checkpoints so you can recover your best model:
best_val_loss = float('inf')
for epoch in range(num_epochs):
# ... training code ...
# Save best model
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'val_loss': val_loss,
}, 'best_model.pt')
# Load best model for final evaluation
checkpoint = torch.load('best_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])7.7.4 From Training to Deployment
Once training completes, you have a model that maps pixel values to predictions. But before clinical use, you need to:
- Evaluate on held-out test set: Performance on data the model has never seen, not even for hyperparameter tuning
- Compute clinical metrics: Sensitivity, specificity, AUROC—not just accuracy
- Analyze failure modes: Which cases does the model get wrong? Are there systematic biases?
- Validate externally: Test on data from different institutions, scanners, patient populations
Chapters 14 and 16 cover deployment and evaluation in depth. For now, the key insight is that training is just the beginning—rigorous validation determines whether your model is ready for patients.
7.8 Looking Ahead: Exploiting Image Structure
The networks in this chapter treat images as flat vectors—50,000 independent numbers fed into fully connected layers. This works, but it ignores something important: pixels near each other are related. The lung fields are spatially coherent regions, not random collections of pixels.
Chapter 8 introduces convolutional neural networks, which exploit this spatial structure. Instead of learning weights for every pixel independently, CNNs learn small filters that detect local patterns (edges, textures) and slide them across the image. This dramatically reduces parameters and improves performance on medical imaging tasks.
The training machinery you learned here—backpropagation, optimizers, regularization, the training loop—carries over directly. CNNs just change what the network computes, not how it learns.
7.9 Appendix 5A: The Mathematics of Backpropagation
This appendix provides formal derivations for readers who want the mathematical foundations behind the intuitions in the main chapter.
7.9.1 Notation and Setup
Consider an \(L\)-layer feedforward network. For layer \(\ell \in \{1, \ldots, L\}\):
- \(W^{(\ell)} \in \mathbb{R}^{n_\ell \times n_{\ell-1}}\): weight matrix
- \(b^{(\ell)} \in \mathbb{R}^{n_\ell}\): bias vector
- \(z^{(\ell)} = W^{(\ell)} a^{(\ell-1)} + b^{(\ell)}\): pre-activation (linear combination)
- \(a^{(\ell)} = \sigma(z^{(\ell)})\): activation (after nonlinearity)
- \(a^{(0)} = x\): input
- \(a^{(L)} = \hat{y}\): output prediction
The forward pass computes sequentially:
\[ x \xrightarrow{W^{(1)}, b^{(1)}} z^{(1)} \xrightarrow{\sigma} a^{(1)} \xrightarrow{W^{(2)}, b^{(2)}} z^{(2)} \xrightarrow{\sigma} \cdots \xrightarrow{W^{(L)}, b^{(L)}} z^{(L)} \xrightarrow{\sigma} \hat{y} \]
7.9.2 Loss Function Gradients
For a single training example \((x, y)\), common loss functions and their gradients with respect to the output \(\hat{y}\):
Cross-entropy (multiclass, \(\hat{y}\) after softmax): \[ \mathcal{L} = -\sum_{c=1}^{C} y_c \log \hat{y}_c, \qquad \frac{\partial \mathcal{L}}{\partial \hat{y}_c} = -\frac{y_c}{\hat{y}_c} \]
When using softmax output \(\hat{y}_c = \frac{e^{z_c}}{\sum_j e^{z_j}}\), the gradient with respect to the pre-softmax logits simplifies elegantly: \[ \frac{\partial \mathcal{L}}{\partial z_c} = \hat{y}_c - y_c \]
This is why cross-entropy pairs naturally with softmax—the gradient is just prediction minus target.
Binary cross-entropy (\(\hat{y}\) after sigmoid): \[ \mathcal{L} = -[y \log \hat{y} + (1-y) \log(1-\hat{y})], \qquad \frac{\partial \mathcal{L}}{\partial \hat{y}} = \frac{\hat{y} - y}{\hat{y}(1-\hat{y})} \]
With sigmoid \(\hat{y} = \sigma(z) = \frac{1}{1+e^{-z}}\), this simplifies to: \[ \frac{\partial \mathcal{L}}{\partial z} = \hat{y} - y \]
Mean squared error: \[ \mathcal{L} = \frac{1}{2}\|\hat{y} - y\|^2, \qquad \frac{\partial \mathcal{L}}{\partial \hat{y}} = \hat{y} - y \]
7.9.3 The Backpropagation Algorithm
Define the error signal at layer \(\ell\): \[ \delta^{(\ell)} = \frac{\partial \mathcal{L}}{\partial z^{(\ell)}} \]
This measures how the loss changes with the pre-activation at layer \(\ell\).
Step 1: Compute output error
At the final layer \(L\): \[ \delta^{(L)} = \frac{\partial \mathcal{L}}{\partial z^{(L)}} = \frac{\partial \mathcal{L}}{\partial \hat{y}} \odot \sigma'(z^{(L)}) \]
where \(\odot\) denotes element-wise multiplication and \(\sigma'\) is the activation derivative.
Step 2: Backpropagate errors
For \(\ell = L-1, L-2, \ldots, 1\): \[ \delta^{(\ell)} = \bigl((W^{(\ell+1)})^\top \delta^{(\ell+1)}\bigr) \odot \sigma'(z^{(\ell)}) \]
The error at layer \(\ell\) is the error from layer \(\ell+1\), transformed by the transpose of the weight matrix, then scaled by the local activation derivative.
Step 3: Compute parameter gradients
Once we have all error signals: \[ \frac{\partial \mathcal{L}}{\partial W^{(\ell)}} = \delta^{(\ell)} (a^{(\ell-1)})^\top, \qquad \frac{\partial \mathcal{L}}{\partial b^{(\ell)}} = \delta^{(\ell)} \]
7.9.4 Activation Function Derivatives
The local derivatives \(\sigma'(z)\) for common activations:
| Activation | \(\sigma(z)\) | \(\sigma'(z)\) |
|---|---|---|
| Sigmoid | \(\frac{1}{1+e^{-z}}\) | \(\sigma(z)(1-\sigma(z))\) |
| Tanh | \(\frac{e^z - e^{-z}}{e^z + e^{-z}}\) | \(1 - \tanh^2(z)\) |
| ReLU | \(\max(0, z)\) | \(\mathbf{1}_{z > 0}\) (1 if positive, 0 otherwise) |
| Leaky ReLU | \(\max(\alpha z, z)\) | \(\mathbf{1}_{z > 0} + \alpha \cdot \mathbf{1}_{z \leq 0}\) |
7.9.5 Matrix Dimensions in Practice
For a batch of \(m\) examples, the dimensions are:
- Input: \(X \in \mathbb{R}^{m \times n_0}\) (each row is one example)
- Activations: \(A^{(\ell)} \in \mathbb{R}^{m \times n_\ell}\)
- Weight gradients: \(\frac{\partial \mathcal{L}}{\partial W^{(\ell)}} \in \mathbb{R}^{n_\ell \times n_{\ell-1}}\)
The gradient for a batch is the average over examples: \[ \frac{\partial \mathcal{L}}{\partial W^{(\ell)}} = \frac{1}{m} (\delta^{(\ell)})^\top A^{(\ell-1)} \]
7.9.6 Vanishing and Exploding Gradients
The error signal propagates as: \[ \delta^{(\ell)} = \Bigl(\prod_{k=\ell+1}^{L} (W^{(k)})^\top \text{diag}(\sigma'(z^{(k)}))\Bigr) \delta^{(L)} \]
The gradient magnitude depends on the product of weight matrices and activation derivatives across all layers between \(\ell\) and \(L\).
Vanishing gradients occur when this product shrinks exponentially:
- Sigmoid/tanh: \(\sigma'(z) \in (0, 0.25]\) for sigmoid, so each layer multiplies the gradient by at most 0.25. After 10 layers: \(0.25^{10} \approx 10^{-6}\).
- Small weights: If \(\|W^{(\ell)}\| < 1\), gradients shrink each layer.
Exploding gradients occur when the product grows exponentially:
- Large weights: If \(\|W^{(\ell)}\| > 1\), gradients can blow up.
- Deep networks without normalization are prone to this.
Mitigations:
- ReLU activations: Derivative is 1 for positive inputs, avoiding the squashing problem.
- Careful initialization: Xavier/Glorot or He initialization sets weight scales to preserve gradient magnitudes.
- Batch normalization: Keeps activations in a reasonable range.
- Residual connections: Skip connections in ResNets allow gradients to flow directly through addition operations.
- Gradient clipping: Cap gradient magnitudes to prevent explosion.
7.9.7 Computational Complexity
For an \(L\)-layer network with \(n\) neurons per layer:
- Forward pass: \(O(Ln^2)\) multiplications (matrix-vector products)
- Backward pass: \(O(Ln^2)\) multiplications (same order as forward)
- Memory: Must store all activations \(a^{(\ell)}\) for the backward pass
The backward pass is approximately 2× the cost of the forward pass—we compute the same operations in reverse, plus the gradient accumulations.
7.9.8 Further Reading
For deeper treatment of these topics:
Goodfellow, Bengio & Courville (2016). Deep Learning, Chapter 6. The standard reference for neural network fundamentals. Freely available at deeplearningbook.org.
Bishop (2006). Pattern Recognition and Machine Learning, Chapter 5. Rigorous probabilistic perspective on neural networks.
Rumelhart, Hinton & Williams (1986). “Learning representations by back-propagating errors.” Nature 323:533–536. The paper that popularized backpropagation.
He et al. (2015). “Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification.” Introduces He initialization for ReLU networks.
Ioffe & Szegedy (2015). “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.” The batch normalization paper.