Skip to content

English · Español

04 — Log-sum-exp y estabilidad numérica

🇪🇸 La matemática es bella, pero los exponentes se desbordan. Esta es la última pieza: cómo calcular softmax y cross-entropy sin que numerosos detalles te dejen con NaN.

El problema

El softmax de los logits \(z \in \mathbb{R}^V\) es:

\[\sigma(z)_i = \frac{\exp(z_i)}{\sum_j \exp(z_j)}.\]

Dos modos de fallo en float32:

  1. Overflow. Si \(\max_i z_i \gg 88\), \(\exp(z_i)\) se desborda a \(+\infty\).
  2. Underflow. Si \(z_i \ll -88\), \(\exp(z_i)\) subfluye a 0; el \(p_i\) correspondiente es 0 aun cuando debería ser \(\sim 10^{-40}\).

Ambos ocurren rutinariamente. Los logits de un transformer moderno pueden estar en los cientos (especialmente tras escalado por temperatura) y la pérdida de log-verosimilitud negativa calculada vía softmax + log ingenuo producirá silenciosamente NaN o gradiente cero cuando esto pase.

La identidad log-sum-exp

Para cualquier constante \(c \in \mathbb{R}\):

\[\log \sum_j \exp(z_j) = c + \log \sum_j \exp(z_j - c).\]

Elige \(c = \max_j z_j\). Entonces el máximo exponente dentro de la suma es \(\exp(0) = 1\) (sin overflow), y el más pequeño es \(\exp(z_{\min} - z_{\max})\), que puede subfluir a 0 pero sólo contribuye de forma despreciable a la suma.

Log-sum-exp estable:

def logsumexp(z):
    c = z.max()
    return c + np.log(np.exp(z - c).sum())

Log-softmax estable

\[\log \sigma(z)_i = z_i - \log \sum_j \exp(z_j) = (z_i - c) - \log \sum_j \exp(z_j - c).\]
def log_softmax(z):
    c = z.max()
    z_shifted = z - c
    return z_shifted - np.log(np.exp(z_shifted).sum())

Borja implementa esto en el lab 02. Compara con el \(\log(\sigma(z))\) ingenuo — deberían coincidir en entradas bien condicionadas y discrepar (uno siendo correcto, el otro NaN) en entradas adversariales como z = [1000, 1001, 1002].

Entropía cruzada estable desde logits

Para un único ejemplo con etiqueta verdadera \(y^*\):

\[\mathcal{L} = -\log \sigma(z)_{y^*} = -z_{y^*} + \log \sum_j \exp(z_j) = -z_{y^*} + \text{logsumexp}(z).\]

Crucialmente: no calcules primero el softmax y luego el log del resultado. Ve siempre de logits → log-softmax en un paso. Implementación en NumPy:

def cross_entropy_logits(z, y_star):
    log_probs = log_softmax(z)
    return -log_probs[y_star]

Vectorizado sobre un batch de tamaño \(B\):

def cross_entropy_batch(Z, Y):  # Z: (B, V), Y: (B,)
    c = Z.max(axis=1, keepdims=True)
    Z_shifted = Z - c
    log_probs = Z_shifted - np.log(np.exp(Z_shifted).sum(axis=1, keepdims=True))
    return -log_probs[np.arange(B), Y].mean()

El gradiente (presagio de la Fase 07)

Lo derivaremos con detalle en la Fase 07, pero vale la pena verlo ahora: el gradiente de CE-desde-logits respecto a los logits es:

\[\frac{\partial \mathcal{L}}{\partial z_i} = \sigma(z)_i - \mathbb{1}[i = y^*] = q_i - \delta_{i, y^*}.\]

Es elegante y numéricamente estable: sin divisiones por probabilidades diminutas, sin logs de números pequeños. También es la razón por la que cada framework fusiona log_softmax + nll_loss en una sola operación (cross_entropy_loss) — tanto por velocidad como por estabilidad.

La variante con label smoothing

Una variante común usa etiquetas suaves: en lugar de \(\mathbb{1}[i = y^*]\), objetivo \(p_i = (1 - \alpha) \mathbb{1}[i = y^*] + \alpha / V\) para algún suavizado \(\alpha \in [0, 1)\). La CE se vuelve:

\[\mathcal{L} = -(1 - \alpha) \log q_{y^*} - \frac{\alpha}{V} \sum_i \log q_i.\]

Borja no tiene que implementar label smoothing en la Fase 05 (aterriza en la Fase 18 como regularización opcional), pero vale la pena ver la forma.

Catálogo de escollos numéricos

Bug Síntoma Cura
Softmax ingenuo con logits grandes inf o NaN en probabilidades Log-softmax estable (arriba)
Entropía cruzada calculada como -(p * log(q)).sum() con q conteniendo ceros Pérdida inf Calcular desde logits, no desde probs
Olvidar enmascarar las posiciones de padding en CE de secuencia La pérdida incluye tokens de padding Multiplicar la pérdida por token por la máscara de atención antes de promediar
Olvidar .detach() del target cuando el target tiene gradientes (raro pero real) Gradiente incorrecto Usa etiquetas one-hot o de índice, no etiquetas suaves con gradientes
Calcular la pérdida del batch como sum en lugar de mean La pérdida escala con el tamaño del batch; acoplada al LR Promediar por número de tokens, no por número de secuencias
Usar FP16 en toda la entropía cruzada Underflow en el log-softmax Promocionar a FP32 para el log-softmax + loss; mantener FP16 para el matmul

El último ítem es un clásico — el entrenamiento en precisión mixta (Fase 18) requiere exactamente este patrón de upcasting.

Disciplina de pruebas

El lab 02 de Borja debe incluir:

  1. Property test: log_softmax(z) ≈ log_softmax(z + c) para cualquier escalar \(c\) (invarianza por desplazamiento).
  2. Test de referencia: comparar contra scipy.special.log_softmax en una batería de entradas.
  3. Stress test: entradas como z = [1000, 1001, 1002] y z = [-1000, -999, -998] producen salidas finitas y sensatas.

Estos tests deben aterrizar en tests/test_logsoftmax.py y mantenerse verdes durante el resto del proyecto. La Fase 19 (debugging de dinámica de entrenamiento) referenciará esta disciplina de pruebas.

Lo que este archivo NO cubre

  • Análisis numérico de derivadas de orden superior (rara vez necesario).
  • Álgebra tropical / de semianillos relacionada con softmax (teóricamente bonita, prácticamente no aporta).
  • Detalles de implementación de precisión mixta (tratados en la Fase 18).

Siguiente: ../lab/00-entropy-by-hand.md