Skip to content

English · Español

Lab 02 — Log-sum-exp and stable cross-entropy from logits

Read theory/04-log-sum-exp-and-stability.md. Do not consult solutions/.

Objective

Implement logsumexp, log_softmax, and cross_entropy_from_logits with full numerical-stability discipline. Demonstrate that the naïve implementations fail on adversarial inputs while the stable ones succeed.

Setup

Continue in src/phase05/probability.py.

Tasks

Task 1 — naïve implementations (so you see them fail)

Implement first the naïve versions:

def logsumexp_naive(z): return np.log(np.exp(z).sum())
def log_softmax_naive(z): return np.log(np.exp(z) / np.exp(z).sum())
def cross_entropy_naive(z, y_star): return -np.log(np.exp(z) / np.exp(z).sum())[y_star]

Test on the following inputs and document what happens:

Input \(z\) Expected outcome Naïve result
[0, 0, 0] sane should work
[1, 2, 3] sane should work
[1000, 1001, 1002] should be sane but won't overflow → inf / NaN
[-1000, -999, -998] should be sane but won't underflow → 0-inf log

Task 2 — stable logsumexp

Implement the stable version (subtract max before exp). Re-run all 4 inputs from Task 1; all 4 should now produce finite, correct outputs. Verify against scipy.special.logsumexp.

Task 3 — stable log_softmax

Same exercise for log_softmax. Reference: scipy.special.log_softmax.

Task 4 — stable cross_entropy_from_logits

def cross_entropy_from_logits(z, y_star):
    """Stable CE from raw logits. Equivalent to PyTorch's F.cross_entropy on a single example."""
    return -log_softmax(z)[y_star]

Verify on a small synthetic batch.

Task 5 — property tests

Add to tests/test_phase05_logsumexp.py:

  1. Shift invariance. For any \(c \in \mathbb{R}\): logsumexp(z + c) == logsumexp(z) + c within tolerance.
  2. Softmax-shift invariance. For any \(c\): log_softmax(z + c) equals log_softmax(z) (because the constant cancels).
  3. Reduction sanity. log_softmax(z).sum() == log_softmax([z, z]).sum() / 2 * 2 — i.e., the result is well-defined per-row.
  4. Reference parity. Compare against scipy.special.log_softmax on a battery of inputs (uniform, peaked, large, small, negative).

Task 6 — measure speed

logsumexp on shape (B, V) = (64, 600):

  1. Time the stable NumPy version.
  2. Time scipy.special.logsumexp.
  3. Time the naïve broken version (just for context — even if it would NaN on real logits, it's a useful comparison on safe inputs).

Save measurements to experiments/<date>-phase-05-logsumexp/timings.csv.

Acceptance

  • All 4 inputs in Task 1 documented (naïve fails as predicted).
  • Stable implementations pass on all 4 inputs.
  • Property tests pass.
  • Reference parity against scipy within 1e-12.
  • Timings captured.

Pitfalls to expect

  • np.exp(1002) == np.inf in float64; you'll see RuntimeWarning: overflow encountered in exp — that's the point. Don't suppress the warning; it's diagnostic.
  • np.log(0.0) == -np.inf; downstream multiplication by 0 gives NaN. The stable version avoids this entirely by never computing np.log of underflowed exponents.
  • When subtracting the max, watch axis semantics: z.max(axis=-1, keepdims=True) for batched z.
  • The stable cross_entropy_from_logits is fused — never compute log-softmax then index then log; just -log_softmax(z)[y_star]. PyTorch's F.cross_entropy does the same fusion under the hood.

Next: 03-calibration.md