English · Español
02 — Logit Masks: The Derivation¶
🇪🇸 Restringir el muestreo a un subconjunto \(\mathcal{L}\) de tokens legales es exactamente equivalente a multiplicar los logits por \(-\infty\) fuera de \(\mathcal{L}\) antes del softmax. La distribución resultante es la condicional sobre \(\{t \in \mathcal{L}\}\). No hay magia, solo Bayes.
This is the load-bearing theory page. Derive the formula, internalize it, and the rest of the phase is plumbing.
Setup¶
The model produces a logit vector \(z \in \mathbb{R}^{|V|}\) at each step, conditioned on the prefix \(x_{<i}\). The unconstrained sampling distribution is
We have a constraint: only tokens in a subset \(\mathcal{L}_i \subseteq V\) are legal at step \(i\) (where \(\mathcal{L}_i\) depends on the prefix and the grammar). We want to sample from
Apply the definition of conditional probability:
Numerator: the original probability for legal tokens, zero for illegal. Denominator: the total mass on legal tokens, used to renormalize.
The logit-mask trick¶
Computing \(p_\text{constr}\) via that formula requires: 1. Softmax the full logits (\(|V|\) exp, \(|V|\) sum). 2. Multiply by the indicator mask. 3. Renormalize by the sum over legal tokens.
That's 3 passes over the vocabulary. The mask trick is to push the indicator into the logits before softmax:
Then \(\exp(\tilde{z}_t) = \exp(z_t) \cdot \mathbb{1}[t \in \mathcal{L}_i]\) (since \(\exp(-\infty) = 0\)), and softmax of \(\tilde{z}\) is exactly \(p_\text{constr}\). Two passes: mask, then softmax-with-renorm. Identical distribution.
In code:
masked_logits = logits + mask # mask is 0 for legal, -inf for illegal
probs = softmax(masked_logits)
That mask is exactly the deliverable Phase 30 builds. Everything downstream — temperature, top-k, top-p, the sampler from Phase 21 — operates on masked_logits (or probs) unchanged.
Numerical considerations¶
-inf is a real numpy.float32 value (np.float32(-np.inf)), and exp(-inf) == 0.0 exactly. There's no rounding issue.
The danger is that all logits get masked. Then masked_logits is all -inf, softmax computes 0/0 and yields NaN. This is the "empty \(\mathcal{L}_i\)" case (pitfall 3 in PHASE_30_PLAN.md §5). Defensive code: assert (mask > -inf).any() before softmax; raise a NoLegalContinuation exception if the grammar painted itself into a corner.
For stability we use the standard softmax with the -max subtraction trick (from Phase 2). Care: the -max trick subtracts the maximum of the masked logits, which is -inf if all are masked. The defensive check above prevents this from hitting the softmax.
Composition with sampling strategies¶
Phase 21's sampler can do temperature, top-k, top-p, repetition penalties. How does mask compose with these?
Temperature. Temperature scales the logits before softmax: \(z'_t = z_t / T\). The mask is \(\{0, -\infty\}\); scaling either gives back \(\{0, -\infty\}\) (\(-\infty / T = -\infty\) for \(T > 0\)). So mask and temperature commute. Apply either first. Convention: apply mask first (the cheaper operation; masked entries skip the temperature multiplication if you're smart about it).
Top-k. Top-k keeps the \(k\) highest logits. If you apply top-k first then mask, you might end up with fewer than \(k\) legal tokens (some of the top-\(k\) get masked). If you mask first then top-k, you pick the top-\(k\) among legal tokens. The latter is what you want. Mask before top-k.
Top-p. Top-p (nucleus) keeps the smallest set of tokens whose probability sum ≥ \(p\). Same argument: mask first, then top-p on the renormalized distribution. Otherwise top-p might select an illegal token that gets dropped, leaving the nucleus undefined.
Repetition penalty. Multiplies the logits of recently-emitted tokens by some factor. Commutes with mask (illegal stays illegal). Order doesn't matter, but applying mask first is cheaper.
General rule: mask is the outermost operation; everything else is inside. In code:
def sample(logits, mask, temperature, top_p, ...):
logits = logits + mask # 1. mask
logits = logits / temperature # 2. temperature
logits = apply_repetition_penalty(...) # 3. rep penalty
probs = softmax(logits) # 4. softmax
probs = apply_top_p(probs, top_p) # 5. top-p
token = sample_from(probs) # 6. sample
return token
The Phase 30 lab implements exactly this order and adds a test that verifies an illegal token is never sampled.
The KL diagnostic¶
How much did the mask distort the model? Measure with KL divergence:
A simpler form using \(Z = \sum_{t \in \mathcal{L}_i} p(t)\) (the total legal mass):
Derivation: \(p_\text{constr}(t) = p(t) / Z\) for \(t \in \mathcal{L}_i\), so \(\log(p_\text{constr} / p) = -\log Z\). Sum weighted by \(p_\text{constr}\) is just \(-\log Z\) (it's constant).
Interpretation. If \(Z \approx 1\) (the model was already going to emit a legal token), \(\mathrm{KL} \approx 0\) — masking is a no-op. If \(Z \approx 0\) (the model wanted to emit something illegal), \(\mathrm{KL} \to \infty\) — we're forcing the model very far from its preferred distribution.
A consistently high \(\mathrm{KL}\) across decode steps means the model is fighting the grammar. This is a signal, not a failure: it tells you the model wasn't trained on this format and is being coerced, which may produce semantically poor output even if it parses. We log this in experiments/30-mask-overhead/.
Computing the mask¶
The mask depends on the current state of the grammar parser given the prefix emitted so far. For our JSON-schema use case:
state := one of {
EXPECT_OPEN_BRACE,
EXPECT_KEY_OPEN_QUOTE,
EXPECT_KEY_CHARS,
EXPECT_KEY_CLOSE_QUOTE,
EXPECT_COLON,
EXPECT_VALUE_BY_KEY, # depends on which key
EXPECT_COMMA_OR_CLOSE,
DONE,
}
At each step, given state, we enumerate the vocabulary and ask: "if I emit token \(t\), what state does the parser end up in?". If the answer is a valid state (or a valid path to a valid state), \(t\) is legal; otherwise mask it.
The complication is multi-character tokens. A token like "verb": is one BPE token but spans EXPECT_KEY_OPEN_QUOTE → EXPECT_KEY_CHARS → EXPECT_KEY_CLOSE_QUOTE → EXPECT_COLON → EXPECT_VALUE_BY_KEY (five state transitions). The mask logic must simulate the whole sequence and only accept if every intermediate state is legal and the final state is one we wanted to reach.
For Phase 30 we implement this simulation naively: for each candidate token, decode it to a string, run the parser one character at a time, accept or reject. O(\(|V| \cdot \bar{L}\)) per step where \(\bar{L}\) is average token length. Slow but correct.
A sanity check: identity at \(\mathcal{L}_i = V\)¶
If every token is legal (\(\mathcal{L}_i = V\)), the mask is all zeros and \(p_\text{constr} = p\). Masking is a no-op in this case. Phase 30's tests verify this: a "permissive" mask that allows everything produces identical samples to no mask at all (modulo any RNG state differences, which our deterministic sampler avoids).
This is the first thing to test. A mask implementation that fails this test is broken.
A second sanity check: degenerate at \(\mathcal{L}_i = \{t^*\}\)¶
If only one token \(t^*\) is legal, \(p_\text{constr}\) is a point mass on \(t^*\). Sampling always returns \(t^*\). The decoder is effectively forced on this step. The conjugation schema has several such steps (e.g., after {, only " is legal; after "verb", only : is legal).
These "forced" steps are where the model has zero say. The interesting steps are inside the value fields, where the model picks the actual verb, tense, or person from the closed enum. The mask there allows only the legal enum values; the model picks among them.
What this means operationally¶
For the conjugation schema, large stretches of the output are forced — the punctuation, the field names. The model only "creates" inside the values, and even there the choice is from a small enum (20 verbs, 5 tenses, 3 persons). This dramatically reduces the entropy of the generated sequence. In practice, structured generation often emits tens of tokens of overhead (the JSON scaffolding) per few-token content prediction. That's fine — the tokens are forced, so the cost is just one matmul per token at decode time, not any meaningful inference work.
A trap: the model doesn't "know" it's being masked¶
The model's next-token distribution is computed assuming unconstrained continuation. Its hidden state encodes its expectation of the future, not the grammar's. If the grammar forces it down a path it considered unlikely, the model's KV cache from prior steps may be misleading for the remaining prediction. In practice this is fine for small schemas; in theory it's a source of distribution shift that production systems sometimes work around with "teach the model to emit the format natively via fine-tuning" (Phase 28 territory).
For Phase 30 we ignore this — our model is tiny and the schema is small.
What this phase does NOT cover¶
- Distribution-shift correction. A "speculative re-prefill" with a fine-tuned model that already knows the format is a Phase 28 (LoRA) extension; we don't do it here.
- Per-beam mask state. Phase 30 does greedy / top-p only. Beam search × mask interaction is a separate subject; mentioned in
PHASE_30_PLAN.md§7. - Mid-token (sub-BPE) masking. Production Outlines splits tokens further when they cross grammar transitions awkwardly; we accept whatever token the BPE gives us and validate at the token boundary.
- Non-JSON grammars. Our mask is JSON-schema-specific. CFG / GBNF descriptions are in
03-grammar-as-dfa.md, not implemented.
Next: theory/03-grammar-as-dfa.md — how production implementations precompile the grammar into an automaton, and why we don't do that here.