Gradient Descent Variants
SGD, mini-batch, momentum, RMSProp, Adam — what each variant fixes and why Adam became the default optimizer for deep learning.
Intuition First
Basic gradient descent has a problem: it moves the same amount in every direction regardless of how steep or noisy that direction is. Smarter optimizers learned to adapt.
Think of it like hiking: vanilla gradient descent takes the same stride length no matter the terrain. Some variants add memory (momentum) — like a ball rolling downhill that builds up speed. Others adapt the stride per dimension — taking tiny steps in noisy ravines and bold steps on smooth slopes.
What's Actually Happening
All gradient descent variants share the same core loop:
- Compute the gradient
∇L(w)(via backprop) - Update the weights
wusing some rule that involves the gradient
What differs is how they use the gradient:
- How many samples they use to estimate it
- Whether they maintain state (velocity, squared gradient averages) between steps
- Whether the effective learning rate adapts per parameter
Build the Idea Step-by-Step
Formal Explanation
Batch Gradient Descent (Vanilla GD)
∇L = (1/N) Σᵢ ∇ℓ(xᵢ, yᵢ) # average over ALL N samples
w ← w - α · ∇L
Exact gradient, but requires a full pass over all data before each update.
Stochastic Gradient Descent (SGD)
sample one (xᵢ, yᵢ) at random
∇L ≈ ∇ℓ(xᵢ, yᵢ) # one-sample estimate
w ← w - α · ∇L
Very noisy, but one update per sample — much faster per epoch.
Mini-batch SGD (Standard in Practice)
sample B examples {(x₁,y₁), ..., (xB,yB)}
∇L ≈ (1/B) Σᵢ ∇ℓ(xᵢ, yᵢ) # batch estimate
w ← w - α · ∇L
B is typically 32–256. This is what "SGD" almost always means in practice — mini-batch, not one-sample.
SGD with Momentum
v ← β·v + (1-β)·∇L # exponential moving average of gradients
w ← w - α·v
β≈ 0.9 is the momentum coefficientvcarries velocity from previous steps- Smooths out noisy gradient directions, accelerates in consistent ones
RMSProp
s ← β·s + (1-β)·(∇L)² # moving average of squared gradient
w ← w - α · ∇L / (√s + ε)
- Divides each weight update by the RMS of recent gradients for that weight
- Parameters with large recent gradients get smaller updates (stabilizes training)
- Parameters with small gradients get larger updates (explores more)
Adam (Adaptive Moment Estimation)
m ← β₁·m + (1-β₁)·∇L # 1st moment: mean of gradients (like momentum)
v ← β₂·v + (1-β₂)·(∇L)² # 2nd moment: mean of squared gradients (like RMSProp)
m̂ = m / (1 - β₁ᵗ) # bias-corrected (important early in training)
v̂ = v / (1 - β₂ᵗ)
w ← w - α · m̂ / (√v̂ + ε)
Default hyperparameters: α=1e-3, β₁=0.9, β₂=0.999, ε=1e-8
Adam = Momentum + RMSProp, with bias correction for the first few steps.
Key Properties / Rules
| Optimizer | Memory | Adaptive LR | Best For |
|---|---|---|---|
| Batch GD | none | no | tiny datasets, convex problems |
| SGD | none | no | simple, interpretable, sometimes best for generalization |
| Mini-batch SGD | none | no | standard default |
| Momentum | velocity v | no | smoother convergence on deep nets |
| RMSProp | gradient² s | yes | RNNs, non-stationary objectives |
| Adam | m and v | yes | most deep learning — default choice |
Why It Matters
Adam is the standard optimizer for training transformers, CNNs, and most deep learning models. Understanding why it's used:
- Momentum helps escape saddle points and noisy gradient directions
- Adaptive LR per weight handles gradients of wildly different scales (early layers vs late layers)
- Bias correction prevents tiny initial updates from skewing the moving averages
When you see optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) in any ML codebase — that's this exact algorithm running.
Caveat: SGD with momentum sometimes generalizes better than Adam on vision tasks. Adam finds a good solution fast; SGD often finds a flatter, more generalizable minimum given enough time. This is an active research area.
Common Pitfalls
- Using Adam's default learning rate for fine-tuning large models.
1e-3is too aggressive for fine-tuning — use1e-5or1e-4. Adam's1e-3default is tuned for training from scratch. - Forgetting
optimizer.zero_grad(). In PyTorch, gradients accumulate by default. Callzero_grad()before each backward pass, or gradients from previous batches corrupt the update. - Momentum carries stale velocity across learning rate changes. If you cut the LR mid-training, the accumulated momentum can overshoot for a few steps. Learning rate schedulers handle this gracefully.
- Adam with weight decay is not the same as AdamW. Standard Adam applies weight decay incorrectly (inside the adaptive step). AdamW fixes this — prefer
torch.optim.AdamWfor regularized training.
Examples
import torch
import torch.nn as nn
model = nn.Linear(10, 1)
X = torch.randn(100, 10)
y = torch.randn(100, 1)
loss_fn = nn.MSELoss()
# --- SGD ---
opt_sgd = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# --- Adam ---
opt_adam = torch.optim.Adam(model.parameters(), lr=1e-3)
# --- AdamW (preferred for transformers) ---
opt_adamw = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
# Training loop (same for all)
for epoch in range(50):
opt_adam.zero_grad() # clear accumulated gradients
pred = model(X)
loss = loss_fn(pred, y)
loss.backward() # compute ∇L for all parameters
opt_adam.step() # apply the Adam update rule
if epoch % 10 == 0:
print(f"epoch {epoch}: loss={loss.item():.4f}")
# Implementing Adam from scratch to understand it
import numpy as np
def adam_update(w, grad, m, v, t, alpha=1e-3, beta1=0.9, beta2=0.999, eps=1e-8):
"""
w - current weights
grad - gradient at current w
m, v - running 1st and 2nd moment estimates
t - time step (starts at 1)
"""
m = beta1 * m + (1 - beta1) * grad # update biased 1st moment
v = beta2 * v + (1 - beta2) * (grad ** 2) # update biased 2nd moment
m_hat = m / (1 - beta1 ** t) # bias correction
v_hat = v / (1 - beta2 ** t)
w = w - alpha * m_hat / (np.sqrt(v_hat) + eps)
return w, m, v
# Usage
w = np.array([0.0, 0.0, 0.0])
m = np.zeros_like(w)
v = np.zeros_like(w)
for t in range(1, 101):
grad = 2 * w - np.array([1.0, 2.0, 3.0]) # gradient of ||w - target||²
w, m, v = adam_update(w, grad, m, v, t)
print(f"converged to: {w}") # should be close to [1, 2, 3]