English · Español
02 — Softmax stability and log-sum-exp¶
Try it — temperature & softmax¶
🇪🇸 La página más importante de la fase 2. Tres derivaciones, tres líneas de código, y todas las fases posteriores dependen de ellas: el truco
-max, log-sum-exp, y la cross-entropy estable a partir de logits crudos. Ejemplo concreto: clasificar el token siguiente como uno de los 5 tiempos verbales del modelo (§A13).
This is the theory page of Phase 2. Re-derive everything below from a blank page until it's mechanical. Every later phase (autograd, attention, training loop, sampler) leans on these three rewrites.
The setup — softmax over §A13¶
The softmax function maps a length-K logit vector x = [x_1, ..., x_K] to a probability distribution:
Concrete §A13 instance: classify a token as one of the five tenses. K = 5. The model produces a logit vector like
Softmax converts to:
The model thinks "present" is most likely. Cross-entropy will compare this to the true label (say, "past simple" — index 2) and produce a scalar loss.
That's the math. Now the failure mode.
The overflow failure¶
Suppose the logit vector during early training has one huge value:
A logit of 92 is unremarkable for un-normalized model outputs. But:
The largest representable fp32 is ~3.4 × 10^{38}. So exp(92) in fp32 returns +∞. Then exp(92) / sum(exp(x)) becomes inf / inf, which IEEE-754 defines as NaN. The full vector becomes NaN. Loss is NaN. Gradient is NaN. Training is dead.
This is not a contrived example — it happens whenever a model's logits aren't aggressively normalized, which is the default state of most architectures. Naive softmax must not exist in your codebase. The fix is one line.
The -max trick — derivation¶
Observation: softmax is invariant under a constant shift of all logits.
The e^c factors cancel. So we can shift x by any constant without changing the result.
Choose c = -max(x). Then x' = x - max(x) has its maximum element equal to 0, and every other element is ≤ 0. So exp(x') is bounded by exp(0) = 1 and bounded below by zero. No overflow — ever.
The stable softmax:
def stable_softmax(x):
x_shifted = x - x.max() # max element is now 0
e = np.exp(x_shifted) # all entries in (0, 1]
return e / e.sum()
Applied to our adversarial example:
x = [ 1.2, 92.0, 3.1, 0.5, 2.9 ]
x_shifted= [-90.8, 0.0, -88.9, -91.5, -89.1 ]
e ≈ [ 0, 1, 0, 0, 0 ] # all underflow to ~0 except the max
softmax ≈ [ 0, 1, 0, 0, 0 ]
The result is a one-hot vector at index 1 (present), which is the correct limiting behavior of softmax when one logit dwarfs the others. No NaN.
Underflow is acceptable. The values that underflow to zero would have been e^{-91} ≈ 4 × 10^{-40} anyway, which is below denormal range and would have been zero rounded. The -max trick guarantees that whatever was rounded to zero was correctly near zero, never the dominant term.
Log-sum-exp — the same idea, applied to log of the denominator¶
Many computations want log(sum(exp(x))) directly (e.g., for log-likelihood or for the denominator term in log-softmax). Naive:
Apply the same shift:
Choosing m = max(x), every x_i - m ≤ 0, every exp is in (0, 1], the sum is at most K, and log of something ≤ K is finite.
This is the canonical logsumexp operation. scipy.special.logsumexp is the reference; your implementation should match it to fp32 ε on any input.
Stable cross-entropy from raw logits¶
Cross-entropy between the predicted distribution p (over K classes) and the true label y (an integer in [0, K)) is:
The naive implementation is −log(softmax(x)[y]). This computes softmax (which is fine if you used the -max trick), then takes a log, which is also fine. The result is also fine — but you've done extra work, and you've turned probabilities that round to zero into log(0) = -∞.
The stable implementation goes directly from logits to loss:
def stable_cross_entropy(x, y):
# x: shape (K,) logits; y: integer label
return log_sum_exp(x) - x[y]
One pass, no probabilities materialized, no log(0) risk. The probability of the correct class is implicit in the difference.
Verification: expand log_sum_exp(x) - x[y] and you get −log(softmax(x)[y]). Same answer, different path, never produces inf from finite logits.
Example walkthrough¶
For our adversarial logit vector with true label y = 2 (past simple):
x = [ 1.2, 92.0, 3.1, 0.5, 2.9 ]
m = 92.0
exp(x - m).sum() = exp(-90.8) + exp(0) + exp(-88.9) + exp(-91.5) + exp(-89.1) ≈ 1.0
log_sum_exp(x) = 92.0 + log(1.0) = 92.0
x[y=2] = 3.1
CE = 92.0 - 3.1 = 88.9
The model is very confident in the wrong answer (logit 92 for "present"); the cross-entropy is huge (88.9 nats); the gradient will be huge; the optimizer will move the weights strongly to fix it. Correct, well-defined behavior — because of the trick.
If you had used naive softmax + −log p_y, the path would be:
exp(x) → [3.3, inf, 22.2, 1.6, 18.2]
exp(x).sum() → inf
softmax → [0, NaN, 0, 0, 0]
−log p_y → −log(0) → +inf
NaN poisons every downstream gradient. Training dies. One line of code — x = x - x.max() — separates dead training from live training.
Numerical equivalence — what to test¶
The stable versions must produce the same answer as the naive versions when the naive versions don't overflow. Test inputs:
- Small magnitudes.
x = [0.1, 0.2, 0.3, 0.4, 0.5]— both should agree to fp32 ε. - Mixed magnitudes.
x = [-3, 0, 1, 2, 5]— both agree. - Large positive (adversarial).
x = [1, 92, 3, 0, 2]— naive returns NaN, stable returns a valid distribution. - Large negative.
x = [-100, -200, -300, -400, -500]— naive returns NaN (exp(-300)underflows to 0,0 / 0 = NaN), stable returns a one-hot at the maximum (x = -100here). - Identical.
x = [5, 5, 5, 5, 5]— both return[0.2, 0.2, 0.2, 0.2, 0.2]. -infentries (masked positions).x = [1, -inf, 3, 0, 2]— both should give zero probability for the-infentry; the-maxtrick handles this correctly because-inf - max ≤ 0andexp(-inf) = 0.
Lab 01-softmax-stability.md constructs each of these test inputs from a tense-classification context and asks Borja to predict the naive vs stable outputs before running.
Why the trick generalizes — and where else it lives¶
The -max trick is one instance of a broader pattern: scale the inputs into a regime where the operation behaves well, then compensate. Examples Borja will see later:
- Layer normalization (Phase 10) — subtract mean, divide by std before applying transformations. Same idea, different motivation (gradient stability rather than overflow).
- RMSNorm (Phase 10) — divide by RMS only. Modern LLMs prefer it (cheaper, equally stable).
- Attention scores (Phase 15) — divide
QK^Tby√d_kbefore softmax. This is not numerical; it's a variance-preservation argument, but it has the side effect of pushing logits toward the regime where softmax is well-conditioned. - Gradient clipping (Phase 18) — re-scale gradients whose norm exceeds a threshold, so the optimizer step is well-conditioned.
- Mixed precision training (Phase 26) — scale the loss by
2^kso gradients land in the representable range of fp16, then descale before applying the optimizer step.
Each one is "scale into a good regime, operate, compensate". The -max trick is the first instance you'll meet. Memorize it.
Honest caveats¶
- The
-maxtrick doesn't fix everything. It fixes overflow. It does not fix the catastrophic cancellation inlog(1 - p)whenp ≈ 1(uselog1p(-p)instead) or the loss-of-precision inexp(x) - 1for tinyx(useexpm1(x)). - The trick is not unique. You could shift by
min(x), by0, by the mean — anything constant.maxis chosen because it makes the largest exponentexp(0) = 1, which is the tightest possible bound on the post-shift values. - Vectorized over batches. For a batch of K-vectors,
maxmust be computed per row, not over the whole tensor.x.max(axis=-1, keepdims=True). Lab01makes you fall into this trap on purpose. maxof all-infis-inf. Edge case if every entry is-inf(a fully masked attention row). Stable softmax should returnnanor a sentinel here — discussed in lab01.
Drill problems (work these before lab)¶
Solutions in solutions/02-softmax-stability-ref.md (written at phase open, not visible during pre-write). Try them by reasoning, not running.
- For
x = [10, 20, 30], compute naivesoftmax(x)and stablesoftmax(x)by hand to 4 significant digits. Do they agree? (Yes — naive doesn't overflow at these magnitudes.) - For
x = [100, 100, 100], compute both. Stable gives[1/3, 1/3, 1/3]. What does naive give in fp32? In fp64? - Cross-entropy for
x = [1.2, 4.7, 3.1, 0.5, 2.9],y = 1(present). Compute by hand usinglog_sum_exp(x) - x[y]. Verify against−log(stable_softmax(x)[y]). - Show that
log_sum_exp(x) ≥ max(x)always, with equality iff one entry strictly dominates. - What is
stable_softmax([0, 0, 0, 0, 0])? What islog_sum_exp([0, 0, 0, 0, 0])? Both have closed forms; derive them.
If you can answer all five from memory + paper, move to lab 01. If any feel wobbly, re-read.
One-paragraph recap¶
Naive softmax overflows whenever any logit is large enough that exp exceeds the representable range (fp32: ~89). The fix is to shift logits by -max(x) before exponentiating — softmax is invariant under constant shifts, but the shift moves the post-exp values into (0, 1], eliminating overflow. The same shift turns log(sum(exp(x))) into the stable m + log(sum(exp(x - m))). Cross-entropy from raw logits is then log_sum_exp(x) - x[y] — one pass, no overflow, no log(0). These three transformations are non-negotiable in every later phase that touches a probability distribution.
What this page does NOT cover¶
- Gradient of softmax — Phase 4.
- Multi-head attention's specific use of softmax — Phase 15.
- Distributed softmax / online softmax (FlashAttention) — Phase 27.
Next: theory/03-summation-and-cancellation.md.