Skip to content

English · Español

Lab 02 — Causal Mask

Goal: add causal masking to the MultiHeadAttention forward; verify via perturbation that position \(i\)'s output does NOT depend on position \(i+1\)'s input.

Estimated time: 45–60 minutes.

Prereq: labs 00, 01 committed; theory/04-masking.md read.


What you produce

A directory experiments/15-causal-mask/ containing:

  • mask.py — causal-mask helper.
  • verify.py — perturbation test script.
  • verify_output.txt — captured printout.
  • mask_visual.png — heatmap of the causal mask + an attention matrix with the mask applied (side by side).
  • manifest.json.
  • README.md.

Background

The causal mask: M[i, j] = 0 if j <= i else -inf. Applied additively pre-softmax. Theory file 04 is the reference.

The perturbation test is the standard way to verify a causal mask in practice. It's much more convincing than reading the code:

  1. Run the model on input \(X\), capture output \(Y\).
  2. Run the model on \(X'\) where the last token differs, capture output \(Y'\).
  3. Assert that \(Y[0..T-1]\) matches \(Y'[0..T-1]\) for all positions before the last.

If the mask works, the perturbation cannot propagate backward through time. If positions before \(T-1\) differ between \(Y\) and \(Y'\), the mask is broken.

TODOs

Block A — implement the mask helper

  • In src/minimodel/attention/attention.py, add causal_mask(T: int, dtype=np.float32) -> np.ndarray.
  • Returns a \(T \times T\) matrix with zeros on/below the diagonal, -1e9 above.
  • Use np.triu(np.ones(...), k=1) * -1e9. One line of body.
  • Unit test: for \(T = 4\), the mask should look like
    [[ 0, -1e9, -1e9, -1e9],
     [ 0,  0,   -1e9, -1e9],
     [ 0,  0,    0,   -1e9],
     [ 0,  0,    0,    0  ]]
    

Block B — wire it into forward

  • Update MultiHeadAttention.forward(x, mask=None):
  • If mask is given, add it to scores before softmax.
  • The shape contract: mask is (T, T) and broadcasts across the head dimension.
  • Re-run lab 01 with mask=None to confirm no regression (lab 01 assertions still pass).

Block C — the perturbation test

In verify.py:

  • Construct mha = MultiHeadAttention(d_model=16, n_heads=2, seed=0).
  • Build \(X\) of shape (T=8, 16) with seeded random values.
  • Build \(X'\) = X.copy(), then set X'[7] = random new vector (different seed).
  • Build mask = causal_mask(8).
  • Compute Y = mha.forward(X, mask=mask) and Y' = mha.forward(X', mask=mask).
  • Assert np.allclose(Y[0:7], Y'[0:7], atol=1e-6). (Positions 0..6 must be identical between Y and Y' — the change at position 7 cannot propagate backward.)
  • Assert not np.allclose(Y[7], Y'[7], atol=1e-3). (Position 7 must change, since its own input changed.)
  • Print pass/fail and the per-position max-diff. Capture to verify_output.txt.

Block D — the failure mode (intentional break)

Verify your understanding by breaking the mask intentionally and watching the test fail:

  • Make a copy of the test where the mask is applied multiplicatively post-softmax (the wrong way; see theory/04-masking.md). I.e., compute attention as
    attn = softmax(scores)
    attn = attn * (mask >= 0)  # zeros out forbidden positions, but post-softmax
    out = attn @ V
    
  • Re-run the perturbation test on this broken version.
  • Confirm it FAILS — earlier positions in \(Y'\) now differ from \(Y\) because the gradient/output has leaked through the softmax normalization.
  • Capture this output too. Note in README.md.

Block E — visualize

  • Two subplots side-by-side:
  • Left: the causal mask itself (visualize 0 as white, -1e9 as black).
  • Right: the attention matrix \(A\) after applying the mask to a random scores matrix (visualize 0 as white, 1 as dark).
  • Both should be lower-triangular in shape, with the right plot's row sums equal to 1 (since softmax was applied).
  • Save as mask_visual.png.

Block F — write up

In README.md, answer:

  1. Why does multiplicative-post-softmax masking fail the perturbation test? Two-sentence answer; refer to theory/04-masking.md §"The critical mistake".
  2. Why is the row-sum of the attention matrix exactly 1, even after masking? (Hint: the softmax normalizes whatever survives. Forbidden positions get exactly 0, so the remaining probabilities sum to 1.)

Block G — manifest

{
  "experiment": "15-causal-mask",
  "date": "YYYY-MM-DD",
  "seed": 0,
  "versions": { "python": "3.11.x", "numpy": "X.Y.Z" },
  "config": {
    "T": 8,
    "d_model": 16,
    "n_heads": 2
  },
  "results_summary": {
    "correct_perturbation_max_diff_positions_0_6": null,
    "broken_perturbation_max_diff_positions_0_6": null
  }
}

The correct version's max diff should be < 1e-6. The broken version's should be > 1e-3 (broken).

Constraints

  • No PyTorch.
  • -1e9, not -np.inf. Some numpy reductions are surprised by inf; large finite negative is safer.
  • Test must be deterministic. Seed mha, seed \(X\) and \(X'\).

Stop conditions

Done when:

  1. All six files committed.
  2. Correct-version perturbation test passes (max diff < 1e-6 on positions 0..6).
  3. Broken-version perturbation test demonstrates the failure (max diff > 1e-3).
  4. README.md answers both Block F questions.

Pitfalls

  • Off-by-one. Position \(i\) attends to positions \(0, \ldots, i\) inclusive. If you used k=0 in np.triu you'd zero out the diagonal — wrong (then position \(i\) couldn't even attend to itself, only to past).
  • Shape mismatch. mask is (T, T). It must broadcast across the head dim of scores (which is (H, T, T)). NumPy does this automatically.
  • Mask used during inference too. Causal masking is needed at both train and inference for an autoregressive decoder. Don't disable it at inference.

When to consult solutions/

After all six files committed and both correctness/failure tests behave as expected. Solution at solutions/02-causal-mask-ref.md.


Next lab: 03-attention-perf.md.