Skip to content

English · Español

Lab 00 — Softmax online en Python puro

Objetivo: implementar la recurrencia del softmax online y verificar que coincide con el softmax por lotes sobre datos sintéticos, tanto en fp32 como en fp16.

Tiempo estimado: 2–3 horas.

Prerrequisito: theory 01 leída y la recurrencia rederivable de memoria.


Lo que produces

Un directorio experiments/27-online-softmax/ que contenga:

  • online_softmax.py — implementación en Python puro (NumPy). Dos funciones: softmax_batched (formulación clásica) y softmax_online_chunked(chunks, V_chunks, m_init=-inf, ℓ_init=0, O_init=0).
  • test_equivalence.py — verifica que las dos producen salidas idénticas sobre entradas aleatorias.
  • results.json — medidas del max-abs-error entre {fp32, fp16, bf16} y entre tamaños de chunk {1, 4, 16, 64}.
  • manifest.json.
  • README.md — interpretación.

No hay entregable en src/ para este lab; la implementación es código pedagógico deliberadamente desechable.

TODOs

Bloque A — implementar softmax por lotes como referencia

  • softmax_batched(s: ndarray, V: ndarray) -> ndarray: devuelve softmax(s) @ V (s 1D, V 2D de forma (N, d)). Usa la forma numéricamente estable.

Bloque B — implementar la versión online

  • Firma: softmax_online(chunks: list[ndarray], V_chunks: list[ndarray]) -> ndarray.
  • Bucle sobre chunks; mantén m, ℓ, O según la recurrencia de theory 01.
  • Devuelve O / ℓ final.

Bloque C — verificar la equivalencia

Para cada (N, d, dtype, chunk_size) en una rejilla pequeña:

  • Genera s ∈ ℝ^N, V ∈ ℝ^{N × d} aleatorios.
  • Calcula las versiones por lotes y online.
  • Registra max_abs_error = max|O_batched - O_online|.

Rangos esperados: - fp32: < 1e-6 (sólo redondeo). - fp16: < 1e-3. - bf16: < 5e-3.

Bloque D — sensibilidad al tamaño de chunk

Para fp16:

  • Barre chunk_size ∈ {1, 4, 16, 64, 256}. Plotea max_abs_error vs chunk_size.
  • Esperado: el error es aproximadamente constante entre tamaños de chunk (la recurrencia es exacta algebraicamente; sólo importa el redondeo, y el redondeo está acotado por O(N) independientemente del tamaño de chunk).
  • Si el error varía bruscamente con chunk_size, tu reescalado α tiene un bug.

Bloque E — entradas patológicas

Prueba la recurrencia sobre:

  • s = [60, 0, 0, 0] (un valor enorme seguido de otros minúsculos). En fp16, exp(60) desborda — pero el softmax online debería manejarlo gracias a la sustracción del máximo móvil.
  • s = [-100, -100, -100, 0] (uno casi-cero, el resto muy negativos). Verifica que softmax_onlineO ≈ V[3].
  • s = todo ceros. Atención uniforme; la salida es mean(V).

Bloque E' — entradas realistas del corpus de verbos

El vocabulario del corpus de verbos es pequeño (~600 formas), así que los logits de atención sobre una secuencia de 64 tokens tienen una distribución muy picuda tras unas pocas épocas de entrenamiento (el modelo está seguro de cada verbo en contexto).

  • Genera s muestreando de N(0, 5) (una distribución picuda-pero-no-patológica que imita los logits de atención post-entrenamiento sobre un vocab pequeño).
  • Establece V con forma (64, 64) (N = 64 coincide con la longitud de secuencia del corpus de verbos; d = 64 es una dimensión de cabeza típica para MiniGPT).
  • Verifica que softmax_online con chunk_size=16 coincide con el softmax por lotes dentro de la tolerancia fp16.
  • Comenta en README.md: ¿cómo ayuda o perjudica la distribución picuda a la estabilidad de la recurrencia online?

Bloque F — interpretar en README.md

Tres preguntas:

  1. ¿Cuál es el peor error fp16 que observaste? ¿Está bajo 1e-3? Si no, ¿de dónde vino el error adicional?
  2. ¿El tamaño de chunk afecta a la precisión? Si observaste sensibilidad, ¿por qué?
  3. ¿Qué ocurre en fp16 cuando s contiene un valor > 11? (exp(11) ≈ half_max.) ¿Maneja la recurrencia online el caso, o necesitas acumuladores fp32 para m, ℓ, O?

Condiciones de parada

  • Los cinco archivos commiteados.
  • max_abs_error fp32 < 1e-6; fp16 < 1e-3.
  • README responde las tres preguntas del Bloque F.

Errores típicos

  • Off-by-one en el bucle de chunks. Si los chunks no teselan perfectamente s, el manejo del último chunk puede descartar entradas. Usa una lista de chunks en Python en lugar de aritmética de índices para mayor claridad.
  • Reescalado α aplicado a lo que no toca. Tanto como O necesitan α. Si olvidas uno, los errores se acumulan cuadráticamente con N.
  • exp(m - m') subdesborda. Cuando m_new ≫ m_old, α puede subdesbordar. Esto es matemáticamente fino (α → 0 sólo significa que las contribuciones antiguas quedan eclipsadas por las nuevas), pero si calculas ℓ * α en fp16 y era pequeño, puedes perderlo del todo. Usa acumuladores fp32 para el estado móvil en implementaciones fp16.

Cuándo consultar solutions/

Tras cumplir todas las condiciones de parada. solutions/00-online-softmax-ref.md (apertura de fase) compara estructura y números.


Siguiente lab: lab/01-flash-bytes.md.