English · Español
Lab 00 — Mask Logits Against a Regex¶
Goal: implement the simplest possible logit mask and verify it produces a constrained distribution.
Estimated time: 60–90 minutes.
Prereq: Phase 21's sampler (
src/miniinfer/generate.py) committed. Phase 17's MiniGPT importable.
What you produce¶
A directory experiments/30-regex-mask/ containing:
mask.py— your firstLogitMasksubclass: matches a fixed regex (digits-only, length 4).test_mask.py— pytest cases verifying mask behavior.bench.py— a tiny driver that runsgenerate(prompt, mask=DigitMask(length=4))and asserts every output matches^\d{4}$.results.json—{ "n_samples": ..., "all_match_regex": true, "wall_time_s": ... }.manifest.json— perLYNX_CORTEX.md§5.README.md— what you measured.
Plus, the new module:
src/ministruct/mask.py— theLogitMaskABC +DigitMaskconcrete class. This is the first file insrc/ministruct/; the BLUEPRINT atsrc/ministruct/BLUEPRINT.md(pre-written by Claude) is the design source of truth.
The kernel¶
Pick the smallest possible grammar: "exactly four ASCII digits followed by EOS". Mask everything else.
This is a warm-up. The point is to (a) understand the API contract you'll generalize in lab 01, (b) verify that masking actually produces digit-only outputs on the trained MiniGPT (or — if the model wasn't trained to emit digits — at least produces the model's preference among digits).
🇪🇸 Empezamos con dígitos en vez de con la gramática de conjugaciones porque el conjunto de tokens legales se identifica trivialmente — son los IDs de tokens
'0'..'9'y EOS. El lab 01 generaliza esto a un esquema JSON real.
TODOs¶
Block A — LogitMask interface¶
- In
src/ministruct/mask.py, declare an ABC matching the signature insrc/ministruct/BLUEPRINT.md: -
step(None)returns the mask for the first token. Subsequent calls pass the token just emitted; the mask returned is for the next token. -
is_done()returns True once the grammar accepts the current sequence as complete (e.g., 4 digits emitted, ready to terminate).
Block B — DigitMask concrete class¶
- Constructor:
DigitMask(tokenizer, length=4). - Internal state:
count_emitted(how many digits so far). -
step(last_token_id): - If
count_emitted < length: return mask that allows only tokens whose decoded string is one of'0'..'9'. Determine these token IDs once at construction by enumerating the tokenizer. - If
count_emitted == length: return mask that allows only the EOS token. - Update
count_emittedbased onlast_token_id.
Block C — wire into the decoder¶
- In
src/miniinfer/generate.py(or wherever Phase 21'sgeneratelives), add amask: LogitMask | None = Noneparameter. - At each step: call
mask.step(last_token)to get the mask array, add to logits before sampling. - After sampling: pass the chosen token back to the mask next iteration.
- Terminate when
mask.is_done()is True (or when the sampler picks EOS, or whenmax_new_tokensis hit).
Block D — verify¶
- Run
bench.py: generate 100 samples withDigitMask(length=4). Assert every output regex-matches^\d{4}$. - Run a control: generate 100 samples with
mask=None. Most outputs will NOT match the regex. (This is the existence proof that the mask is doing work.) - Record both in
results.json.
Block E — tests¶
In tests/test_ministruct_mask.py:
-
test_digit_mask_first_step_only_digits— mask returned bystep(None)has0.0exactly at the digit-token indices,-infeverywhere else. -
test_digit_mask_after_n_digits_only_eos— afterlengthdigits emitted, mask allows only EOS. -
test_digit_mask_is_done_after_full_length—is_done()flips True at the right moment. -
test_permissive_mask_no_op— a mask that allows all tokens produces samples identical tomask=Noneunder the same RNG seed (the sanity check fromtheory/02-logit-masks.md). -
test_empty_legal_raises— a mask that returns all-infcauses the decoder to raiseNoLegalContinuationrather than silently producing NaN.
Constraints¶
- No
outlines, nolm-format-enforcer, nojsonschema. This is the build before abstracting phase. - Mask must be NumPy array, not a list or a dict. It's added to logits which are NumPy.
- EOS handling. The tokenizer's EOS token id is known; read it from the tokenizer's API. Don't guess.
Stop conditions¶
You're done when:
- All five tests in Block E pass.
bench.pyreportsall_match_regex: trueover 100 samples.README.mdanswers: "what is the KL divergence between the constrained and unconstrained distribution at step 0? Is it small (model wanted to emit a digit anyway) or large (model wanted to emit text)?"- Manifest committed.
Pitfalls¶
- Tokenizer "0" might not be one token. If
' 0'(space-prefixed) is the actual token, your set of digit-token-ids is wrong. Print the tokenizer's encoding of each digit and verify. - Mask of correct shape but wrong dtype. Mask must be float (so
-infis representable). An integer mask of{0, -2**31}doesn't behave correctly under softmax. - Forgot to reset mask state between samples. If you generate two samples back-to-back without
mask.reset(), the second sample starts withcount_emitted == lengthand only EOS is legal. Test this. - Off-by-one in
count_emitted. Easy to update before vs after sampling. Either works; pick one and stick with it.
When to consult solutions/¶
After your tests pass and your bench script reports clean results. The solution lives in solutions/00-regex-mask-ref.md — written at phase open, not pre-written, because it depends on what Borja's actual tokenizer reports. Compare; don't pre-read.
Next lab: lab/01-json-schema-mask.md.