English · Español
Lab 02 — Causal Mask¶
Goal: add causal masking to the
MultiHeadAttentionforward; verify via perturbation that position \(i\)'s output does NOT depend on position \(i+1\)'s input.Estimated time: 45–60 minutes.
Prereq: labs 00, 01 committed;
theory/04-masking.mdread.
What you produce¶
A directory experiments/15-causal-mask/ containing:
mask.py— causal-mask helper.verify.py— perturbation test script.verify_output.txt— captured printout.mask_visual.png— heatmap of the causal mask + an attention matrix with the mask applied (side by side).manifest.json.README.md.
Background¶
The causal mask: M[i, j] = 0 if j <= i else -inf. Applied additively pre-softmax. Theory file 04 is the reference.
The perturbation test is the standard way to verify a causal mask in practice. It's much more convincing than reading the code:
- Run the model on input \(X\), capture output \(Y\).
- Run the model on \(X'\) where the last token differs, capture output \(Y'\).
- Assert that \(Y[0..T-1]\) matches \(Y'[0..T-1]\) for all positions before the last.
If the mask works, the perturbation cannot propagate backward through time. If positions before \(T-1\) differ between \(Y\) and \(Y'\), the mask is broken.
TODOs¶
Block A — implement the mask helper¶
- In
src/minimodel/attention/attention.py, addcausal_mask(T: int, dtype=np.float32) -> np.ndarray. - Returns a \(T \times T\) matrix with zeros on/below the diagonal,
-1e9above. - Use
np.triu(np.ones(...), k=1) * -1e9. One line of body. - Unit test: for \(T = 4\), the mask should look like
Block B — wire it into forward¶
- Update
MultiHeadAttention.forward(x, mask=None): - If
maskis given, add it toscoresbefore softmax. - The shape contract:
maskis(T, T)and broadcasts across the head dimension. - Re-run lab 01 with
mask=Noneto confirm no regression (lab 01 assertions still pass).
Block C — the perturbation test¶
In verify.py:
- Construct
mha = MultiHeadAttention(d_model=16, n_heads=2, seed=0). - Build \(X\) of shape
(T=8, 16)with seeded random values. - Build \(X'\) =
X.copy(), then setX'[7] = random new vector(different seed). - Build
mask = causal_mask(8). - Compute
Y = mha.forward(X, mask=mask)andY' = mha.forward(X', mask=mask). - Assert
np.allclose(Y[0:7], Y'[0:7], atol=1e-6). (Positions 0..6 must be identical between Y and Y' — the change at position 7 cannot propagate backward.) - Assert
not np.allclose(Y[7], Y'[7], atol=1e-3). (Position 7 must change, since its own input changed.) - Print pass/fail and the per-position max-diff. Capture to
verify_output.txt.
Block D — the failure mode (intentional break)¶
Verify your understanding by breaking the mask intentionally and watching the test fail:
- Make a copy of the test where the mask is applied multiplicatively post-softmax (the wrong way; see
theory/04-masking.md). I.e., compute attention as - Re-run the perturbation test on this broken version.
- Confirm it FAILS — earlier positions in \(Y'\) now differ from \(Y\) because the gradient/output has leaked through the softmax normalization.
- Capture this output too. Note in
README.md.
Block E — visualize¶
- Two subplots side-by-side:
- Left: the causal mask itself (visualize 0 as white, -1e9 as black).
- Right: the attention matrix \(A\) after applying the mask to a random scores matrix (visualize 0 as white, 1 as dark).
- Both should be lower-triangular in shape, with the right plot's row sums equal to 1 (since softmax was applied).
- Save as
mask_visual.png.
Block F — write up¶
In README.md, answer:
- Why does multiplicative-post-softmax masking fail the perturbation test? Two-sentence answer; refer to
theory/04-masking.md§"The critical mistake". - Why is the row-sum of the attention matrix exactly 1, even after masking? (Hint: the softmax normalizes whatever survives. Forbidden positions get exactly 0, so the remaining probabilities sum to 1.)
Block G — manifest¶
{
"experiment": "15-causal-mask",
"date": "YYYY-MM-DD",
"seed": 0,
"versions": { "python": "3.11.x", "numpy": "X.Y.Z" },
"config": {
"T": 8,
"d_model": 16,
"n_heads": 2
},
"results_summary": {
"correct_perturbation_max_diff_positions_0_6": null,
"broken_perturbation_max_diff_positions_0_6": null
}
}
The correct version's max diff should be < 1e-6. The broken version's should be > 1e-3 (broken).
Constraints¶
- No PyTorch.
-1e9, not-np.inf. Some numpy reductions are surprised by inf; large finite negative is safer.- Test must be deterministic. Seed
mha, seed \(X\) and \(X'\).
Stop conditions¶
Done when:
- All six files committed.
- Correct-version perturbation test passes (max diff < 1e-6 on positions 0..6).
- Broken-version perturbation test demonstrates the failure (max diff > 1e-3).
README.mdanswers both Block F questions.
Pitfalls¶
- Off-by-one. Position \(i\) attends to positions \(0, \ldots, i\) inclusive. If you used
k=0innp.triuyou'd zero out the diagonal — wrong (then position \(i\) couldn't even attend to itself, only to past). - Shape mismatch.
maskis(T, T). It must broadcast across the head dim ofscores(which is(H, T, T)). NumPy does this automatically. - Mask used during inference too. Causal masking is needed at both train and inference for an autoregressive decoder. Don't disable it at inference.
When to consult solutions/¶
After all six files committed and both correctness/failure tests behave as expected. Solution at solutions/02-causal-mask-ref.md.
Next lab: 03-attention-perf.md.