Skip to content

English · Español

03 — Multi-Head Attention

Multi-head attention = correr varias attentions en paralelo, cada una en un subespacio de dimensión más pequeña, y concatenar las salidas. Con el mismo presupuesto de parámetros que una single-head, multi-head puede atender simultáneamente a varias relaciones distintas (una cabeza al sujeto-verbo, otra al sustantivo-adjetivo, otra al cierre de paréntesis). Es la formulación que ganó.

Este archivo deriva multi-head attention, explica por qué es equivalente en parámetros a single-head pero más expresivo, y bloquea la superficie de API de la que dependen las fases aguas abajo.


La limitación de single-head

Single-head attention calcula un patrón de atención \(T \times T\) por capa. Ese patrón es una función bilineal de scoring \(x_i^\top (W_Q W_K^\top) x_j\). El modelo está limitado a una noción de similitud por capa.

Pero el lenguaje real tiene muchos tipos de dependencias, todas relevantes en paralelo. Para nuestro alcance de gramática verbal (§A13), incluso la completación de juguete I work, you work, he ___ involucra múltiples relaciones paralelas:

  • Concordancia de persona. El verbo en la posición 7 debe mirar al pronombre sujeto (posición 6, he) para elegir la forma -s vs -.
  • Identificación de tiempo. El verbo en la posición 7 debe mirar los tokens verbales previos (posiciones 1, 4 — ambos work) para elegir el paradigma del present simple.
  • Alineamiento inglés↔español. Al predecir traducciones al español (yo trabajo), la posición del verbo debe alinearse tanto con el token en inglés como con el pronombre en español.
  • Posicional / localidad. Algunas cabezas simplemente miran el token inmediatamente anterior.

Forzar todo esto a través de una sola forma bilineal es un cuello de botella representacional. Multi-head lo arregla.

La construcción

Elige \(H\) — el número de cabezas. Elecciones comunes: 4, 8, 12, 16. El Mini-GPT de la Fase 17 usa \(H = 4\).

Divide la dimensión de la cabeza:

\[ d_k^{\text{head}} = d_k / H, \qquad d_v^{\text{head}} = d_v / H \]

(Requisito: \(d_k, d_v\) divisibles por \(H\). Casi siempre se cumple porque todo el mundo elige potencias de 2.)

Para cada cabeza \(h = 1, \ldots, H\), aprende tres matrices de proyección:

\[ W_Q^{(h)} \in \mathbb{R}^{d \times d_k^{\text{head}}}, \quad W_K^{(h)} \in \mathbb{R}^{d \times d_k^{\text{head}}}, \quad W_V^{(h)} \in \mathbb{R}^{d \times d_v^{\text{head}}} \]

Cada cabeza produce:

\[ \text{head}^{(h)} = \text{Attention}(X W_Q^{(h)}, X W_K^{(h)}, X W_V^{(h)}) \in \mathbb{R}^{T \times d_v^{\text{head}}} \]

Concatena las cabezas a lo largo de la dimensión de características:

\[ \text{Concat} = [\text{head}^{(1)} ; \text{head}^{(2)} ; \ldots ; \text{head}^{(H)}] \in \mathbb{R}^{T \times d_v} \]

Aplica una proyección de salida final:

\[ \boxed{\; \text{MultiHead}(X) = \text{Concat} \cdot W_O \in \mathbb{R}^{T \times d} \;} \]

con \(W_O \in \mathbb{R}^{d_v \times d}\).

Eso es todo. Multi-head = \(H\) single-head attentions paralelas en subespacios más pequeños, pegadas.

Recuento de parámetros

Sea \(d = d_k = d_v\) (el caso estándar).

Single-head con dimensión completa \(d\): - \(W_Q, W_K, W_V\): \(3 d^2\) parámetros. - Total: \(3 d^2\).

Multi-head con \(H\) cabezas: - Por cabeza: \(W_Q^{(h)}, W_K^{(h)}, W_V^{(h)}\) cada una de tamaño \(d \times d/H\), así que \(3 d^2 / H\) por cabeza. - A través de \(H\) cabezas: \(3 d^2\). - Proyección de salida \(W_O\): \(d^2\). - Total: \(4 d^2\).

Multi-head tiene \(d^2\) parámetros más que single-head — ese es el coste de la proyección de salida. En la práctica, esto es un modesto 33% de aumento. La ganancia de expresividad es mucho mayor.

Truco común de implementación: en lugar de \(H\) matrices pequeñas separadas, almacena una matriz grande \(W_Q \in \mathbb{R}^{d \times d}\) y reshape a \((T, H, d/H)\) en tiempo de ejecución. Mismo recuento de parámetros, contabilidad más simple. Usamos este truco en src/minimodel/attention/.

Por qué multi-head supera a una cabeza ancha

Una pregunta natural: ¿por qué no usar simplemente single-head con \(d_k = d\) (en lugar de \(d_k = d/H\))?

Ambos tienen los mismos FLOPs (la dimensión por cabeza cae en \(H\), pero tienes \(H\) cabezas). Ambos tienen recuentos de parámetros similares (aparte de la diferencia de \(W_O\)).

Multi-head gana porque cada cabeza puede especializarse en un subespacio diferente. Con las cuatro cabezas de nuestro Mini-GPT, una especialización aspiracional (post-Fase-18) podría ser:

  • Cabeza 1: atiende al pronombre sujeto (para concordancia de persona). Al predecir una forma verbal, mira hacia atrás a I / you / he.
  • Cabeza 2: atiende a la última raíz verbal (para consistencia de tiempo/aspecto).
  • Cabeza 3: atiende al token inmediatamente anterior (para coherencia local — comas, conjunciones).
  • Cabeza 4: atiende al emparejamiento inglés↔español (para alineamiento de traducción).

Con una sola cabeza ancha, todos estos patrones tienen que ser expresados por la misma forma bilineal \(d \times d\) \(W_Q W_K^\top\). Deben ser compatibles — el modelo tiene que encontrar una matriz que puntúe bien todos los patrones simultáneamente.

Con multi-head, cada cabeza tiene su propio \(W_Q^{(h)} W_K^{(h),\top}\) — matrices de scoring independientes. Las cabezas pueden discrepar. La proyección de salida \(W_O\) decide cómo combinar sus salidas.

Empíricamente: multi-head supera a wide-single-head a igual recuento de parámetros, en cada benchmark, desde 2017. Esto es ahora un axioma arquitectónico.

Advertencia para la Fase 18: la lectura "la cabeza se especializa en X" es aspiracional. A nuestra escala de juguete, la especialización real aprendida es parcial y ruidosa — una historia limpia cabeza-por-cabeza es un tema de investigación, no un resultado garantizado. La Fase 18 visualiza los mapas de atención entrenados; describiremos lo que cada cabeza parece hacer sin exagerar la interpretabilidad.

Intuición de subespacios: en lugar de buscar una función de scoring que sirva para todo, multi-head busca \(H\) funciones de scoring independientes, cada una en un subespacio de \(d/H\) dimensiones. Cabezas distintas se especializan. El \(W_O\) final aprende cómo mezclar las especializaciones.

La proyección de salida \(W_O\)

La gente se salta \(W_O\) al explicar multi-head y es un bug real. Sin \(W_O\):

  • La salida es solo la concatenación de las cabezas.
  • La característica de la posición \(i\) en la salida es la concatenación de \(\text{head}^{(h)}_i\) a lo largo de \(h\).
  • Diferentes dimensiones de características en la salida vienen de diferentes cabezas — no pueden mezclarse.

Con \(W_O\):

  • Cada característica en la salida es una combinación lineal aprendida de las contribuciones de todas las cabezas en esa posición.
  • El modelo puede usar la salida de la cabeza 1 para modular la contribución de la cabeza 2, etc.
  • Las cabezas pueden comunicarse en la frontera de la capa.

Quitar \(W_O\) significaría que cada cabeza está forzada a producir su propia porción de la salida de la capa de forma independiente — estrictamente menos expresivo que la capa completa.

Conclusión: \(W_O\) no es opcional. Es el mecanismo que hace de multi-head una capa, no solo una concatenación.

Superficie de API (bloqueada para src/minimodel/attention/)

class MultiHeadAttention:
    def __init__(self, d_model: int, n_heads: int, seed: int = 0):
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        rng = np.random.default_rng(seed)
        # one big matrix per role, reshape to heads at runtime
        scale = 1.0 / np.sqrt(d_model)
        self.W_Q = rng.standard_normal((d_model, d_model)).astype(np.float32) * scale
        self.W_K = rng.standard_normal((d_model, d_model)).astype(np.float32) * scale
        self.W_V = rng.standard_normal((d_model, d_model)).astype(np.float32) * scale
        self.W_O = rng.standard_normal((d_model, d_model)).astype(np.float32) * scale

    def forward(self, x: np.ndarray, mask: np.ndarray | None = None) -> np.ndarray:
        # x: (T, d_model), mask: (T, T) additive or None
        # returns: (T, d_model)
        ...

El forward pass:

  1. Calcula Q, K, V = x @ W_Q, x @ W_K, x @ W_V — cada una (T, d_model).
  2. Reshape a multi-head: Q.reshape(T, H, d_head).transpose(1, 0, 2) — forma (H, T, d_head).
  3. Para cada cabeza de forma independiente (o vectorizada vía batched matmul):
  4. scores_h = Q_h @ K_h.T / sqrt(d_head) — forma (T, T).
  5. Aplicar máscara si se da.
  6. attn_h = softmax(scores_h) — forma (T, T).
  7. out_h = attn_h @ V_h — forma (T, d_head).
  8. Concatenar cabezas: out.transpose(1, 0, 2).reshape(T, d_model) — forma (T, d_model).
  9. Aplicar out @ W_O — forma (T, d_model).

En NumPy, los pasos 3a–3d pueden ser un único einsum sobre el eje de cabezas. En la Fase 17 esta es la formulación más limpia; en la Fase 15 Borja puede hacer el bucle explícito por claridad. Ambos están testeados.

Cross-attention (un párrafo, por completitud)

En modelos encoder-decoder (traducción, resumen), las capas decoder tienen dos sub-capas de atención:

  1. Self-attention sobre las salidas previas del propio decoder (causal).
  2. Cross-attention sobre las salidas del encoder.

La única diferencia para cross-attention: \(Q\) viene de los hidden states del decoder, mientras que \(K\) y \(V\) vienen de la salida del encoder. Misma ecuación por lo demás.

Los modelos decoder-only (familia GPT — lo que estamos construyendo) no tienen cross-attention. El Mini-GPT es decoder-only. Cross-attention está documentado aquí para que el término no sea misterioso; no lo implementaremos.

Nota sobre cross-attention: la dejamos fuera del currículo activo porque construimos un modelo decoder-only (estilo GPT). Si necesitas un modelo encoder-decoder (T5, BART), cross-attention es trivial: tres líneas más que self-attention, mismo mecanismo.

Lo que este archivo NO cubre

  • Máscara causal (causal mask). Próximo archivo (04-masking.md).
  • Visualización de patrones de atención entrenada. Fase 18. Aquí solo describimos la especialización aspiracional.
  • Head-pruning, grouped-query attention (GQA), multi-query attention (MQA). Optimizaciones en tiempo de inferencia cubiertas en la Fase 27.
  • Cross-attention más allá de la mención de un párrafo. Modelo decoder-only en este currículo.
  • Inicialización de \(W_O\). Fase 18 (entrenamiento); para verificación solo forward, los labs usan gaussiana pequeña.

Recapitulación

  • Multi-head = \(H\) single-head attentions paralelas en subespacios de dimensión \(d/H\), concatenadas + proyectadas.
  • Mismos FLOPs que single-head a dimensión completa; una matriz \(d \times d\) extra (\(W_O\)).
  • Cada cabeza puede especializarse; la proyección de salida las mezcla.
  • \(W_O\) no es opcional — es lo que permite que las cabezas se comuniquen.
  • Superficie de API bloqueada en BLUEPRINT.md. Constructor: (d_model, n_heads). Forward: (x, mask) -> y.

Siguiente: 04-masking.md.