English · Español
03 — Multi-Head Attention¶
🇪🇸 Multi-head attention = correr varias attentions en paralelo, cada una en un subespacio de dimensión más pequeña, y concatenar las salidas. Con el mismo presupuesto de parámetros que una single-head, multi-head puede atender simultáneamente a varias relaciones distintas (una cabeza al sujeto-verbo, otra al sustantivo-adjetivo, otra al cierre de paréntesis). Es la formulación que ganó.
This file derives multi-head attention, explains why it's parameter-equivalent to single-head but more expressive, and locks the API surface that downstream phases depend on.
The single-head limitation¶
Single-head attention computes one \(T \times T\) attention pattern per layer. That pattern is one bilinear scoring function \(x_i^\top (W_Q W_K^\top) x_j\). The model is limited to one notion of similarity per layer.
But real language has many kinds of dependencies, all relevant in parallel. For our verb-grammar scope (§A13), even the toy completion I work, you work, he ___ involves multiple parallel relations:
- Person agreement. The verb at position 7 must look at the subject pronoun (position 6,
he) to pick-svs-form. - Tense identification. The verb at position 7 must look at the prior verb tokens (positions 1, 4 — both
work) to pick the present-simple paradigm. - English↔Spanish alignment. When predicting Spanish translations (
yo trabajo), the verb position must align with both the English token and the Spanish pronoun. - Positional / locality. Some heads just look at the immediately-preceding token.
Forcing all of these through one bilinear form is a representational bottleneck. Multi-head fixes it.
The construction¶
Pick \(H\) — the number of heads. Common choices: 4, 8, 12, 16. Phase 17's Mini-GPT uses \(H = 4\).
Split the head dimension:
(Requirement: \(d_k, d_v\) divisible by \(H\). Almost always satisfied because everyone picks powers of 2.)
For each head \(h = 1, \ldots, H\), learn three projection matrices:
Each head produces:
Concatenate the heads along the feature dimension:
Apply a final output projection:
with \(W_O \in \mathbb{R}^{d_v \times d}\).
That's it. Multi-head = \(H\) parallel single-head attentions in smaller subspaces, glued together.
Parameter count¶
Let \(d = d_k = d_v\) (the standard case).
Single-head with full dimension \(d\): - \(W_Q, W_K, W_V\): \(3 d^2\) parameters. - Total: \(3 d^2\).
Multi-head with \(H\) heads: - Per head: \(W_Q^{(h)}, W_K^{(h)}, W_V^{(h)}\) each of size \(d \times d/H\), so \(3 d^2 / H\) per head. - Across \(H\) heads: \(3 d^2\). - Output projection \(W_O\): \(d^2\). - Total: \(4 d^2\).
Multi-head has \(d^2\) more parameters than single-head — that's the cost of the output projection. In practice, this is a modest 33% increase. The expressiveness gain is much larger.
Common implementation trick: instead of \(H\) separate small matrices, store one big matrix \(W_Q \in \mathbb{R}^{d \times d}\) and reshape to \((T, H, d/H)\) at runtime. Same parameter count, simpler bookkeeping. We use this trick in src/minimodel/attention/.
Why multi-head beats one wide head¶
A natural question: why not just use single-head with \(d_k = d\) (instead of \(d_k = d/H\))?
Both have the same FLOPs (the per-head dimension drops by \(H\), but you have \(H\) heads). Both have similar parameter counts (the \(W_O\) difference aside).
Multi-head wins because each head can specialize in a different subspace. With the four heads of our Mini-GPT, an aspirational (post-Phase-18) specialization might be:
- Head 1: attend to the subject pronoun (for person agreement). When predicting a verb form, look back at
I/you/he. - Head 2: attend to the last verb stem (for tense/aspect consistency).
- Head 3: attend to the immediately-preceding token (for local coherence — commas, conjunctions).
- Head 4: attend to the English↔Spanish pairing (for translation alignment).
With a single wide head, all these patterns have to be expressed by the same \(d \times d\) bilinear form \(W_Q W_K^\top\). They must be compatible — the model has to find one matrix that scores well on all the patterns simultaneously.
With multi-head, each head has its own \(W_Q^{(h)} W_K^{(h),\top}\) — independent scoring matrices. Heads can disagree. The output projection \(W_O\) decides how to combine their outputs.
Empirically: multi-head outperforms wide-single-head at equal parameter count, across every benchmark, since 2017. This is now an architectural axiom.
Caveat for Phase 18: the "head specializes in X" reading is aspirational. At our toy scale, the actual learned specialization is partial and noisy — a clean head-by-head story is a research topic, not a guaranteed outcome. Phase 18 visualizes the trained attention maps; we will describe what each head appears to do without overstating the interpretability.
🇪🇸 Intuición de subespacios: en lugar de buscar una función de scoring que sirva para todo, multi-head busca \(H\) funciones de scoring independientes, cada una en un subespacio de \(d/H\) dimensiones. Cabezas distintas se especializan. El \(W_O\) final aprende cómo mezclar las especializaciones.
The output projection \(W_O\)¶
People skip \(W_O\) when explaining multi-head and it's a real bug. Without \(W_O\):
- The output is just the concatenation of heads.
- Position \(i\)'s feature in the output is the concatenation of \(\text{head}^{(h)}_i\) across \(h\).
- Different feature dimensions in the output come from different heads — they cannot mix.
With \(W_O\):
- Every feature in the output is a learned linear combination of all heads' contributions at that position.
- The model can use head 1's output to modulate head 2's contribution, etc.
- Heads can communicate at the layer boundary.
Removing \(W_O\) would mean every head is forced to produce its own slice of the layer's output independently — strictly less expressive than the full layer.
Conclusion: \(W_O\) is not optional. It is the mechanism that makes multi-head a layer, not just a concatenation.
API surface (locked for src/minimodel/attention/)¶
class MultiHeadAttention:
def __init__(self, d_model: int, n_heads: int, seed: int = 0):
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_model // n_heads
rng = np.random.default_rng(seed)
# one big matrix per role, reshape to heads at runtime
scale = 1.0 / np.sqrt(d_model)
self.W_Q = rng.standard_normal((d_model, d_model)).astype(np.float32) * scale
self.W_K = rng.standard_normal((d_model, d_model)).astype(np.float32) * scale
self.W_V = rng.standard_normal((d_model, d_model)).astype(np.float32) * scale
self.W_O = rng.standard_normal((d_model, d_model)).astype(np.float32) * scale
def forward(self, x: np.ndarray, mask: np.ndarray | None = None) -> np.ndarray:
# x: (T, d_model), mask: (T, T) additive or None
# returns: (T, d_model)
...
The forward pass:
- Compute
Q, K, V = x @ W_Q, x @ W_K, x @ W_V— each(T, d_model). - Reshape to multi-head:
Q.reshape(T, H, d_head).transpose(1, 0, 2)— shape(H, T, d_head). - For each head independently (or vectorized via batched matmul):
scores_h = Q_h @ K_h.T / sqrt(d_head)— shape(T, T).- Apply mask if given.
attn_h = softmax(scores_h)— shape(T, T).out_h = attn_h @ V_h— shape(T, d_head).- Concatenate heads:
out.transpose(1, 0, 2).reshape(T, d_model)— shape(T, d_model). - Apply
out @ W_O— shape(T, d_model).
In NumPy, steps 3a–3d can be a single einsum over the head axis. In Phase 17 this is the cleaner formulation; in Phase 15 Borja can do the explicit loop for clarity. Both are tested.
Cross-attention (one paragraph, for completeness)¶
In encoder-decoder models (translation, summarization), the decoder layers have two attention sub-layers:
- Self-attention over the decoder's own previous outputs (causal).
- Cross-attention over the encoder's outputs.
The only difference for cross-attention: \(Q\) comes from the decoder's hidden states, while \(K\) and \(V\) come from the encoder's output. Same equation otherwise.
Decoder-only models (GPT family — what we're building) don't have cross-attention. The Mini-GPT is decoder-only. Cross-attention is documented here so the term isn't mysterious; we won't implement it.
🇪🇸 Nota sobre cross-attention: la dejamos fuera del currículo activo porque construimos un modelo decoder-only (estilo GPT). Si necesitas un modelo encoder-decoder (T5, BART), cross-attention es trivial: tres líneas más que self-attention, mismo mecanismo.
What this file does NOT cover¶
- Causal masking. Next file (
04-masking.md). - Trained-attention pattern visualization. Phase 18. Here we only describe the aspirational specialization.
- Head-pruning, grouped-query attention (GQA), multi-query attention (MQA). Inference-time optimizations covered in Phase 27.
- Cross-attention beyond the one-paragraph mention. Decoder-only model in this curriculum.
- Initialization of \(W_O\). Phase 18 (training); for forward-only verification, the labs use small Gaussian.
Recap¶
- Multi-head = \(H\) parallel single-head attentions in \(d/H\)-dimensional subspaces, concatenated + projected.
- Same FLOPs as single-head at full dimension; one extra \(d \times d\) matrix (\(W_O\)).
- Each head can specialize; the output projection mixes them.
- \(W_O\) is not optional — it's what makes heads able to communicate.
- API surface locked in
BLUEPRINT.md. Constructor:(d_model, n_heads). Forward:(x, mask) -> y.
Next: 04-masking.md.