MnemosyneMnemosyne

Learning Rate Effects

The learning rate controls how big each optimization step is. Too high and training explodes; too low and it stalls. Schedules, warmup, and the LR finder are the tools practitioners use to get it right.

Intuition First

Imagine you're adjusting a dial with your eyes closed, trying to hit exactly 100. If you turn it by 50 each time, you overshoot and oscillate. If you turn it by 0.001 each time, you'll never get there.

The learning rate α is that dial sensitivity. Every gradient descent step moves the weights by α × gradient. Too big: the loss bounces around or explodes. Too small: training is painfully slow or gets stuck.

Getting the learning rate right is often the single biggest factor in whether a model trains at all.


What's Actually Happening

At each step, the update is:

w ← w - α · ∇L(w)

The gradient ∇L tells you the direction and magnitude of the slope. α scales how far you move. The mismatch between α and the actual curvature of the loss surface is the source of most training failures.

Too high: The step overshoots the valley floor, landing on the other side at a higher loss. Next step overshoots again. Loss oscillates or diverges.

Too low: Steps so small you barely move. Training appears to "work" but converges to a poor solution, or simply takes thousands of epochs to reach something useful.

Just right: Loss decreases smoothly. Each step takes you meaningfully closer to a good solution.


Build the Idea Step-by-Step

High LR: overshoot → loss diverges
Low LR: tiny steps → slow or stalled
Good LR: smooth loss decrease
LR schedule: start big, decrease over time
Warmup: start tiny, ramp up, then decay
Cyclic LR: periodically reset to escape local minima

Formal Explanation

Why learning rate affects convergence:

For a convex function with curvature H (Hessian), gradient descent converges stably when:

α < 2 / λ_max(H)

Where λ_max is the largest eigenvalue of the Hessian (steepest curvature). In practice, you don't compute H — you just tune α empirically.

Learning rate schedules:

ScheduleRuleWhen to Use
Constantα fixedPrototyping, short runs
Step decayα halved every k epochsWhen you know when to reduce
Exponential decayα = α₀ × γᵗSmooth continuous decay
Cosine annealingα follows half-cosine curveMost modern transformer training
Linear warmup + cosine decayramp up for N steps, then cosineBERT, GPT, LLaMA — standard

The warmup intuition:

At the start of training, weights are random and Adam's moving averages (m, v) are zero. The bias correction helps, but the gradient signal is very noisy early on. A large LR in this chaotic phase can send weights far in bad directions that are hard to recover from.

Warmup linearly increases α from ~0 to the target value over the first 1–10% of training steps, giving the optimizer time to build stable momentum estimates before taking large steps.


Key Properties / Rules

ScenarioSymptomFix
LR too highLoss oscillates or NaNReduce by 10×
LR too lowLoss barely decreases, very slowIncrease by 10×
LR just rightSmooth monotonic loss decreaseKeep it
Forgetting scheduleLoss plateaus earlyAdd cosine decay
Training from scratchAdam default 1e-3 often worksStart here
Fine-tuning pretrainedDefault LR too highUse 1e-5 to 1e-4

Why It Matters

Learning rate is the most important hyperparameter in neural network training. It appears in every ML paper's experimental setup. When a model "doesn't converge," the learning rate is the first thing to check.

In practice:

  • PyTorch and Hugging Face transformers train with linear warmup + cosine decay by default
  • The LR range test / LR finder automates selecting a good LR
  • Gradient clipping (torch.nn.utils.clip_grad_norm_) prevents gradient explosions when LR is high or gradients spike

Common Pitfalls

  • Using the same LR for all layer types. Embedding layers often need a lower LR than transformer layers; the last linear layer might need a higher LR than earlier layers. Per-layer LR groups fix this.
  • Not resetting the optimizer when resuming training. If you load a checkpoint and change the LR, Adam's accumulated momentum still reflects the old trajectory. Sometimes you need to reset m and v.
  • Warmup too short for large batches. When using large batch sizes, the effective LR is higher (more gradient signal per step). Warmup should be longer proportionally.
  • Forgetting to step the scheduler. In PyTorch, scheduler.step() must be called each epoch or step. Forgetting it leaves LR constant.

Examples

import torch
import torch.nn as nn

model = nn.TransformerEncoder(
    nn.TransformerEncoderLayer(d_model=128, nhead=4, batch_first=True),
    num_layers=2
)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

# --- Cosine annealing ---
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=100, eta_min=1e-5
)

for epoch in range(100):
    # ... training step ...
    optimizer.step()
    scheduler.step()
    print(f"epoch {epoch}: lr={scheduler.get_last_lr()[0]:.6f}")
# Linear warmup + cosine decay (standard for transformers)
from torch.optim.lr_scheduler import LambdaLR

def warmup_cosine_schedule(optimizer, warmup_steps, total_steps):
    import math
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return current_step / max(1, warmup_steps)   # linear ramp
        progress = (current_step - warmup_steps) / max(1, total_steps - warmup_steps)
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))  # cosine decay

    return LambdaLR(optimizer, lr_lambda)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
scheduler = warmup_cosine_schedule(optimizer, warmup_steps=500, total_steps=10000)
# Learning rate range test (LR finder)
# Increase LR exponentially, record loss at each step.
# Plot loss vs LR. Choose LR just before loss starts rising.

lrs = []
losses = []
lr = 1e-6
multiplier = 1.2

for batch in dataloader:
    for g in optimizer.param_groups:
        g['lr'] = lr

    loss = train_step(batch)
    lrs.append(lr)
    losses.append(loss.item())

    lr *= multiplier
    if lr > 1.0 or loss > 10 * min(losses):
        break

# Plot lrs vs losses — choose lr = ~10x below minimum
# Gradient clipping — prevents explosions during high-LR phases
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()

Review Questions