Skip to content

English · Español

03 — Multi-Head Attention

🇪🇸 Multi-head attention = correr varias attentions en paralelo, cada una en un subespacio de dimensión más pequeña, y concatenar las salidas. Con el mismo presupuesto de parámetros que una single-head, multi-head puede atender simultáneamente a varias relaciones distintas (una cabeza al sujeto-verbo, otra al sustantivo-adjetivo, otra al cierre de paréntesis). Es la formulación que ganó.

This file derives multi-head attention, explains why it's parameter-equivalent to single-head but more expressive, and locks the API surface that downstream phases depend on.


The single-head limitation

Single-head attention computes one \(T \times T\) attention pattern per layer. That pattern is one bilinear scoring function \(x_i^\top (W_Q W_K^\top) x_j\). The model is limited to one notion of similarity per layer.

But real language has many kinds of dependencies, all relevant in parallel. For our verb-grammar scope (§A13), even the toy completion I work, you work, he ___ involves multiple parallel relations:

  • Person agreement. The verb at position 7 must look at the subject pronoun (position 6, he) to pick -s vs - form.
  • Tense identification. The verb at position 7 must look at the prior verb tokens (positions 1, 4 — both work) to pick the present-simple paradigm.
  • English↔Spanish alignment. When predicting Spanish translations (yo trabajo), the verb position must align with both the English token and the Spanish pronoun.
  • Positional / locality. Some heads just look at the immediately-preceding token.

Forcing all of these through one bilinear form is a representational bottleneck. Multi-head fixes it.

The construction

Pick \(H\) — the number of heads. Common choices: 4, 8, 12, 16. Phase 17's Mini-GPT uses \(H = 4\).

Split the head dimension:

\[ d_k^{\text{head}} = d_k / H, \qquad d_v^{\text{head}} = d_v / H \]

(Requirement: \(d_k, d_v\) divisible by \(H\). Almost always satisfied because everyone picks powers of 2.)

For each head \(h = 1, \ldots, H\), learn three projection matrices:

\[ W_Q^{(h)} \in \mathbb{R}^{d \times d_k^{\text{head}}}, \quad W_K^{(h)} \in \mathbb{R}^{d \times d_k^{\text{head}}}, \quad W_V^{(h)} \in \mathbb{R}^{d \times d_v^{\text{head}}} \]

Each head produces:

\[ \text{head}^{(h)} = \text{Attention}(X W_Q^{(h)}, X W_K^{(h)}, X W_V^{(h)}) \in \mathbb{R}^{T \times d_v^{\text{head}}} \]

Concatenate the heads along the feature dimension:

\[ \text{Concat} = [\text{head}^{(1)} ; \text{head}^{(2)} ; \ldots ; \text{head}^{(H)}] \in \mathbb{R}^{T \times d_v} \]

Apply a final output projection:

\[ \boxed{\; \text{MultiHead}(X) = \text{Concat} \cdot W_O \in \mathbb{R}^{T \times d} \;} \]

with \(W_O \in \mathbb{R}^{d_v \times d}\).

That's it. Multi-head = \(H\) parallel single-head attentions in smaller subspaces, glued together.

Parameter count

Let \(d = d_k = d_v\) (the standard case).

Single-head with full dimension \(d\): - \(W_Q, W_K, W_V\): \(3 d^2\) parameters. - Total: \(3 d^2\).

Multi-head with \(H\) heads: - Per head: \(W_Q^{(h)}, W_K^{(h)}, W_V^{(h)}\) each of size \(d \times d/H\), so \(3 d^2 / H\) per head. - Across \(H\) heads: \(3 d^2\). - Output projection \(W_O\): \(d^2\). - Total: \(4 d^2\).

Multi-head has \(d^2\) more parameters than single-head — that's the cost of the output projection. In practice, this is a modest 33% increase. The expressiveness gain is much larger.

Common implementation trick: instead of \(H\) separate small matrices, store one big matrix \(W_Q \in \mathbb{R}^{d \times d}\) and reshape to \((T, H, d/H)\) at runtime. Same parameter count, simpler bookkeeping. We use this trick in src/minimodel/attention/.

Why multi-head beats one wide head

A natural question: why not just use single-head with \(d_k = d\) (instead of \(d_k = d/H\))?

Both have the same FLOPs (the per-head dimension drops by \(H\), but you have \(H\) heads). Both have similar parameter counts (the \(W_O\) difference aside).

Multi-head wins because each head can specialize in a different subspace. With the four heads of our Mini-GPT, an aspirational (post-Phase-18) specialization might be:

  • Head 1: attend to the subject pronoun (for person agreement). When predicting a verb form, look back at I / you / he.
  • Head 2: attend to the last verb stem (for tense/aspect consistency).
  • Head 3: attend to the immediately-preceding token (for local coherence — commas, conjunctions).
  • Head 4: attend to the English↔Spanish pairing (for translation alignment).

With a single wide head, all these patterns have to be expressed by the same \(d \times d\) bilinear form \(W_Q W_K^\top\). They must be compatible — the model has to find one matrix that scores well on all the patterns simultaneously.

With multi-head, each head has its own \(W_Q^{(h)} W_K^{(h),\top}\) — independent scoring matrices. Heads can disagree. The output projection \(W_O\) decides how to combine their outputs.

Empirically: multi-head outperforms wide-single-head at equal parameter count, across every benchmark, since 2017. This is now an architectural axiom.

Caveat for Phase 18: the "head specializes in X" reading is aspirational. At our toy scale, the actual learned specialization is partial and noisy — a clean head-by-head story is a research topic, not a guaranteed outcome. Phase 18 visualizes the trained attention maps; we will describe what each head appears to do without overstating the interpretability.

🇪🇸 Intuición de subespacios: en lugar de buscar una función de scoring que sirva para todo, multi-head busca \(H\) funciones de scoring independientes, cada una en un subespacio de \(d/H\) dimensiones. Cabezas distintas se especializan. El \(W_O\) final aprende cómo mezclar las especializaciones.

The output projection \(W_O\)

People skip \(W_O\) when explaining multi-head and it's a real bug. Without \(W_O\):

  • The output is just the concatenation of heads.
  • Position \(i\)'s feature in the output is the concatenation of \(\text{head}^{(h)}_i\) across \(h\).
  • Different feature dimensions in the output come from different heads — they cannot mix.

With \(W_O\):

  • Every feature in the output is a learned linear combination of all heads' contributions at that position.
  • The model can use head 1's output to modulate head 2's contribution, etc.
  • Heads can communicate at the layer boundary.

Removing \(W_O\) would mean every head is forced to produce its own slice of the layer's output independently — strictly less expressive than the full layer.

Conclusion: \(W_O\) is not optional. It is the mechanism that makes multi-head a layer, not just a concatenation.

API surface (locked for src/minimodel/attention/)

class MultiHeadAttention:
    def __init__(self, d_model: int, n_heads: int, seed: int = 0):
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        rng = np.random.default_rng(seed)
        # one big matrix per role, reshape to heads at runtime
        scale = 1.0 / np.sqrt(d_model)
        self.W_Q = rng.standard_normal((d_model, d_model)).astype(np.float32) * scale
        self.W_K = rng.standard_normal((d_model, d_model)).astype(np.float32) * scale
        self.W_V = rng.standard_normal((d_model, d_model)).astype(np.float32) * scale
        self.W_O = rng.standard_normal((d_model, d_model)).astype(np.float32) * scale

    def forward(self, x: np.ndarray, mask: np.ndarray | None = None) -> np.ndarray:
        # x: (T, d_model), mask: (T, T) additive or None
        # returns: (T, d_model)
        ...

The forward pass:

  1. Compute Q, K, V = x @ W_Q, x @ W_K, x @ W_V — each (T, d_model).
  2. Reshape to multi-head: Q.reshape(T, H, d_head).transpose(1, 0, 2) — shape (H, T, d_head).
  3. For each head independently (or vectorized via batched matmul):
  4. scores_h = Q_h @ K_h.T / sqrt(d_head) — shape (T, T).
  5. Apply mask if given.
  6. attn_h = softmax(scores_h) — shape (T, T).
  7. out_h = attn_h @ V_h — shape (T, d_head).
  8. Concatenate heads: out.transpose(1, 0, 2).reshape(T, d_model) — shape (T, d_model).
  9. Apply out @ W_O — shape (T, d_model).

In NumPy, steps 3a–3d can be a single einsum over the head axis. In Phase 17 this is the cleaner formulation; in Phase 15 Borja can do the explicit loop for clarity. Both are tested.

Cross-attention (one paragraph, for completeness)

In encoder-decoder models (translation, summarization), the decoder layers have two attention sub-layers:

  1. Self-attention over the decoder's own previous outputs (causal).
  2. Cross-attention over the encoder's outputs.

The only difference for cross-attention: \(Q\) comes from the decoder's hidden states, while \(K\) and \(V\) come from the encoder's output. Same equation otherwise.

Decoder-only models (GPT family — what we're building) don't have cross-attention. The Mini-GPT is decoder-only. Cross-attention is documented here so the term isn't mysterious; we won't implement it.

🇪🇸 Nota sobre cross-attention: la dejamos fuera del currículo activo porque construimos un modelo decoder-only (estilo GPT). Si necesitas un modelo encoder-decoder (T5, BART), cross-attention es trivial: tres líneas más que self-attention, mismo mecanismo.

What this file does NOT cover

  • Causal masking. Next file (04-masking.md).
  • Trained-attention pattern visualization. Phase 18. Here we only describe the aspirational specialization.
  • Head-pruning, grouped-query attention (GQA), multi-query attention (MQA). Inference-time optimizations covered in Phase 27.
  • Cross-attention beyond the one-paragraph mention. Decoder-only model in this curriculum.
  • Initialization of \(W_O\). Phase 18 (training); for forward-only verification, the labs use small Gaussian.

Recap

  • Multi-head = \(H\) parallel single-head attentions in \(d/H\)-dimensional subspaces, concatenated + projected.
  • Same FLOPs as single-head at full dimension; one extra \(d \times d\) matrix (\(W_O\)).
  • Each head can specialize; the output projection mixes them.
  • \(W_O\) is not optional — it's what makes heads able to communicate.
  • API surface locked in BLUEPRINT.md. Constructor: (d_model, n_heads). Forward: (x, mask) -> y.

Next: 04-masking.md.