Skip to content

English · Español

Lab 00 — Enmascarar logits contra una regex

Objetivo: implementar la máscara de logits más simple posible y verificar que produce una distribución restringida.

Tiempo estimado: 60–90 minutos.

Prerrequisito: el sampler de la Fase 21 (src/miniinfer/generate.py) commiteado. MiniGPT de la Fase 17 importable.


Lo que produces

Un directorio experiments/30-regex-mask/ que contiene:

  • mask.py — tu primera subclase de LogitMask: coincide con una regex fija (solo dígitos, longitud 4).
  • test_mask.py — casos pytest que verifican el comportamiento de la máscara.
  • bench.py — un driver minúsculo que ejecuta generate(prompt, mask=DigitMask(length=4)) y asserta que cada salida coincide con ^\d{4}$.
  • results.json{ "n_samples": ..., "all_match_regex": true, "wall_time_s": ... }.
  • manifest.json — según LYNX_CORTEX.md §5.
  • README.md — qué mediste.

Más, el nuevo módulo:

  • src/ministruct/mask.py — la ABC LogitMask + la clase concreta DigitMask. Este es el primer archivo en src/ministruct/; el BLUEPRINT en src/ministruct/BLUEPRINT.md (pre-escrito por Claude) es la fuente de verdad del diseño.

El kernel

Elige el grammar más pequeño posible: "exactamente cuatro dígitos ASCII seguidos de EOS". Enmascara todo lo demás.

Esto es un calentamiento. El objetivo es (a) entender el contrato de API que generalizarás en el lab 01, (b) verificar que el enmascarado realmente produce salidas solo de dígitos en el MiniGPT entrenado (o — si el modelo no fue entrenado para emitir dígitos — al menos produzca la preferencia del modelo entre los dígitos).

Empezamos con dígitos en vez de con el grammar de conjugaciones porque el conjunto de tokens legales se identifica trivialmente — son los IDs de tokens '0'..'9' y EOS. El lab 01 generaliza esto a un JSON Schema real.

TODOs

Bloque A — interfaz LogitMask

  • En src/ministruct/mask.py, declara una ABC que coincida con la firma de src/ministruct/BLUEPRINT.md:
    class LogitMask(ABC):
        def reset(self) -> None: ...
        def step(self, last_token_id: int | None) -> np.ndarray: ...
        # devuelve máscara de forma (vocab_size,), valores en {0.0, -inf}
        def is_done(self) -> bool: ...
    
  • step(None) devuelve la máscara para el primer token. Las llamadas posteriores pasan el token recién emitido; la máscara devuelta es para el siguiente token.
  • is_done() devuelve True una vez que el grammar acepta la secuencia actual como completa (p. ej., 4 dígitos emitidos, lista para terminar).

Bloque B — clase concreta DigitMask

  • Constructor: DigitMask(tokenizer, length=4).
  • Estado interno: count_emitted (cuántos dígitos hasta ahora).
  • step(last_token_id):
  • Si count_emitted < length: devuelve una máscara que permite solo tokens cuyo string decodificado sea uno de '0'..'9'. Determina estos IDs de token una vez en la construcción enumerando el tokenizer.
  • Si count_emitted == length: devuelve una máscara que permite solo el token EOS.
  • Actualiza count_emitted basado en last_token_id.

Bloque C — cablear en el decoder

  • En src/miniinfer/generate.py (o dondequiera que viva el generate de la Fase 21), añade un parámetro mask: LogitMask | None = None.
  • En cada paso: llama a mask.step(last_token) para obtener el array de máscara, súmalo a los logits antes de muestrear.
  • Tras muestrear: pasa el token elegido a la máscara para la siguiente iteración.
  • Termina cuando mask.is_done() sea True (o cuando el sampler elija EOS, o cuando se alcance max_new_tokens).

Bloque D — verificar

  • Ejecuta bench.py: genera 100 muestras con DigitMask(length=4). Asserta que cada salida coincida con la regex ^\d{4}$.
  • Ejecuta un control: genera 100 muestras con mask=None. La mayoría de salidas NO coincidirán con la regex. (Esta es la prueba existencial de que la máscara está haciendo trabajo.)
  • Registra ambos en results.json.

Bloque E — tests

En tests/test_ministruct_mask.py:

  • test_digit_mask_first_step_only_digits — la máscara devuelta por step(None) tiene 0.0 exactamente en los índices de tokens de dígitos, -inf en todo lo demás.
  • test_digit_mask_after_n_digits_only_eos — tras length dígitos emitidos, la máscara permite solo EOS.
  • test_digit_mask_is_done_after_full_lengthis_done() cambia a True en el momento correcto.
  • test_permissive_mask_no_op — una máscara que permite todos los tokens produce muestras idénticas a mask=None bajo la misma semilla de RNG (la comprobación de cordura de theory/02-logit-masks.md).
  • test_empty_legal_raises — una máscara que devuelve todo -inf hace que el decoder lance NoLegalContinuation en vez de producir NaN silenciosamente.

Restricciones

  • Sin outlines, sin lm-format-enforcer, sin jsonschema. Esta es la fase de construir antes de abstraer.
  • La máscara debe ser un array NumPy, no una lista ni un dict. Se suma a logits que son NumPy.
  • Manejo de EOS. El ID del token EOS del tokenizer se conoce; léelo desde la API del tokenizer. No adivines.

Condiciones de parada

Has terminado cuando:

  1. Los cinco tests del Bloque E pasan.
  2. bench.py reporta all_match_regex: true sobre 100 muestras.
  3. README.md responde: "¿cuál es la divergencia KL entre la distribución restringida y la no restringida en el paso 0? ¿Es pequeña (el modelo iba a emitir un dígito de todas formas) o grande (el modelo quería emitir texto)?"
  4. Manifest commiteado.

Trampas

  • "0" del tokenizer puede no ser un solo token. Si ' 0' (con espacio delante) es el token real, tu conjunto de IDs de tokens de dígitos está equivocado. Imprime la codificación del tokenizer de cada dígito y verifica.
  • Máscara con forma correcta pero dtype equivocado. La máscara debe ser float (para que -inf sea representable). Una máscara entera de {0, -2**31} no se comporta correctamente bajo softmax.
  • Se te olvida resetear el estado de la máscara entre muestras. Si generas dos muestras seguidas sin mask.reset(), la segunda muestra arranca con count_emitted == length y solo EOS es legal. Testea esto.
  • Off-by-one en count_emitted. Fácil de actualizar antes vs después de muestrear. Cualquiera funciona; elige uno y cíñete a él.

Cuándo consultar solutions/

Después de que tus tests pasen y tu script de bench reporte resultados limpios. La solución vive en solutions/00-regex-mask-ref.md — escrita al abrir la fase, no pre-escrita, porque depende de lo que reporte el tokenizer real de Borja. Compara; no leas antes.


Siguiente lab: lab/01-json-schema-mask.md.