English · Español
Break — Quitar el desplazamiento -max del softmax¶
🇪🇸 La ruptura más clásica de toda la fase 2: quita el
- max(x)del softmax y observa cómo nan se propaga en cuanto algún logit supera ~88 (dondeexpdesborda en fp32).
Objetivo: cualquier implementación de softmax (la tuya del lab 01, o una nueva para el clasificador de 5 tiempos de §A13).
Hipótesis¶
El aprendiz predice: "Quitar el desplazamiento - max(x) hará que el softmax funcione silenciosamente para logits pequeños, y luego caiga catastróficamente a NaN en cuanto un solo logit exceda ~88 (el umbral de overflow de fp32 para exp)."
La ruptura¶
En tu función softmax(x):
def softmax(x: np.ndarray) -> np.ndarray:
- m = x.max()
- e = np.exp(x - m)
+ e = np.exp(x)
return e / e.sum()
Procedimiento de ejecución¶
Testea sobre tres vectores de logits de tiempos de §A13 de magnitud creciente:
uv run python -c "
import numpy as np
def softmax_naive(x):
e = np.exp(x)
return e / e.sum()
cases = {
'small': np.array([1.2, 4.7, 3.1, 0.5, 2.9], dtype=np.float32),
'medium': np.array([10.0, 50.0, 30.0, 20.0, 40.0], dtype=np.float32),
'large': np.array([10.0, 92.0, 30.0, 20.0, 40.0], dtype=np.float32),
}
for name, x in cases.items():
p = softmax_naive(x)
print(f'{name:6} max(x)={x.max():6.1f} p={p} sum={p.sum():.4f}')
"
Modo de fallo esperado¶
small max(x)= 4.7 p=[0.018 0.612 0.124 0.009 0.237] sum=1.0000
medium max(x)= 50.0 p=[... finite ...] sum=1.0000
large max(x)= 92.0 p=[ 0. nan 0. 0. 0.] sum=nan
El precipicio es nítido: exp(88.7) ≈ 3.4e38 es aproximadamente el máximo fp32. Más allá de eso, np.exp devuelve inf; inf / inf = nan. Con el desplazamiento -max, el mayor exponente es siempre 0, así que exp(0) = 1 y el precipicio es imposible.
Diagnóstico¶
Desde los logs de entrenamiento solos el síntoma es "la pérdida se volvió nan alrededor del paso N". Chequeos diagnósticos:
assert not np.isnan(loss).any()tras cada paso. Barato, pilla el síntoma en el paso N en vez del N+1000.- Loguea
logits.max()por paso. Si cruza ~80, estás a un paso del precipicio en fp32 (o ~10 en fp16 — fp16 desborda en ~65000). - Compara tu
softmaxcontrascipy.special.softmaxen la entrada que falla. scipy usa el desplazamiento-max; si difieren enlarge, tu implementación es ingenua.
Lección¶
El desplazamiento -max es algebraicamente un no-op: exp(-m) se cancela en numerador y denominador. Es numéricamente un requisito duro: sin él, el softmax explota en cuanto algún logit excede ~88. El coste es una reducción extra (max sobre el vector); el ahorro es "nunca produce nan para entrada finita". Paga siempre el coste.
La misma idea generaliza a logsumexp(x) = max(x) + log Σ exp(x - max(x)). La página de teoría de la información de la Fase 5 la reutiliza para la cross-entropy.
Referencias¶
- Goldberg, What Every Computer Scientist Should Know About Floating-Point Arithmetic, §2.2 (overflow y el truco del desplazamiento).
- La fuente de NumPy:
scipy/special/_logsumexp.pyes la implementación estable de libro de texto; léela una vez.