English · Español
Lab 00 — Greedy decoding¶
🇪🇸 El sampler más simple:
argmax. Es determinístico, robusto y la línea base contra la que se mide todo lo demás. Implementa el bucle de decode y verifica que"Yesterday I"produzca"worked".
Objetivo¶
Implementar greedy decoding como función sample_greedy(model, prompt, max_tokens) -> list[int], y verificar que elige la forma de conjugación predicha por el Mini-GPT entrenado sobre un puñado de prompts semilla del corpus §A13.
Setup¶
- Checkpoint del Mini-GPT entrenado en la Fase 18 (
experiments/<date>-phase-18-train/model.npz). - Tokenizer y vocabulario de la Fase 12.
numpy,np.random.default_rng(sin usar aquí — greedy es determinístico, pero cada sampler de esta fase recibe un parámetroseedpor consistencia de API).
Tareas¶
- Crea
src/minimodel/sampling.pycon el docstring a nivel de módulo describiendo la API del sampler:
def sample(model, prompt: list[int], *, max_tokens: int, seed: int,
strategy: SamplingStrategy) -> list[int]:
"""Returns the generated tokens (excluding the prompt)."""
- Implementa la estrategia
Greedy. Una estrategia es un callable(logits: np.ndarray, rng: np.random.Generator) -> int. Greedy ignora el rng:
- Implementa el bucle de decode en
sample(): - Inicializa
tokens = list(prompt). - Para cada paso hasta
max_tokens:- Ejecuta
model(np.array(tokens))→ logits(T, V). - Toma
logits[-1]→(V,). - Llama a
strategy(last_logits, rng)→next_token: int. - Añade
next_tokenatokens. - Si
next_token == EOS_TOKEN_ID, rompe el bucle.
- Ejecuta
-
Devuelve
tokens[len(prompt):]. -
Test sobre tres prompts del corpus §A13:
"Yesterday I"→ se espera una forma verbal de pasado (p. ej.,"worked")."Tomorrow she"→ se espera una forma de futuro (p. ej.,"will go"o"is going to walk").-
"He"→ se espera una forma de tercera persona del singular (p. ej.,"works","plays"). -
Verificación de determinismo. Llama a
sample(...)dos veces con la misma semilla (seed) y el mismo prompt; comprueba que las salidas son idénticas.
Mediciones¶
Para cada prompt, registra en experiments/<date>-phase-21-greedy/:
- El prompt (como texto crudo y como IDs de token).
- La salida greedy (como IDs de token y texto decodificado).
- Los 5 logits superiores en el primer paso de generación (para que veas cuán picuda es la distribución — greedy es "robusto" solo si el hueco entre rango 1 y rango 2 es grande).
Aceptación¶
Greedy()(logits, rng)es puro — los mismos logits dan la misma salida independientemente del rng.- El bucle de decode termina: o en
max_tokenso en EOS. - Para
"Yesterday I", el primer token generado decodifica a un verbo regular o irregular en pasado. (Si no, el modelo está infraentrenado — investiga las curvas de pérdida (loss) de la Fase 18 en lugar del sampler.) mypy --strict src/minimodel/sampling.pypasa.
Trampas¶
- Olvidar recortar la última posición.
logitstiene shape(T, V)— quiereslogits[-1], nologits.argmax(axis=-1)[-1](lo segundo es inocuo pero despilfarrador). - Confusión con el id del token EOS. El corpus §A13 puede o no tener un EOS aprendido — revisa el tokenizer de la Fase 12. Si no hay EOS, simplemente ejecuta hasta
max_tokens. - Mutar la lista del prompt. Trabaja siempre sobre una copia (
list(prompt)) para que el prompt del llamador no se modifique. - Llamar a
model(np.array(tokens))sobre una lista de longitud 0. Añade una precondición:assert len(prompt) > 0.
Completado de referencia (no mires antes de intentarlo)¶
La solución del lab vive en solutions/00-greedy.py después de que la fase esté cerrada. No mires hasta que tu test pase.
Siguiente: 01-temperature-sweep.md