English · Español
Break — Remove the -max shift from softmax¶
🇪🇸 La ruptura más clásica de toda la fase 2: quita el
- max(x)del softmax y observa cómo nan se propaga en cuanto algún logit supera ~88 (dondeexpdesborda en fp32).
Target: any softmax implementation (yours from lab 01, or a fresh one for the §A13 5-tense classifier).
Hypothesis¶
The learner predicts: "Removing the - max(x) shift will make softmax silently work for small logits, then catastrophically nan-out the moment a single logit exceeds ~88 (the fp32 overflow threshold for exp)."
The break¶
In your softmax(x) function:
def softmax(x: np.ndarray) -> np.ndarray:
- m = x.max()
- e = np.exp(x - m)
+ e = np.exp(x)
return e / e.sum()
Run procedure¶
Test on three §A13-tense logit vectors of growing magnitude:
uv run python -c "
import numpy as np
def softmax_naive(x):
e = np.exp(x)
return e / e.sum()
cases = {
'small': np.array([1.2, 4.7, 3.1, 0.5, 2.9], dtype=np.float32),
'medium': np.array([10.0, 50.0, 30.0, 20.0, 40.0], dtype=np.float32),
'large': np.array([10.0, 92.0, 30.0, 20.0, 40.0], dtype=np.float32),
}
for name, x in cases.items():
p = softmax_naive(x)
print(f'{name:6} max(x)={x.max():6.1f} p={p} sum={p.sum():.4f}')
"
Expected failure mode¶
small max(x)= 4.7 p=[0.018 0.612 0.124 0.009 0.237] sum=1.0000
medium max(x)= 50.0 p=[... finite ...] sum=1.0000
large max(x)= 92.0 p=[ 0. nan 0. 0. 0.] sum=nan
The cliff is sharp: exp(88.7) ≈ 3.4e38 is roughly the fp32 max. Beyond that, np.exp returns inf; inf / inf = nan. With the -max shift, the largest exponent is always 0, so exp(0) = 1 and the cliff is impossible.
Diagnostic¶
From training logs alone the symptom is "loss became nan around step N." Diagnostic checks:
assert not np.isnan(loss).any()after every step. Cheap, catches the symptom at step N rather than step N+1000.- Log
logits.max()per step. If it crosses ~80, you are one step from the cliff in fp32 (or ~10 in fp16 — fp16 overflows at ~65000). - Diff your
softmaxagainstscipy.special.softmaxon the failing input. scipy uses the-maxshift; if they disagree onlarge, your implementation is naive.
Lesson¶
The -max shift is algebraically a no-op: exp(-m) cancels in numerator and denominator. It is numerically a hard requirement: without it, softmax explodes the moment any logit exceeds ~88. The cost is one extra reduction (max over the vector); the savings is "never produces nan for finite input." Always pay the cost.
The same idea generalizes to logsumexp(x) = max(x) + log Σ exp(x - max(x)). Phase 5's information-theory page reuses it for cross-entropy.
References¶
- Goldberg, What Every Computer Scientist Should Know About Floating-Point Arithmetic, §2.2 (overflow and the shift trick).
- The NumPy source:
scipy/special/_logsumexp.pyis the textbook stable implementation; read it once.