Skip to content

English · Español

Break — Quitar el desplazamiento -max del softmax

🇪🇸 La ruptura más clásica de toda la fase 2: quita el - max(x) del softmax y observa cómo nan se propaga en cuanto algún logit supera ~88 (donde exp desborda en fp32).

Objetivo: cualquier implementación de softmax (la tuya del lab 01, o una nueva para el clasificador de 5 tiempos de §A13).

Hipótesis

El aprendiz predice: "Quitar el desplazamiento - max(x) hará que el softmax funcione silenciosamente para logits pequeños, y luego caiga catastróficamente a NaN en cuanto un solo logit exceda ~88 (el umbral de overflow de fp32 para exp)."

La ruptura

En tu función softmax(x):

 def softmax(x: np.ndarray) -> np.ndarray:
-    m = x.max()
-    e = np.exp(x - m)
+    e = np.exp(x)
     return e / e.sum()

Procedimiento de ejecución

Testea sobre tres vectores de logits de tiempos de §A13 de magnitud creciente:

uv run python -c "
import numpy as np
def softmax_naive(x):
    e = np.exp(x)
    return e / e.sum()

cases = {
    'small':  np.array([1.2, 4.7, 3.1, 0.5, 2.9], dtype=np.float32),
    'medium': np.array([10.0, 50.0, 30.0, 20.0, 40.0], dtype=np.float32),
    'large':  np.array([10.0, 92.0, 30.0, 20.0, 40.0], dtype=np.float32),
}
for name, x in cases.items():
    p = softmax_naive(x)
    print(f'{name:6}  max(x)={x.max():6.1f}  p={p}  sum={p.sum():.4f}')
"

Modo de fallo esperado

small   max(x)=   4.7  p=[0.018 0.612 0.124 0.009 0.237]  sum=1.0000
medium  max(x)=  50.0  p=[... finite ...]                 sum=1.0000
large   max(x)=  92.0  p=[ 0.  nan  0.  0.  0.]            sum=nan

El precipicio es nítido: exp(88.7) ≈ 3.4e38 es aproximadamente el máximo fp32. Más allá de eso, np.exp devuelve inf; inf / inf = nan. Con el desplazamiento -max, el mayor exponente es siempre 0, así que exp(0) = 1 y el precipicio es imposible.

Diagnóstico

Desde los logs de entrenamiento solos el síntoma es "la pérdida se volvió nan alrededor del paso N". Chequeos diagnósticos:

  1. assert not np.isnan(loss).any() tras cada paso. Barato, pilla el síntoma en el paso N en vez del N+1000.
  2. Loguea logits.max() por paso. Si cruza ~80, estás a un paso del precipicio en fp32 (o ~10 en fp16 — fp16 desborda en ~65000).
  3. Compara tu softmax contra scipy.special.softmax en la entrada que falla. scipy usa el desplazamiento -max; si difieren en large, tu implementación es ingenua.

Lección

El desplazamiento -max es algebraicamente un no-op: exp(-m) se cancela en numerador y denominador. Es numéricamente un requisito duro: sin él, el softmax explota en cuanto algún logit excede ~88. El coste es una reducción extra (max sobre el vector); el ahorro es "nunca produce nan para entrada finita". Paga siempre el coste.

La misma idea generaliza a logsumexp(x) = max(x) + log Σ exp(x - max(x)). La página de teoría de la información de la Fase 5 la reutiliza para la cross-entropy.

Referencias

  • Goldberg, What Every Computer Scientist Should Know About Floating-Point Arithmetic, §2.2 (overflow y el truco del desplazamiento).
  • La fuente de NumPy: scipy/special/_logsumexp.py es la implementación estable de libro de texto; léela una vez.