English · Español
Lab 02 — Log-sum-exp and stable cross-entropy from logits¶
Read
theory/04-log-sum-exp-and-stability.md. Do not consultsolutions/.
Objective¶
Implement logsumexp, log_softmax, and cross_entropy_from_logits with full numerical-stability discipline. Demonstrate that the naïve implementations fail on adversarial inputs while the stable ones succeed.
Setup¶
Continue in src/phase05/probability.py.
Tasks¶
Task 1 — naïve implementations (so you see them fail)¶
Implement first the naïve versions:
def logsumexp_naive(z): return np.log(np.exp(z).sum())
def log_softmax_naive(z): return np.log(np.exp(z) / np.exp(z).sum())
def cross_entropy_naive(z, y_star): return -np.log(np.exp(z) / np.exp(z).sum())[y_star]
Test on the following inputs and document what happens:
| Input \(z\) | Expected outcome | Naïve result |
|---|---|---|
[0, 0, 0] |
sane | should work |
[1, 2, 3] |
sane | should work |
[1000, 1001, 1002] |
should be sane but won't | overflow → inf / NaN |
[-1000, -999, -998] |
should be sane but won't | underflow → 0 → -inf log |
Task 2 — stable logsumexp¶
Implement the stable version (subtract max before exp). Re-run all 4 inputs from Task 1; all 4 should now produce finite, correct outputs. Verify against scipy.special.logsumexp.
Task 3 — stable log_softmax¶
Same exercise for log_softmax. Reference: scipy.special.log_softmax.
Task 4 — stable cross_entropy_from_logits¶
def cross_entropy_from_logits(z, y_star):
"""Stable CE from raw logits. Equivalent to PyTorch's F.cross_entropy on a single example."""
return -log_softmax(z)[y_star]
Verify on a small synthetic batch.
Task 5 — property tests¶
Add to tests/test_phase05_logsumexp.py:
- Shift invariance. For any \(c \in \mathbb{R}\):
logsumexp(z + c) == logsumexp(z) + cwithin tolerance. - Softmax-shift invariance. For any \(c\):
log_softmax(z + c)equalslog_softmax(z)(because the constant cancels). - Reduction sanity.
log_softmax(z).sum() == log_softmax([z, z]).sum() / 2 * 2— i.e., the result is well-defined per-row. - Reference parity. Compare against
scipy.special.log_softmaxon a battery of inputs (uniform, peaked, large, small, negative).
Task 6 — measure speed¶
logsumexp on shape (B, V) = (64, 600):
- Time the stable NumPy version.
- Time
scipy.special.logsumexp. - Time the naïve broken version (just for context — even if it would NaN on real logits, it's a useful comparison on safe inputs).
Save measurements to experiments/<date>-phase-05-logsumexp/timings.csv.
Acceptance¶
- All 4 inputs in Task 1 documented (naïve fails as predicted).
- Stable implementations pass on all 4 inputs.
- Property tests pass.
- Reference parity against scipy within
1e-12. - Timings captured.
Pitfalls to expect¶
np.exp(1002) == np.infin float64; you'll seeRuntimeWarning: overflow encountered in exp— that's the point. Don't suppress the warning; it's diagnostic.np.log(0.0) == -np.inf; downstream multiplication by 0 givesNaN. The stable version avoids this entirely by never computingnp.logof underflowed exponents.- When subtracting the max, watch axis semantics:
z.max(axis=-1, keepdims=True)for batchedz. - The stable
cross_entropy_from_logitsis fused — never compute log-softmax then index then log; just-log_softmax(z)[y_star]. PyTorch'sF.cross_entropydoes the same fusion under the hood.
Next: 03-calibration.md