Skip to content

English · Español

Lab 00 — Greedy decode

🇪🇸 El sampler más simple: argmax. Es determinístico, robusto y la línea base contra la que se mide todo lo demás. Implementa el bucle de decode y verifica que "Yesterday I" produzca "worked".

Objective

Implement greedy decoding as a function sample_greedy(model, prompt, max_tokens) -> list[int], and verify it picks the conjugation form predicted by the trained Mini-GPT on a handful of seed prompts from the §A13 corpus.

Setup

  • Trained Mini-GPT checkpoint from Phase 18 (experiments/<date>-phase-18-train/model.npz).
  • Tokenizer and vocabulary from Phase 12.
  • numpy, np.random.default_rng (unused here — greedy is deterministic, but every sampler in this phase takes a seed parameter for API consistency).

Tasks

  1. Create src/minimodel/sampling.py with the module-level docstring describing the sampler API:
def sample(model, prompt: list[int], *, max_tokens: int, seed: int,
           strategy: SamplingStrategy) -> list[int]:
    """Returns the generated tokens (excluding the prompt)."""
  1. Implement the Greedy strategy. A strategy is a callable (logits: np.ndarray, rng: np.random.Generator) -> int. Greedy ignores the rng:
class Greedy:
    def __call__(self, logits, rng):
        return int(np.argmax(logits))
  1. Implement the decode loop in sample():
  2. Initialize tokens = list(prompt).
  3. For each step up to max_tokens:
    • Run model(np.array(tokens))(T, V) logits.
    • Take logits[-1](V,).
    • Call strategy(last_logits, rng)next_token: int.
    • Append next_token to tokens.
    • If next_token == EOS_TOKEN_ID, break.
  4. Return tokens[len(prompt):].

  5. Test on three prompts from the §A13 corpus:

  6. "Yesterday I" → expect a past-tense verb form (e.g., "worked").
  7. "Tomorrow she" → expect a future-tense form (e.g., "will go" or "is going to walk").
  8. "He" → expect a 3rd-person singular form (e.g., "works", "plays").

  9. Determinism check. Call sample(...) twice with the same seed and same prompt; assert the outputs are identical.

Measurements

For each prompt, log to experiments/<date>-phase-21-greedy/:

  • The prompt (as raw text and as token IDs).
  • The greedy output (as token IDs and decoded text).
  • The top-5 logits at the first generation step (so you can see how peaked the distribution is — greedy is "robust" only if the gap between rank-1 and rank-2 is large).

Acceptance

  • Greedy()(logits, rng) is pure — same logits give same output regardless of rng.
  • The decode loop terminates: either at max_tokens or at EOS.
  • For "Yesterday I", the first generated token decodes to a past-tense regular or irregular verb. (If it does not, the model is undertrained — investigate Phase 18 loss curves rather than the sampler.)
  • mypy --strict src/minimodel/sampling.py passes.

Pitfalls

  • Forgetting to slice the last position. logits has shape (T, V) — you want logits[-1], not logits.argmax(axis=-1)[-1] (the latter is harmless but wasteful).
  • EOS token id confusion. The §A13 corpus may or may not have a learned EOS — check Phase 12's tokenizer. If there's no EOS, just run to max_tokens.
  • Mutating the prompt list. Always work on a copy (list(prompt)) so the caller's prompt isn't modified.
  • Calling model(np.array(tokens)) on a 0-length list. Add a precondition: assert len(prompt) > 0.

Reference completion (do not peek before attempting)

Lab solution lives in solutions/00-greedy.py after the phase is done. Do not look until your test passes.

Next: 01-temperature-sweep.md