MnemosyneMnemosyne

Softmax and Cross-Entropy

Softmax converts raw scores into probabilities. Cross-entropy measures how wrong those probabilities are. Together, they form the standard output layer + loss function for classification — and they're mathematically designed to complement each other.

Intuition First

Your model scores three candidate answers: [3.0, 1.0, 0.2]. These are raw numbers — no probability yet. Which one is the answer?

Softmax converts these scores into probabilities that sum to 1: maybe [0.84, 0.11, 0.05]. Now you can say "the model is 84% confident it's class 0."

Cross-entropy then measures how far those probabilities are from the truth. If class 0 is correct: loss = -log(0.84) ≈ 0.17. If class 2 is correct but we assigned it only 5%: loss = -log(0.05) ≈ 3.0. The more confident you were in the wrong answer, the higher the penalty.

They're paired because softmax produces exactly what cross-entropy needs: a probability distribution over classes.


What's Actually Happening

Softmax is an "amplifying normalizer." It exponentiates each score (making all values positive) and then divides by the total (making them sum to 1):

softmax(z)ᵢ = eᶻⁱ / Σⱼ eᶻʲ

Why exponentiation? It preserves the ordering of scores (higher score → higher probability) but magnifies the differences. A score of 3 doesn't just become 3× more likely than a score of 1 — it becomes e³/e¹ = e² ≈ 7.4× more likely.

Cross-entropy then measures: given the true label, how many bits does it cost to encode it using the model's predicted distribution? The lower the probability you assigned to the correct answer, the higher the cost.

Loss = -log(softmax(z)_correct_class)
     = -log(eᶻᶜ / Σⱼ eᶻʲ)
     = -zc + log(Σⱼ eᶻʲ)

This is the log-sum-exp form — it's what PyTorch actually computes, and it's numerically stable.


Build the Idea Step-by-Step

Network produces logits z = [3.0, 1.0, 0.2] — raw scores
Softmax: exponentiate → normalize → probabilities [0.84, 0.11, 0.05]
True label: class 0 (correct)
Cross-entropy: -log(0.84) ≈ 0.17 — low loss, model was right
Alternative: if true label were class 2: -log(0.05) ≈ 3.0 — high loss
Gradient flows: ∂Loss/∂zᵢ = pᵢ - yᵢ (softmax output minus one-hot truth)

Formal Explanation

Softmax

Given logit vector z ∈ ℝᵏ for k classes:

softmax(z)ᵢ = eᶻⁱ / Σⱼ₌₁ᵏ eᶻʲ

Properties:

  • All outputs are in (0, 1) — strictly between 0 and 1
  • Outputs sum to 1 — valid probability distribution
  • Relative ordering is preserved (larger logit → larger probability)
  • Translation-invariant: softmax(z) = softmax(z + c) for any constant c

Cross-Entropy Loss (classification)

For true one-hot label y and predicted probabilities p = softmax(z):

L = -Σᵢ yᵢ · log(pᵢ)
  = -log(p_correct)       ← since yᵢ = 1 only for the true class

The Combined Formula

Substituting softmax into cross-entropy:

L = -log(eᶻᶜ / Σⱼ eᶻʲ)
  = -zc + log(Σⱼ eᶻʲ)
  = log-sum-exp(z) - zc

where c is the correct class index.

Gradient of the combined loss with respect to logits (the beautiful result):

∂L/∂zᵢ = softmax(z)ᵢ - yᵢ = pᵢ - yᵢ

If i is the correct class: pᵢ - 1 (push probability up)
If i is any other class: pᵢ - 0 = pᵢ (push probability down)

This is one of the cleanest gradients in all of machine learning. The combined softmax + cross-entropy gradient is just "predicted probability minus truth."


Numerical Stability

The problem: eᶻ overflows to for large z (e.g., e¹⁰⁰ = ∞ in float32).

The fix: Subtract the max logit before exponentiating. Since softmax is translation-invariant:

z_shifted = z - max(z)        # no overflow: largest value becomes 0
softmax(z) = exp(z_shifted) / sum(exp(z_shifted))

Log-softmax is even better: When you're computing log(softmax(z)) for the loss:

log(softmax(z)ᵢ) = zᵢ - log(Σⱼ eᶻʲ)
                 = zᵢ - max(z) - log(Σⱼ eᶻʲ⁻ᵐᵃˣ⁽ᶻ⁾)

This avoids computing softmax then log separately (which introduces precision loss). PyTorch's F.cross_entropy does this automatically — pass raw logits, not pre-softmaxed probabilities.


Key Properties / Rules

ConceptDetail
Softmax temperaturesoftmax(z/T): high T → uniform, low T → sharp/argmax-like
Gradient of combined losspᵢ - yᵢ — prediction minus truth
Never apply softmax before CrossEntropyLossPyTorch applies it internally — double softmax breaks training
Log-sum-exp trickSubtract max before exponentiation to avoid overflow
Softmax vs sigmoidMulti-class → softmax; Multi-label (many can be true) → sigmoid per class

Why It Matters

Temperature scaling changes model confidence. Language models use temperature T to control how "creative" or "random" their outputs are. T = 1.0 is standard softmax. T → 0 always picks the highest logit (greedy). T = 2.0 spreads probability more evenly, allowing more varied outputs.

The gradient is why training works cleanly. Because ∂L/∂zᵢ = pᵢ - yᵢ, the network gets a direct error signal: "you said probability 0.84 for class 0 but the truth is 1.0 — push that probability higher by 0.16." This clean signal makes softmax + cross-entropy much easier to train than, say, sigmoid + MSE on a multi-class problem.

RLHF uses the same structure. Language models produce logits over the vocabulary. Softmax converts to probabilities. The probability of the target token is the cross-entropy loss. Training maximizes the log-likelihood of the correct token — same formula.


Common Pitfalls

  • Applying softmax before nn.CrossEntropyLoss in PyTorch. Don't. The function applies log_softmax internally. If you pass softmaxed probabilities, it applies softmax again — the math is wrong and gradients become nearly zero.
  • Using softmax for multi-label classification. If multiple classes can be true simultaneously (e.g., image tags: cat, outdoor, cute), softmax is wrong — its outputs compete. Use sigmoid on each logit independently.
  • Interpreting logits as probabilities. Logits are raw scores. A logit of 5 doesn't mean 50%. Only after softmax do you have probabilities.
  • Forgetting numerical stability in custom implementations. If you implement softmax yourself, always subtract the max first.

Examples

import numpy as np

# --- Softmax with numerical stability ---
def softmax(z):
    z = np.array(z, dtype=float)
    z -= z.max()           # subtract max for stability
    exp_z = np.exp(z)
    return exp_z / exp_z.sum()

logits = np.array([3.0, 1.0, 0.2])
probs = softmax(logits)
print("Probs:", np.round(probs, 4))   # [0.8437, 0.1142, 0.0421]
print("Sum:  ", probs.sum())           # 1.0

# --- Cross-entropy loss ---
def cross_entropy(probs, true_class):
    return -np.log(probs[true_class] + 1e-10)

print("Loss (class 0 is correct):", round(cross_entropy(probs, 0), 4))  # ≈ 0.17
print("Loss (class 2 is correct):", round(cross_entropy(probs, 2), 4))  # ≈ 3.17

# --- Gradient: the beautiful result ---
def softmax_crossentropy_gradient(logits, true_class):
    p = softmax(logits)
    y = np.zeros_like(p)
    y[true_class] = 1.0
    return p - y   # prediction minus truth

grad = softmax_crossentropy_gradient(logits, true_class=0)
print("Gradient:", np.round(grad, 4))  
# [-0.1563,  0.1142,  0.0421]
# correct class: -(1 - 0.8437) = -0.1563  (push up)
# wrong classes: positive (push down)
# Temperature scaling — how confidence changes
def softmax_with_temp(z, temperature=1.0):
    z = np.array(z, dtype=float) / temperature
    z -= z.max()
    exp_z = np.exp(z)
    return exp_z / exp_z.sum()

logits = np.array([3.0, 1.0, 0.2])

print("T=0.5 (sharp):", np.round(softmax_with_temp(logits, 0.5), 3))
print("T=1.0 (normal):", np.round(softmax_with_temp(logits, 1.0), 3))
print("T=2.0 (flat):  ", np.round(softmax_with_temp(logits, 2.0), 3))

# T=0.5: [0.983, 0.016, 0.002]   — very confident
# T=1.0: [0.844, 0.114, 0.042]   — standard
# T=2.0: [0.644, 0.236, 0.120]   — more spread out
# PyTorch: always pass raw logits to CrossEntropyLoss
import torch
import torch.nn as nn
import torch.nn.functional as F

logits = torch.tensor([[3.0, 1.0, 0.2]])
target = torch.tensor([0])   # class 0 is correct

# ✓ Correct: pass logits directly
loss = nn.CrossEntropyLoss()(logits, target)
print(f"Correct loss: {loss.item():.4f}")

# ✗ Wrong: softmax applied twice
probs = F.softmax(logits, dim=1)
wrong_loss = nn.CrossEntropyLoss()(probs, target)
print(f"Double-softmax loss (wrong): {wrong_loss.item():.4f}")  # artificially low

Review Questions