Skip to content

English · Español

01 — La recurrencia del online softmax

🇪🇸 La clave matemática para Flash: poder calcular softmax(s) @ V por trozos sin haber visto todo s antes. Mantienes el máximo corriente y la suma corriente, y al añadir un trozo nuevo, reescalas lo anterior por exp(m_viejo - m_nuevo). Una línea de álgebra, todo el resto del fenómeno depende de ella.


El softmax clásico

Para un vector s ∈ ℝ^N, el softmax numéricamente estable es:

m = max(s)
p = exp(s - m) / sum(exp(s - m))

La resta de m mantiene los argumentos de exp ≤ 0, así que no hay overflow. Esta es la formulación numéricamente estable; la versión inestable (simplemente exp(s) / sum(exp(s))) hace overflow para cualquier s_i > log(float_max) ≈ 88 (fp32) o > log(half_max) ≈ 11 (fp16).

Para la attention con Q, K de dimensión de cabeza d, un único valor pre-softmax s_i = (Q[i, :] · K[j, :]) / √d puede fácilmente superar 11 en fp16. El softmax estable (es decir, restar el máximo) es obligatorio, no opcional, en attention fp16.

El problema: la formulación estándar requiere el vector completo s para calcular m = max(s). Eso impide el cómputo en streaming.

El setup: attention en streaming

En FlashAttention, no tenemos la fila completa de s = Q[i, :] @ K^T ∈ ℝ^N de golpe. La tenemos tile-a-tile: trozos s_1, s_2, ..., s_{N/B_c} cada uno de longitud B_c. Para cada trozo s_k, queremos actualizar un O[i, :] ∈ ℝ^d corriente (la salida parcial de attention) tal que, tras consumir todos los trozos, O[i, :] = softmax(s) @ V.

La pregunta: ¿podemos actualizar O correctamente usando sólo el trozo actual y una pequeña cantidad de estado corriente?

Sí. Así.

La recurrencia

Mantén tres piezas de estado corriente para la fila de salida i:

  • m ∈ ℝ — máximo corriente de s visto hasta ahora.
  • ℓ ∈ ℝ — denominador corriente sum(exp(s_seen - m)).
  • O ∈ ℝ^d — salida no normalizada corriente sum_j exp(s_j - m) · V_j.

Cuando llega el trozo s_new ∈ ℝ^{B_c} (con el correspondiente V_new ∈ ℝ^{B_c × d}):

  1. Nuevo máximo: $$ m' = \max(m, \max(s_{\text{new}})) $$

  2. Reescala el estado viejo al nuevo máximo: $$ \alpha = \exp(m - m') $$ El y O viejos se calcularon relativos al m viejo. Para ponerlos al mismo nivel que el trozo nuevo (que calcularemos relativo a m'), multiplica ambos por α: $$ ℓ \leftarrow \alpha \cdot ℓ \qquad O \leftarrow \alpha \cdot O $$

  3. Añade la contribución del nuevo trozo: $$ p_{\text{new}} = \exp(s_{\text{new}} - m') \in \mathbb{R}^{B_c} $$ $$ ℓ \leftarrow ℓ + \sum p_{\text{new}} $$ $$ O \leftarrow O + p_{\text{new}} \cdot V_{\text{new}} \quad \text{(producto matriz-vector sobre } B_c \text{ términos)} $$

  4. Actualiza m: $$ m \leftarrow m' $$

Al final, divide una vez:

\[ O_{\text{final}} = O / ℓ \]

Esa es la recurrencia entera. Seis líneas de pseudocódigo; un cómputo corriente O(N·d); matemáticamente idéntico al softmax todo-a-la-vez salvo redondeo fp.

Demostración de corrección

Afirmación: tras procesar todos los trozos, O / ℓ = softmax(s) @ V.

Sea s = [s_1, ..., s_K] los trozos concatenados (cada uno de longitud B_c). Sea m_global = max(s). Por construcción, tras todos los trozos, m = m_global (el máximo corriente acumula correctamente).

Tras todos los trozos: - ℓ = sum_{j=1..N} exp(s_j - m_global). - O = sum_{j=1..N} exp(s_j - m_global) · V_j.

El resultado todo-a-la-vez es O_all = (sum_j exp(s_j - m_global) · V_j) / sum_j exp(s_j - m_global) = O / ℓ. Mismo valor. □

El único paso sutil es la reescala. Supón que hemos procesado los trozos 1..k y estamos a punto de procesar el trozo k+1. Justo antes de procesar k+1, nuestro estado es:

  • m = max(s_1, ..., s_k). Llámalo m_k.
  • ℓ = sum_{j ∈ primeros k trozos} exp(s_j - m_k).
  • O = sum_{j ∈ primeros k trozos} exp(s_j - m_k) · V_j.

Tras observar el trozo k+1, el máximo verdadero es m_{k+1} = max(m_k, max(s_{k+1})). Para rebasar al nuevo máximo:

ℓ_rebased = sum_{j ∈ primeros k trozos} exp(s_j - m_{k+1})
          = sum_{j ∈ primeros k trozos} exp(s_j - m_k + m_k - m_{k+1})
          = exp(m_k - m_{k+1}) × sum_{j ∈ primeros k trozos} exp(s_j - m_k)
          = α × ℓ

donde α = exp(m_k - m_{k+1}) ≤ 1. Misma álgebra para O. Luego añade las contribuciones del trozo k+1 (que ya están calculadas relativas a m_{k+1}). □

Propiedades numéricas

Tres observaciones:

  1. Sin overflow. Cada argumento de exp es ≤ 0 por construcción (siempre restamos el máximo actual o nuevo antes de exponenciar). Seguro en fp16.
  2. Sin cancelación catastrófica en la reescala. α = exp(m_k - m_{k+1}) ≤ 1 está acotado; multiplicar y O por él sólo puede hacer underflow a cero si m saltó más de ~88 (fp32) o ~11 (fp16). Para valores de attention, esto es raro pero posible — la discusión de estabilidad del paper de Flash lo maneja vía acumuladores fp32 para m, ℓ, O.
  3. El orden de los trozos no importa. El resultado final es invariante al orden de los trozos (módulo redondeo). Esto es lo que hace que Flash funcione dentro de un kernel tileado: los tiles pueden procesarse en cualquier orden que el scheduler prefiera.

Un ejemplo trabajado

Vector s = [1, 2, 3, 10], V = [[1], [1], [1], [1]] (así cada V de trozo es un vector 1×1; la respuesta debería converger a softmax(s) @ V = 1.0 puesto que todas las entradas de V son 1).

Procesar en dos trozos: s_1 = [1, 2], s_2 = [3, 10].

Tras el trozo 1: - m = 2 - p_1 = exp([1-2, 2-2]) = [exp(-1), 1] ≈ [0.368, 1] - ℓ = 1.368 - O = 0.368 + 1 = 1.368

Tras el trozo 2: - max(s_2) = 10, así que m' = max(2, 10) = 10. - α = exp(2 - 10) = exp(-8) ≈ 3.35e-4 - Reescala: ℓ = 0.000458, O = 0.000458. - p_2 = exp([3-10, 10-10]) = [exp(-7), 1] ≈ [9.12e-4, 1] - ℓ = 0.000458 + 0.000912 + 1 = 1.00137 - O = 0.000458 + 0.000912 + 1 = 1.00137

Final: O / ℓ = 1.00137 / 1.00137 = 1.0. ✓

Compara con todo-a-la-vez: softmax([1, 2, 3, 10]) ≈ [1.23e-4, 3.34e-4, 9.08e-4, 0.9985]. Producto punto con [1, 1, 1, 1] = 1.0. Misma respuesta.

Lo que esto permite

Una vez tenemos la recurrencia del online softmax, el resto de Flash es "sólo" tiling. Procesamos la matriz de attention S = QK^T un tile a la vez, sin materializar nunca la cosa entera en HBM. Para cada fila de salida de O, mantenemos una cantidad minúscula de estado (m, ℓ, O) y la actualizamos conforme los tiles de K, V van llegando.

La recurrencia es la razón matemática por la que Flash es correcto. El tiling (siguiente archivo de teoría) es la razón algorítmica por la que es rápido.

Problemas de práctica

Soluciones al abrir la fase en solutions/01-online-softmax-ref.md. Inténtalo sin código.

  1. Calcula el online softmax sobre s = [0, 5] con V = [[2], [3]] troceado como s_1 = [0], s_2 = [5]. Muestra los cuatro pasos.
  2. La recurrencia requiere α = exp(m_k - m_{k+1}). ¿Bajo qué condiciones α hace underflow a 0 en fp16? ¿Qué falla en si ocurre? (Pista: nada falla — la matemática sigue siendo correcta en el límite.)
  3. Muestra que procesar los trozos en orden inverso produce el mismo O/ℓ final (módulo redondeo fp). Esboza el argumento; no simules.
  4. Supón que paralelizas la recurrencia entre P hilos, cada uno manejando 1/P de los trozos, luego reduces al final. ¿Cuál es la operación de reducción? Muestra que es asociativa.

Recap de un párrafo

La recurrencia del online softmax mantiene un máximo corriente m, un denominador corriente , y una salida no normalizada corriente O mientras procesa trozos de s y V uno a uno. Cuando llega un trozo nuevo, reescala los ℓ, O viejos por α = exp(m_old - m_new) para alinearlos con el nuevo máximo, luego añade las contribuciones del trozo nuevo. La respuesta final es O / ℓ. El álgebra es una línea, el cómputo es O(N·d) sin materializar nunca el softmax completo, y el resultado es exacto salvo redondeo de coma flotante. Esta es la clave matemática que hace posible FlashAttention — sin ella, no podrías procesar attention en tiles.

Siguiente: theory/02-flash-attention.md.