Skip to content

English · Español

02 — Scaled Dot-Product Attention: la derivación completa

Pruébalo — un heatmap de attention causal

La ecuación central del transformer es \(\text{softmax}(Q K^\top / \sqrt{d_k}) V\). Aquí derivamos cada pieza: por qué producto escalar (similitud rotacionalmente invariante), por qué softmax (convertir similitudes en distribución), por qué dividir entre \(\sqrt{d_k}\) (varianza unitaria), y la reescritura numéricamente estable que toda implementación real usa (restar el máximo antes de exponenciar). Lee este archivo dos veces.

Este archivo es el más denso de la Fase 15. Derivamos cada término de

\[ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^\top}{\sqrt{d_k}}\right) V \]

desde primeros principios. Cinco piezas: similitud por producto escalar, la matriz de scores \(T \times T\), el escalado por \(\sqrt{d_k}\), softmax por filas, el matmul final con V.


Pieza 1 — Producto escalar como similitud

Necesitamos puntuar, para cada query \(Q_i \in \mathbb{R}^{d_k}\) y cada key \(K_j \in \mathbb{R}^{d_k}\), "¿cuán relevante es \(j\) para \(i\)?". Existen muchas opciones para la función de scoring (similitud coseno, aditivo, bilineal, MLP aprendido — la atención de Bahdanau de 2015 usaba un MLP aditivo). El paper del transformer eligió el producto escalar:

\[ \text{score}(Q_i, K_j) = Q_i \cdot K_j = \sum_{a=1}^{d_k} Q_{i,a} K_{j,a} \]

Dos razones para producto escalar sobre las alternativas:

  1. Barato y paralelo. Todos los \(T^2\) scores son una multiplicación de matrices: \(Q K^\top \in \mathbb{R}^{T \times T}\). En una GPU, un matmul grande es más rápido que \(T^2\) operaciones pequeñas o \(T\) MLPs.
  2. Empíricamente tan bueno como el scoring aditivo. El paper del transformer de 2017 lo demuestra. Para \(d_k\) por encima de unas docenas, attention por producto escalar iguala o supera la attention aditiva, a menor coste.

El coste de elegir producto escalar es la dependencia de escala — abordada en la Pieza 3.

Pieza 2 — La matriz completa de scores

Apila todas las queries como \(Q \in \mathbb{R}^{T \times d_k}\) y todas las keys como \(K \in \mathbb{R}^{T \times d_k}\). Los scores por pares son:

\[ S = Q K^\top \in \mathbb{R}^{T \times T}, \qquad S_{ij} = Q_i \cdot K_j \]

Cada fila de \(S\) son los scores de una query contra todas las keys. Cada columna son los scores de una key desde todas las queries.

Coste computacional. - Memoria: \(T^2\) floats. Para \(T = 2048\) y fp32, eso son 16 MiB por capa por cabeza. Con 24 capas y 16 cabezas, 6 GiB solo para los scores. Este es el cuello de botella que ataca Flash Attention (Fase 27). - FLOPs: \(2 T^2 d_k\) para el matmul. Cuadrático en \(T\) — la famosa "atención cuadrática".

Para la Fase 15, no intentamos ser listos. Calculamos \(S\) como una matriz densa. La Fase 27 lo revisitará.

Pieza 3 — Por qué dividir entre \(\sqrt{d_k}\)

Esta es la pieza más importante de la derivación. El argumento es control de varianza.

Supongamos que \(Q_i\) y \(K_j\) son vectores aleatorios con componentes i.i.d.: \(Q_{i,a}, K_{j,a} \sim \mathcal{N}(0, 1)\), independientes a lo largo de \(a\) y a lo largo de \(i, j\). Entonces

\[ \mathbb{E}[Q_i \cdot K_j] = \sum_a \mathbb{E}[Q_{i,a}] \mathbb{E}[K_{j,a}] = 0 \]
\[ \text{Var}(Q_i \cdot K_j) = \sum_a \text{Var}(Q_{i,a} K_{j,a}) = \sum_a \mathbb{E}[Q_{i,a}^2] \mathbb{E}[K_{j,a}^2] = \sum_a 1 = d_k \]

Así que \(Q_i \cdot K_j \sim \mathcal{N}(0, d_k)\) aproximadamente. Desviación estándar = \(\sqrt{d_k}\).

El problema: para \(d_k = 64\), los scores tienen desviación estándar \(8\). Para \(d_k = 256\), std = \(16\). El score más grande en una fila de \(T\) tales scores puede fácilmente estar 3–4 desviaciones estándar por encima de la media.

Ahora aplica softmax. \(\text{softmax}([s_1, \ldots, s_T])_i = e^{s_i} / \sum_j e^{s_j}\). Si un \(s_i\) es mucho mayor que el resto, \(e^{s_i}\) domina y la salida del softmax es casi one-hot — una entrada cerca de 1, todas las demás cerca de 0.

Por qué esto es malo: en el régimen saturado, el gradiente del softmax es casi cero. El modelo no puede aprender. Específicamente, \(\partial \text{softmax}_i / \partial s_j = \text{softmax}_i (\delta_{ij} - \text{softmax}_j)\). Cuando el softmax está cerca de one-hot, este producto está cerca de cero para todas las entradas.

El arreglo: dividir los scores entre \(\sqrt{d_k}\) antes del softmax. Ahora

\[ \text{Var}\left(\frac{Q_i \cdot K_j}{\sqrt{d_k}}\right) = \frac{d_k}{d_k} = 1 \]

— la varianza vuelve a 1, independientemente de \(d_k\). El softmax ve scores en un rango razonable, no se satura, el gradiente fluye. El escalado es independiente de los datos de entrenamiento — es una función de la elección arquitectónica \(d_k\).

Resumen del argumento de varianza: sin escalar, los scores tienen std \(\sqrt{d_k}\). Si \(d_k\) es grande, los scores son grandes, el softmax se satura, el gradiente muere. Dividir entre \(\sqrt{d_k}\) restaura std = 1 y mantiene el softmax en su régimen útil.

Comprobación de sanidad: ¿y si \(d_k\) es minúsculo?

Para \(d_k = 2\), la std de los scores es \(\sqrt{2} \approx 1.4\). El softmax está bien sin escalar. El arreglo es innecesario a \(d_k\) pequeño, pero no hace daño — dividir entre \(\sqrt{2}\) deja el comportamiento del softmax casi sin cambios.

Para los ejemplos de juguete de la Fase 15 (\(d_k = 2\)), el escalado es apenas visible. Para el Mini-GPT de la Fase 17 (\(d_k = 16\) por cabeza con 4 cabezas, \(d_\text{model} = 64\)), importa. Para los LLM modernos (\(d_k = 128+\) por cabeza), es esencial.

Verificación en el Lab 00

En lab/00-attention-by-hand.md, Borja ejecutará attention dos veces sobre los mismos Q, K, V — una con escalado, una sin él — con \(d_k = 64\), y observará que la matriz de atención de la versión sin escalar es casi one-hot. Lo visual convence.

Pieza 4 — Softmax (con estabilidad numérica)

El softmax ingenuo:

\[ \text{softmax}(s_i) = \frac{e^{s_i}}{\sum_j e^{s_j}} \]

Para \(s\) positivo grande, \(e^s\) desborda fp32 (máximo en torno a \(e^{88}\)). Para nuestros scores escalados rara vez es un problema, pero en implementaciones reales eventualmente muerde — sobre todo durante el entrenamiento cuando el gradiente ocasionalmente produce valores grandes.

Reescritura numéricamente estable:

\[ \text{softmax}(s_i) = \frac{e^{s_i - m}}{\sum_j e^{s_j - m}}, \qquad m = \max_k s_k \]

Tras restar el máximo, el mayor valor es \(0\), así que \(e^{s_i - m} \leq 1\). Sin desbordamiento. La salida es idéntica a la forma ingenua (el factor \(e^{-m}\) se cancela entre numerador y denominador).

Toda implementación de attention de producción hace esta resta-del-máximo. La Fase 27 (Flash Attention) lo hace incrementalmente en bloques (tiles) — misma idea, más contabilidad.

Nota de implementación para src/minimodel/attention/:

def softmax_stable(s, axis=-1):
    m = s.max(axis=axis, keepdims=True)
    exp = np.exp(s - m)
    return exp / exp.sum(axis=axis, keepdims=True)

Tres líneas. Usa siempre esta forma. Nunca el ingenuo np.exp(s) / np.exp(s).sum().

Pieza 5 — Multiplicar por V

El softmax produce probabilidades por filas: \(A = \text{softmax}(S / \sqrt{d_k}) \in \mathbb{R}^{T \times T}\). La fila \(i\) de \(A\) es una distribución de probabilidad sobre las \(T\) posiciones.

La salida final es

\[ \text{Attention}(Q, K, V) = A V \in \mathbb{R}^{T \times d_v} \]

La fila \(i\) de la salida es el promedio ponderado de filas de value: \(\text{out}_i = \sum_j A_{ij} V_j\).

Ese es el forward completo. Seis líneas de código (recuento literal de LOC abajo).

Poniéndolo todo junto

def single_head_attention(Q, K, V, mask=None):
    # Q, K: (T, d_k), V: (T, d_v)
    d_k = Q.shape[-1]
    scores = Q @ K.T / np.sqrt(d_k)         # (T, T)
    if mask is not None:
        scores = scores + mask              # additive -inf, see file 04
    attn = softmax_stable(scores, axis=-1)  # (T, T) row-normalized
    return attn @ V                         # (T, d_v)

Cinco líneas efectivas. Este es el mecanismo de atención completo. El resto del currículo se construye sobre esto.

Backward (esbozo)

No implementarás el backward a mano en la Fase 15 — el autograd de la Fase 8 lo gestiona. Pero la estructura del gradiente vale la pena conocerla para la vista previa de flash-attention de la Fase 27:

Sea \(L\) la pérdida aguas abajo, \(\delta_{\text{out}} = \partial L / \partial \text{out}\).

\[ \frac{\partial L}{\partial V} = A^\top \delta_{\text{out}}, \qquad \frac{\partial L}{\partial A} = \delta_{\text{out}} V^\top \]

A través del softmax (derivación estándar):

\[ \frac{\partial L}{\partial S} = A \odot \left( \frac{\partial L}{\partial A} - \left( \frac{\partial L}{\partial A} \odot A \right) \mathbf{1} \right) \]

(aproximadamente; la forma exacta está en cualquier referencia de derivación de autograd).

Luego a través del matmul \(S = Q K^\top / \sqrt{d_k}\):

\[ \frac{\partial L}{\partial Q} = \frac{1}{\sqrt{d_k}} \frac{\partial L}{\partial S} K, \qquad \frac{\partial L}{\partial K} = \frac{1}{\sqrt{d_k}} \left(\frac{\partial L}{\partial S}\right)^\top Q \]

Observación clave: calcular \(\partial L / \partial Q\) requiere la matriz \(A\) completa — memoria \(O(T^2)\). Esto es lo que Flash Attention recomputa en bloques para evitar materializar.

No implementes esto — el autograd lo hace. Solo nota el requisito de memoria \(O(T^2)\).

Resumen de complejidad

Operación FLOPs Memoria
\(Q K^\top\) \(2 T^2 d_k\) \(T^2\)
Escala \(T^2\)
Softmax \(T^2\) \(T^2\)
\(A V\) \(2 T^2 d_v\) \(T^2 + T d_v\)
Total \(\sim 4 T^2 d_k\) (con \(d_k = d_v\)) \(\sim T^2\)

Para \(T = 256, d_k = 16\): FLOPs \(\approx 4 \cdot 65536 \cdot 16 \approx 4\) MFLOP. Trivial en el i5-8250U. Para \(T = 2048, d_k = 64\): FLOPs \(\approx 4 \cdot 4 \cdot 10^6 \cdot 64 \approx 1\) GFLOP por capa por cabeza. Con 24 capas y 16 cabezas, ~400 GFLOP por forward pass. 2 segundos en el i5-8250U a 200 GFLOPS de pico. Por eso esperamos a la GPU en cloud en la Fase 23.

Análisis de roofline (vista previa de la Fase 27)

Intensidad aritmética del matmul de attention \(Q K^\top\):

  • FLOPs: \(2 T^2 d_k\).
  • Bytes: leer \(Q\) (\(T d_k\) fp32 = \(4 T d_k\) bytes), leer \(K\) (\(4 T d_k\)), escribir \(S\) (\(4 T^2\)). Total \(\approx 8 T d_k + 4 T^2\) bytes.
  • Intensidad: \(\frac{2 T^2 d_k}{8 T d_k + 4 T^2} = \frac{T d_k}{4(d_k + T)}\).

Para \(T \gg d_k\) (secuencias largas), intensidad \(\approx d_k / 4\). Memory-bound en el i5-8250U si \(d_k < 40\) (ya que el balance de la máquina es 10 FLOPs/byte de la Fase 1 — pero espera, 200 GFLOPS / 20 GB/s = 10 FLOPs/byte, así que \(d_k = 40\) es el cruce). Para \(d_k = 64\) típico, compute-bound; para \(d_k < 16\) por cabeza como en nuestro Mini-GPT, memory-bound.

El pase del softmax es siempre memory-bound (1–2 FLOPs por byte). Este es el kernel que ataca Flash Attention.

No arreglamos esto en la Fase 15 — solo lo vemos. El Lab 03 mide.

Para conectar con la Fase 1: attention en secuencias cortas es compute-bound; en secuencias largas con cabezas pequeñas, memory-bound. La parte de softmax es siempre memory-bound. Esto motiva Flash Attention (Fase 27), que no cambia los FLOPs sino la cantidad de bytes movidos.

Lo que este archivo NO cubre

  • Extensión multi-head. Próximo archivo (03-multi-head.md). Aquí hicimos single-head.
  • Máscara causal (causal mask). 04-masking.md. El parámetro mask en el código anterior se insinúa pero no se deriva aquí.
  • Implementación del backward. El autograd de la Fase 8 lo gestiona; esbozamos las matemáticas solo para contexto.
  • Attention eficiente en memoria. Flash Attention es la Fase 27. La Fase 15 implementa la forma ingenua \(O(T^2)\).
  • Términos de sesgo en las proyecciones Q/K/V. El estilo GPT-2 los elimina; seguimos la convención sin re-derivar.
  • Funciones de similitud alternativas. Bahdanau (aditiva), bilineal, MLP aprendido. Mencionadas como contexto histórico; solo implementamos producto escalar.

Qué memorizar

Antes del lab, Borja debería ser capaz de escribir — de memoria, en papel — lo siguiente en menos de 5 minutos:

  1. La ecuación completa: \(\text{Attention}(Q, K, V) = \text{softmax}(Q K^\top / \sqrt{d_k}) V\).
  2. Las formas: \(Q, K \in \mathbb{R}^{T \times d_k}\), \(V \in \mathbb{R}^{T \times d_v}\), salida \(\in \mathbb{R}^{T \times d_v}\).
  3. El argumento de varianza para el escalado \(\sqrt{d_k}\). (Var de \(q \cdot k\) = \(d_k\) cuando las componentes son \(\mathcal{N}(0, 1)\); el escalado restaura varianza unitaria.)
  4. El truco de resta-del-máximo para softmax estable.
  5. La implementación NumPy de cinco líneas.

El conjunto /quiz 15 comprueba estos.


Siguiente: 03-multi-head.md.