Skip to content

English · Español

Break — Remove the -max shift from softmax

🇪🇸 La ruptura más clásica de toda la fase 2: quita el - max(x) del softmax y observa cómo nan se propaga en cuanto algún logit supera ~88 (donde exp desborda en fp32).

Target: any softmax implementation (yours from lab 01, or a fresh one for the §A13 5-tense classifier).

Hypothesis

The learner predicts: "Removing the - max(x) shift will make softmax silently work for small logits, then catastrophically nan-out the moment a single logit exceeds ~88 (the fp32 overflow threshold for exp)."

The break

In your softmax(x) function:

 def softmax(x: np.ndarray) -> np.ndarray:
-    m = x.max()
-    e = np.exp(x - m)
+    e = np.exp(x)
     return e / e.sum()

Run procedure

Test on three §A13-tense logit vectors of growing magnitude:

uv run python -c "
import numpy as np
def softmax_naive(x):
    e = np.exp(x)
    return e / e.sum()

cases = {
    'small':  np.array([1.2, 4.7, 3.1, 0.5, 2.9], dtype=np.float32),
    'medium': np.array([10.0, 50.0, 30.0, 20.0, 40.0], dtype=np.float32),
    'large':  np.array([10.0, 92.0, 30.0, 20.0, 40.0], dtype=np.float32),
}
for name, x in cases.items():
    p = softmax_naive(x)
    print(f'{name:6}  max(x)={x.max():6.1f}  p={p}  sum={p.sum():.4f}')
"

Expected failure mode

small   max(x)=   4.7  p=[0.018 0.612 0.124 0.009 0.237]  sum=1.0000
medium  max(x)=  50.0  p=[... finite ...]                 sum=1.0000
large   max(x)=  92.0  p=[ 0.  nan  0.  0.  0.]            sum=nan

The cliff is sharp: exp(88.7) ≈ 3.4e38 is roughly the fp32 max. Beyond that, np.exp returns inf; inf / inf = nan. With the -max shift, the largest exponent is always 0, so exp(0) = 1 and the cliff is impossible.

Diagnostic

From training logs alone the symptom is "loss became nan around step N." Diagnostic checks:

  1. assert not np.isnan(loss).any() after every step. Cheap, catches the symptom at step N rather than step N+1000.
  2. Log logits.max() per step. If it crosses ~80, you are one step from the cliff in fp32 (or ~10 in fp16 — fp16 overflows at ~65000).
  3. Diff your softmax against scipy.special.softmax on the failing input. scipy uses the -max shift; if they disagree on large, your implementation is naive.

Lesson

The -max shift is algebraically a no-op: exp(-m) cancels in numerator and denominator. It is numerically a hard requirement: without it, softmax explodes the moment any logit exceeds ~88. The cost is one extra reduction (max over the vector); the savings is "never produces nan for finite input." Always pay the cost.

The same idea generalizes to logsumexp(x) = max(x) + log Σ exp(x - max(x)). Phase 5's information-theory page reuses it for cross-entropy.

References

  • Goldberg, What Every Computer Scientist Should Know About Floating-Point Arithmetic, §2.2 (overflow and the shift trick).
  • The NumPy source: scipy/special/_logsumexp.py is the textbook stable implementation; read it once.