Skip to content

English · Español

02 — Softmax stability and log-sum-exp

Try it — temperature & softmax

🇪🇸 La página más importante de la fase 2. Tres derivaciones, tres líneas de código, y todas las fases posteriores dependen de ellas: el truco -max, log-sum-exp, y la cross-entropy estable a partir de logits crudos. Ejemplo concreto: clasificar el token siguiente como uno de los 5 tiempos verbales del modelo (§A13).

This is the theory page of Phase 2. Re-derive everything below from a blank page until it's mechanical. Every later phase (autograd, attention, training loop, sampler) leans on these three rewrites.


The setup — softmax over §A13

The softmax function maps a length-K logit vector x = [x_1, ..., x_K] to a probability distribution:

\[ \mathrm{softmax}(x)_i = \frac{e^{x_i}}{\sum_{j=1}^{K} e^{x_j}} \]

Concrete §A13 instance: classify a token as one of the five tenses. K = 5. The model produces a logit vector like

x = [ 1.2, 4.7, 3.1, 0.5, 2.9 ]    # [infinitive, present, past, past-participle, future]

Softmax converts to:

p ≈ [ 0.018, 0.612, 0.124, 0.009, 0.237 ]

The model thinks "present" is most likely. Cross-entropy will compare this to the true label (say, "past simple" — index 2) and produce a scalar loss.

That's the math. Now the failure mode.

The overflow failure

Suppose the logit vector during early training has one huge value:

x = [ 1.2, 92.0, 3.1, 0.5, 2.9 ]

A logit of 92 is unremarkable for un-normalized model outputs. But:

\[ e^{92} \approx 9.2 \times 10^{39} \]

The largest representable fp32 is ~3.4 × 10^{38}. So exp(92) in fp32 returns +∞. Then exp(92) / sum(exp(x)) becomes inf / inf, which IEEE-754 defines as NaN. The full vector becomes NaN. Loss is NaN. Gradient is NaN. Training is dead.

This is not a contrived example — it happens whenever a model's logits aren't aggressively normalized, which is the default state of most architectures. Naive softmax must not exist in your codebase. The fix is one line.

The -max trick — derivation

Observation: softmax is invariant under a constant shift of all logits.

\[ \mathrm{softmax}(x + c)_i = \frac{e^{x_i + c}}{\sum_j e^{x_j + c}} = \frac{e^c e^{x_i}}{e^c \sum_j e^{x_j}} = \mathrm{softmax}(x)_i \]

The e^c factors cancel. So we can shift x by any constant without changing the result.

Choose c = -max(x). Then x' = x - max(x) has its maximum element equal to 0, and every other element is ≤ 0. So exp(x') is bounded by exp(0) = 1 and bounded below by zero. No overflow — ever.

The stable softmax:

def stable_softmax(x):
    x_shifted = x - x.max()        # max element is now 0
    e = np.exp(x_shifted)           # all entries in (0, 1]
    return e / e.sum()

Applied to our adversarial example:

x        = [ 1.2, 92.0, 3.1, 0.5, 2.9 ]
x_shifted= [-90.8, 0.0, -88.9, -91.5, -89.1 ]
e        ≈ [ 0,    1,    0,    0,    0      ]   # all underflow to ~0 except the max
softmax  ≈ [ 0,    1,    0,    0,    0      ]

The result is a one-hot vector at index 1 (present), which is the correct limiting behavior of softmax when one logit dwarfs the others. No NaN.

Underflow is acceptable. The values that underflow to zero would have been e^{-91} ≈ 4 × 10^{-40} anyway, which is below denormal range and would have been zero rounded. The -max trick guarantees that whatever was rounded to zero was correctly near zero, never the dominant term.

Log-sum-exp — the same idea, applied to log of the denominator

Many computations want log(sum(exp(x))) directly (e.g., for log-likelihood or for the denominator term in log-softmax). Naive:

log(sum(exp(x)))   # overflows the same way softmax does

Apply the same shift:

\[ \log \sum_{i} e^{x_i} = \log \left( e^{m} \sum_i e^{x_i - m} \right) = m + \log \sum_i e^{x_i - m} \]

Choosing m = max(x), every x_i - m ≤ 0, every exp is in (0, 1], the sum is at most K, and log of something ≤ K is finite.

def log_sum_exp(x):
    m = x.max()
    return m + np.log(np.exp(x - m).sum())

This is the canonical logsumexp operation. scipy.special.logsumexp is the reference; your implementation should match it to fp32 ε on any input.

Stable cross-entropy from raw logits

Cross-entropy between the predicted distribution p (over K classes) and the true label y (an integer in [0, K)) is:

\[ \mathrm{CE}(x, y) = -\log p_y = -\log \mathrm{softmax}(x)_y = -\left( x_y - \log \sum_j e^{x_j} \right) \]

The naive implementation is −log(softmax(x)[y]). This computes softmax (which is fine if you used the -max trick), then takes a log, which is also fine. The result is also fine — but you've done extra work, and you've turned probabilities that round to zero into log(0) = -∞.

The stable implementation goes directly from logits to loss:

def stable_cross_entropy(x, y):
    # x: shape (K,) logits; y: integer label
    return log_sum_exp(x) - x[y]

One pass, no probabilities materialized, no log(0) risk. The probability of the correct class is implicit in the difference.

Verification: expand log_sum_exp(x) - x[y] and you get −log(softmax(x)[y]). Same answer, different path, never produces inf from finite logits.

Example walkthrough

For our adversarial logit vector with true label y = 2 (past simple):

x = [ 1.2, 92.0, 3.1, 0.5, 2.9 ]
m = 92.0
exp(x - m).sum() = exp(-90.8) + exp(0) + exp(-88.9) + exp(-91.5) + exp(-89.1) ≈ 1.0
log_sum_exp(x) = 92.0 + log(1.0) = 92.0
x[y=2] = 3.1
CE = 92.0 - 3.1 = 88.9

The model is very confident in the wrong answer (logit 92 for "present"); the cross-entropy is huge (88.9 nats); the gradient will be huge; the optimizer will move the weights strongly to fix it. Correct, well-defined behavior — because of the trick.

If you had used naive softmax + −log p_y, the path would be:

exp(x) → [3.3, inf, 22.2, 1.6, 18.2]
exp(x).sum() → inf
softmax → [0, NaN, 0, 0, 0]
−log p_y → −log(0) → +inf

NaN poisons every downstream gradient. Training dies. One line of code — x = x - x.max() — separates dead training from live training.

Numerical equivalence — what to test

The stable versions must produce the same answer as the naive versions when the naive versions don't overflow. Test inputs:

  1. Small magnitudes. x = [0.1, 0.2, 0.3, 0.4, 0.5] — both should agree to fp32 ε.
  2. Mixed magnitudes. x = [-3, 0, 1, 2, 5] — both agree.
  3. Large positive (adversarial). x = [1, 92, 3, 0, 2] — naive returns NaN, stable returns a valid distribution.
  4. Large negative. x = [-100, -200, -300, -400, -500] — naive returns NaN (exp(-300) underflows to 0, 0 / 0 = NaN), stable returns a one-hot at the maximum (x = -100 here).
  5. Identical. x = [5, 5, 5, 5, 5] — both return [0.2, 0.2, 0.2, 0.2, 0.2].
  6. -inf entries (masked positions). x = [1, -inf, 3, 0, 2] — both should give zero probability for the -inf entry; the -max trick handles this correctly because -inf - max ≤ 0 and exp(-inf) = 0.

Lab 01-softmax-stability.md constructs each of these test inputs from a tense-classification context and asks Borja to predict the naive vs stable outputs before running.

Why the trick generalizes — and where else it lives

The -max trick is one instance of a broader pattern: scale the inputs into a regime where the operation behaves well, then compensate. Examples Borja will see later:

  • Layer normalization (Phase 10) — subtract mean, divide by std before applying transformations. Same idea, different motivation (gradient stability rather than overflow).
  • RMSNorm (Phase 10) — divide by RMS only. Modern LLMs prefer it (cheaper, equally stable).
  • Attention scores (Phase 15) — divide QK^T by √d_k before softmax. This is not numerical; it's a variance-preservation argument, but it has the side effect of pushing logits toward the regime where softmax is well-conditioned.
  • Gradient clipping (Phase 18) — re-scale gradients whose norm exceeds a threshold, so the optimizer step is well-conditioned.
  • Mixed precision training (Phase 26) — scale the loss by 2^k so gradients land in the representable range of fp16, then descale before applying the optimizer step.

Each one is "scale into a good regime, operate, compensate". The -max trick is the first instance you'll meet. Memorize it.

Honest caveats

  1. The -max trick doesn't fix everything. It fixes overflow. It does not fix the catastrophic cancellation in log(1 - p) when p ≈ 1 (use log1p(-p) instead) or the loss-of-precision in exp(x) - 1 for tiny x (use expm1(x)).
  2. The trick is not unique. You could shift by min(x), by 0, by the mean — anything constant. max is chosen because it makes the largest exponent exp(0) = 1, which is the tightest possible bound on the post-shift values.
  3. Vectorized over batches. For a batch of K-vectors, max must be computed per row, not over the whole tensor. x.max(axis=-1, keepdims=True). Lab 01 makes you fall into this trap on purpose.
  4. max of all -inf is -inf. Edge case if every entry is -inf (a fully masked attention row). Stable softmax should return nan or a sentinel here — discussed in lab 01.

Drill problems (work these before lab)

Solutions in solutions/02-softmax-stability-ref.md (written at phase open, not visible during pre-write). Try them by reasoning, not running.

  1. For x = [10, 20, 30], compute naive softmax(x) and stable softmax(x) by hand to 4 significant digits. Do they agree? (Yes — naive doesn't overflow at these magnitudes.)
  2. For x = [100, 100, 100], compute both. Stable gives [1/3, 1/3, 1/3]. What does naive give in fp32? In fp64?
  3. Cross-entropy for x = [1.2, 4.7, 3.1, 0.5, 2.9], y = 1 (present). Compute by hand using log_sum_exp(x) - x[y]. Verify against −log(stable_softmax(x)[y]).
  4. Show that log_sum_exp(x) ≥ max(x) always, with equality iff one entry strictly dominates.
  5. What is stable_softmax([0, 0, 0, 0, 0])? What is log_sum_exp([0, 0, 0, 0, 0])? Both have closed forms; derive them.

If you can answer all five from memory + paper, move to lab 01. If any feel wobbly, re-read.

One-paragraph recap

Naive softmax overflows whenever any logit is large enough that exp exceeds the representable range (fp32: ~89). The fix is to shift logits by -max(x) before exponentiating — softmax is invariant under constant shifts, but the shift moves the post-exp values into (0, 1], eliminating overflow. The same shift turns log(sum(exp(x))) into the stable m + log(sum(exp(x - m))). Cross-entropy from raw logits is then log_sum_exp(x) - x[y] — one pass, no overflow, no log(0). These three transformations are non-negotiable in every later phase that touches a probability distribution.

What this page does NOT cover

  • Gradient of softmax — Phase 4.
  • Multi-head attention's specific use of softmax — Phase 15.
  • Distributed softmax / online softmax (FlashAttention) — Phase 27.

Next: theory/03-summation-and-cancellation.md.