Skip to content

English · Español

Lab 00 — Greedy decoding

🇪🇸 El sampler más simple: argmax. Es determinístico, robusto y la línea base contra la que se mide todo lo demás. Implementa el bucle de decode y verifica que "Yesterday I" produzca "worked".

Objetivo

Implementar greedy decoding como función sample_greedy(model, prompt, max_tokens) -> list[int], y verificar que elige la forma de conjugación predicha por el Mini-GPT entrenado sobre un puñado de prompts semilla del corpus §A13.

Setup

  • Checkpoint del Mini-GPT entrenado en la Fase 18 (experiments/<date>-phase-18-train/model.npz).
  • Tokenizer y vocabulario de la Fase 12.
  • numpy, np.random.default_rng (sin usar aquí — greedy es determinístico, pero cada sampler de esta fase recibe un parámetro seed por consistencia de API).

Tareas

  1. Crea src/minimodel/sampling.py con el docstring a nivel de módulo describiendo la API del sampler:
def sample(model, prompt: list[int], *, max_tokens: int, seed: int,
           strategy: SamplingStrategy) -> list[int]:
    """Returns the generated tokens (excluding the prompt)."""
  1. Implementa la estrategia Greedy. Una estrategia es un callable (logits: np.ndarray, rng: np.random.Generator) -> int. Greedy ignora el rng:
class Greedy:
    def __call__(self, logits, rng):
        return int(np.argmax(logits))
  1. Implementa el bucle de decode en sample():
  2. Inicializa tokens = list(prompt).
  3. Para cada paso hasta max_tokens:
    • Ejecuta model(np.array(tokens)) → logits (T, V).
    • Toma logits[-1](V,).
    • Llama a strategy(last_logits, rng)next_token: int.
    • Añade next_token a tokens.
    • Si next_token == EOS_TOKEN_ID, rompe el bucle.
  4. Devuelve tokens[len(prompt):].

  5. Test sobre tres prompts del corpus §A13:

  6. "Yesterday I" → se espera una forma verbal de pasado (p. ej., "worked").
  7. "Tomorrow she" → se espera una forma de futuro (p. ej., "will go" o "is going to walk").
  8. "He" → se espera una forma de tercera persona del singular (p. ej., "works", "plays").

  9. Verificación de determinismo. Llama a sample(...) dos veces con la misma semilla (seed) y el mismo prompt; comprueba que las salidas son idénticas.

Mediciones

Para cada prompt, registra en experiments/<date>-phase-21-greedy/:

  • El prompt (como texto crudo y como IDs de token).
  • La salida greedy (como IDs de token y texto decodificado).
  • Los 5 logits superiores en el primer paso de generación (para que veas cuán picuda es la distribución — greedy es "robusto" solo si el hueco entre rango 1 y rango 2 es grande).

Aceptación

  • Greedy()(logits, rng) es puro — los mismos logits dan la misma salida independientemente del rng.
  • El bucle de decode termina: o en max_tokens o en EOS.
  • Para "Yesterday I", el primer token generado decodifica a un verbo regular o irregular en pasado. (Si no, el modelo está infraentrenado — investiga las curvas de pérdida (loss) de la Fase 18 en lugar del sampler.)
  • mypy --strict src/minimodel/sampling.py pasa.

Trampas

  • Olvidar recortar la última posición. logits tiene shape (T, V) — quieres logits[-1], no logits.argmax(axis=-1)[-1] (lo segundo es inocuo pero despilfarrador).
  • Confusión con el id del token EOS. El corpus §A13 puede o no tener un EOS aprendido — revisa el tokenizer de la Fase 12. Si no hay EOS, simplemente ejecuta hasta max_tokens.
  • Mutar la lista del prompt. Trabaja siempre sobre una copia (list(prompt)) para que el prompt del llamador no se modifique.
  • Llamar a model(np.array(tokens)) sobre una lista de longitud 0. Añade una precondición: assert len(prompt) > 0.

Completado de referencia (no mires antes de intentarlo)

La solución del lab vive en solutions/00-greedy.py después de que la fase esté cerrada. No mires hasta que tu test pase.

Siguiente: 01-temperature-sweep.md