Loss Functions
Loss functions measure how wrong a model's predictions are. Choosing the right one — MSE for regression, cross-entropy for classification — determines what the model actually optimizes for during training.
Intuition First
A loss function is a scorekeeping system. After the model makes a prediction, the loss function asks: "How wrong were you?"
The training loop then tries to minimize that score.
The key insight: what you measure is what you optimize. If you use the wrong loss function, you'll get a model that's very good at the wrong thing.
What's Actually Happening
Every training step:
- Model makes a prediction (
y_pred) - You compare it to the true answer (
y_true) - The loss function returns a single number — the "wrongness"
- Gradient descent nudges weights to make that number smaller
The loss function must be:
- Differentiable — so gradients can flow backward
- Smooth — so gradient descent can navigate it
- Semantically correct — so minimizing it actually solves your task
Build the Idea Step-by-Step
Formal Explanation
MSE — Mean Squared Error
Used for regression (predicting continuous numbers: prices, temperatures, scores).
L = (1/n) × Σ (y_pred - y_true)²
- Average of squared differences between predictions and targets
- Squaring makes all errors positive and penalizes large errors heavily (a 10× larger error becomes 100× worse)
Cross-Entropy Loss
Used for classification (predicting which category something belongs to).
Binary cross-entropy (two classes: yes/no):
L = -[y × log(p) + (1 - y) × log(1 - p)]
Categorical cross-entropy (multiple classes):
L = -Σ y_true × log(y_pred)
Where y_pred is a probability (output of softmax, between 0 and 1).
Key Properties / Rules
| Loss | Task | Punishes |
|---|---|---|
| MSE | Regression | Large prediction errors quadratically |
| Cross-Entropy | Classification | Confident wrong predictions harshly |
| MAE (L1 loss) | Regression (robust) | All errors equally (less sensitive to outliers) |
Why cross-entropy punishes confident wrong answers so hard:
log(0.01) = -4.6 — if you predicted 1% probability for the true class, your loss is 4.6.
log(0.99) = -0.01 — if you predicted 99% probability for the true class, your loss is near zero.
The log function creates a steep cliff: being very wrong about the right class is catastrophically punished.
Why It Matters
The loss function is the only signal the model has. It defines the optimization landscape that gradient descent navigates.
In neural networks:
- Cross-entropy + softmax output is the standard classifier setup
- MSE is used in regression heads, value networks (RL), and reconstruction losses (autoencoders)
- The loss function choice changes which gradients flow back — so it changes what the model learns
Common Pitfalls
- Using MSE for classification. It works somewhat, but doesn't push probabilities toward 0 and 1 as sharply as cross-entropy. The gradients are weaker and training is slower.
- Forgetting to apply softmax before cross-entropy. Most frameworks (PyTorch's
CrossEntropyLoss) include softmax internally — applying it yourself doubles it. - High loss but model seems fine. Check if your loss function matches the task. A 0.1 MSE on normalized targets is great; 0.1 cross-entropy is decent but not great.
- Loss goes NaN. Usually caused by
log(0)in cross-entropy — predictions of exactly 0 break it. Add a small epsilon or use framework loss functions that handle this.
Examples
import torch
import torch.nn as nn
# MSE for regression
mse = nn.MSELoss()
predictions = torch.tensor([2.5, 3.0, 5.0])
targets = torch.tensor([1.0, 3.0, 4.0])
loss = mse(predictions, targets)
# = mean([(2.5-1)², (3-3)², (5-4)²]) = mean([2.25, 0, 1]) = 1.083
print(f"MSE loss: {loss:.4f}") # 1.0833
# Cross-entropy for classification (multi-class)
# PyTorch's CrossEntropyLoss expects raw logits (no softmax)
ce = nn.CrossEntropyLoss()
logits = torch.tensor([[2.0, 1.0, 0.5]]) # raw scores for 3 classes
targets = torch.tensor([0]) # true class is index 0
loss = ce(logits, targets)
print(f"Cross-entropy loss: {loss:.4f}")
# Binary cross-entropy
bce = nn.BCEWithLogitsLoss() # includes sigmoid internally
logits = torch.tensor([1.5, -0.5, 2.0]) # raw scores
targets = torch.tensor([1.0, 0.0, 1.0]) # binary labels
loss = bce(logits, targets)
print(f"Binary CE loss: {loss:.4f}")
Manual MSE calculation:
| Prediction | Target | Error | Squared Error |
|---|---|---|---|
| 2.5 | 1.0 | +1.5 | 2.25 |
| 3.0 | 3.0 | 0.0 | 0.00 |
| 5.0 | 4.0 | +1.0 | 1.00 |
| Mean: | 1.08 |