Skip to content

English · Español

Lab 00 — Mask Logits Against a Regex

Goal: implement the simplest possible logit mask and verify it produces a constrained distribution.

Estimated time: 60–90 minutes.

Prereq: Phase 21's sampler (src/miniinfer/generate.py) committed. Phase 17's MiniGPT importable.


What you produce

A directory experiments/30-regex-mask/ containing:

  • mask.py — your first LogitMask subclass: matches a fixed regex (digits-only, length 4).
  • test_mask.py — pytest cases verifying mask behavior.
  • bench.py — a tiny driver that runs generate(prompt, mask=DigitMask(length=4)) and asserts every output matches ^\d{4}$.
  • results.json{ "n_samples": ..., "all_match_regex": true, "wall_time_s": ... }.
  • manifest.json — per LYNX_CORTEX.md §5.
  • README.md — what you measured.

Plus, the new module:

  • src/ministruct/mask.py — the LogitMask ABC + DigitMask concrete class. This is the first file in src/ministruct/; the BLUEPRINT at src/ministruct/BLUEPRINT.md (pre-written by Claude) is the design source of truth.

The kernel

Pick the smallest possible grammar: "exactly four ASCII digits followed by EOS". Mask everything else.

This is a warm-up. The point is to (a) understand the API contract you'll generalize in lab 01, (b) verify that masking actually produces digit-only outputs on the trained MiniGPT (or — if the model wasn't trained to emit digits — at least produces the model's preference among digits).

🇪🇸 Empezamos con dígitos en vez de con la gramática de conjugaciones porque el conjunto de tokens legales se identifica trivialmente — son los IDs de tokens '0'..'9' y EOS. El lab 01 generaliza esto a un esquema JSON real.

TODOs

Block A — LogitMask interface

  • In src/ministruct/mask.py, declare an ABC matching the signature in src/ministruct/BLUEPRINT.md:
    class LogitMask(ABC):
        def reset(self) -> None: ...
        def step(self, last_token_id: int | None) -> np.ndarray: ...
        # returns mask of shape (vocab_size,), values in {0.0, -inf}
        def is_done(self) -> bool: ...
    
  • step(None) returns the mask for the first token. Subsequent calls pass the token just emitted; the mask returned is for the next token.
  • is_done() returns True once the grammar accepts the current sequence as complete (e.g., 4 digits emitted, ready to terminate).

Block B — DigitMask concrete class

  • Constructor: DigitMask(tokenizer, length=4).
  • Internal state: count_emitted (how many digits so far).
  • step(last_token_id):
  • If count_emitted < length: return mask that allows only tokens whose decoded string is one of '0'..'9'. Determine these token IDs once at construction by enumerating the tokenizer.
  • If count_emitted == length: return mask that allows only the EOS token.
  • Update count_emitted based on last_token_id.

Block C — wire into the decoder

  • In src/miniinfer/generate.py (or wherever Phase 21's generate lives), add a mask: LogitMask | None = None parameter.
  • At each step: call mask.step(last_token) to get the mask array, add to logits before sampling.
  • After sampling: pass the chosen token back to the mask next iteration.
  • Terminate when mask.is_done() is True (or when the sampler picks EOS, or when max_new_tokens is hit).

Block D — verify

  • Run bench.py: generate 100 samples with DigitMask(length=4). Assert every output regex-matches ^\d{4}$.
  • Run a control: generate 100 samples with mask=None. Most outputs will NOT match the regex. (This is the existence proof that the mask is doing work.)
  • Record both in results.json.

Block E — tests

In tests/test_ministruct_mask.py:

  • test_digit_mask_first_step_only_digits — mask returned by step(None) has 0.0 exactly at the digit-token indices, -inf everywhere else.
  • test_digit_mask_after_n_digits_only_eos — after length digits emitted, mask allows only EOS.
  • test_digit_mask_is_done_after_full_lengthis_done() flips True at the right moment.
  • test_permissive_mask_no_op — a mask that allows all tokens produces samples identical to mask=None under the same RNG seed (the sanity check from theory/02-logit-masks.md).
  • test_empty_legal_raises — a mask that returns all -inf causes the decoder to raise NoLegalContinuation rather than silently producing NaN.

Constraints

  • No outlines, no lm-format-enforcer, no jsonschema. This is the build before abstracting phase.
  • Mask must be NumPy array, not a list or a dict. It's added to logits which are NumPy.
  • EOS handling. The tokenizer's EOS token id is known; read it from the tokenizer's API. Don't guess.

Stop conditions

You're done when:

  1. All five tests in Block E pass.
  2. bench.py reports all_match_regex: true over 100 samples.
  3. README.md answers: "what is the KL divergence between the constrained and unconstrained distribution at step 0? Is it small (model wanted to emit a digit anyway) or large (model wanted to emit text)?"
  4. Manifest committed.

Pitfalls

  • Tokenizer "0" might not be one token. If ' 0' (space-prefixed) is the actual token, your set of digit-token-ids is wrong. Print the tokenizer's encoding of each digit and verify.
  • Mask of correct shape but wrong dtype. Mask must be float (so -inf is representable). An integer mask of {0, -2**31} doesn't behave correctly under softmax.
  • Forgot to reset mask state between samples. If you generate two samples back-to-back without mask.reset(), the second sample starts with count_emitted == length and only EOS is legal. Test this.
  • Off-by-one in count_emitted. Easy to update before vs after sampling. Either works; pick one and stick with it.

When to consult solutions/

After your tests pass and your bench script reports clean results. The solution lives in solutions/00-regex-mask-ref.md — written at phase open, not pre-written, because it depends on what Borja's actual tokenizer reports. Compare; don't pre-read.


Next lab: lab/01-json-schema-mask.md.