MnemosyneMnemosyne

Chain Rule

When functions are composed — one feeding into another — the chain rule tells you how to differentiate through the chain. It's not just a calculus rule; it is backpropagation.

Intuition First

You're a translator working through two assistants. You ask assistant A to translate English → French. Assistant B translates French → Spanish. If a change in the original English shifts assistant A's output slightly, and that shift in turn affects B's output, how does the original English change ultimately affect the final Spanish?

You multiply the two "sensitivities" together: (how much A's output changes per unit of English) × (how much B's output changes per unit of French input).

That multiplication is the chain rule — and it's exactly what happens when backpropagation flows the loss gradient backward through every layer in a neural network.


What's Actually Happening

If y = f(g(x))g applied first, then f applied to g's output — the derivative of y with respect to x is:

dy/dx = f'(g(x)) · g'(x)

Rate of y with respect to x = (rate of y with respect to the middle) × (rate of middle with respect to x).

This chains together as many functions as you like. Compose 100 functions, just multiply 100 derivatives.


Build the Idea Step-by-Step

Input x
First function: u = g(x)
Second function: y = f(u)
Chain rule: dy/dx = (dy/du) · (du/dx)
Three functions: multiply 3 derivatives
100 layers: multiply 100 derivatives (backprop!)

Formal Explanation

Two functions:

y = f(g(x))
dy/dx = f'(g(x)) · g'(x)
        = (df/du) · (du/dx)   where u = g(x)

Three functions:

y = f(g(h(x)))
dy/dx = f'(g(h(x))) · g'(h(x)) · h'(x)

General pattern: Multiply derivatives going from the outside in.

Concrete example: y = (3x + 1)²

Let u = 3x + 1, so y = u².

dy/du = 2u = 2(3x+1)
du/dx = 3
dy/dx = 2(3x+1) · 3 = 6(3x+1)

Key Properties / Rules

ConceptFormula
Two functionsd/dx f(g(x)) = f'(g(x)) · g'(x)
Three functions(f∘g∘h)'(x) = f'(g(h(x))) · g'(h(x)) · h'(x)
Leibniz formdy/dx = (dy/du) · (du/dx)
DirectionDerivatives multiply going outer → inner
More layersMore multiplications — the chain gets longer

Why It Matters

Backpropagation is the chain rule. Full stop.

A neural network with L layers computes:

output = fₗ(fₗ₋₁(... f₂(f₁(x))))

This is a chain of composed functions. To train the network, you need ∂Loss/∂w for every weight w at every layer. The chain rule gives you this by propagating the gradient backward:

∂Loss/∂w_layer1 = (∂Loss/∂output) · (∂output/∂layer_L_activation) · ... · (∂layer_2_output/∂layer_1_output) · (∂layer_1_output/∂w_layer1)

Every · in that expression is one application of the chain rule. Backprop is literally computing this product efficiently by reusing intermediate values.

Why ReLU is popular: ReLU has a simple derivative — 1 if active, 0 if not. The chain rule multiplies these derivatives through the network. Sigmoid's derivative is at most 0.25, so in a deep network 0.25 × 0.25 × ... → 0. That's the vanishing gradient problem. ReLU avoids it by keeping derivatives at 1 (or 0) rather than fractions.


Common Pitfalls

  • Forgetting to evaluate at the right point. f'(g(x)) means: evaluate f' at the value g(x), not at x. This is a common mistake — you have to substitute first.
  • Stopping too early. Each composed function requires one more multiplication. Chain 3 functions → multiply 3 derivatives. If you only multiply 2, you're missing a layer.
  • Confusing which layer is "inside" vs "outside." Start with the outermost function, work inward. In sin(x²), the outer function is sin and the inner is .

Examples

# Verifying the chain rule numerically
def numerical_derivative(f, x, dx=1e-5):
    return (f(x + dx) - f(x)) / dx

import math

# y = sin(x^2) 
# Chain rule: dy/dx = cos(x²) · 2x
x = 1.5
composed = lambda x: math.sin(x**2)
chain_rule_result = math.cos(x**2) * 2*x     # ≈ cos(2.25) * 3.0

print(f"chain rule:  {chain_rule_result:.5f}")
print(f"numerical:   {numerical_derivative(composed, x):.5f}")  # should match
# How PyTorch traces the chain rule through a network
import torch
import torch.nn as nn

# Simple 2-layer network
model = nn.Sequential(
    nn.Linear(4, 8),
    nn.ReLU(),
    nn.Linear(8, 1)
)

x = torch.randn(1, 4)
y_true = torch.tensor([[1.0]])

y_pred = model(x)
loss = (y_pred - y_true)**2
loss.backward()   # PyTorch applies chain rule through the whole computation graph

# Every weight now has .grad, computed via the chain rule
for name, param in model.named_parameters():
    print(f"{name}: grad shape = {param.grad.shape}")

What backprop is doing:

  1. Forward pass: compute and save intermediate values at every layer
  2. Backward pass: apply chain rule from loss back to first layer, reusing saved values

Each .backward() call in PyTorch is the chain rule running automatically.

Review Questions