Skip to content

English · Español

00 — El bucle de decode y por qué importa el muestreo (sampling)

🇪🇸 El modelo emite logits, no texto. La función que los convierte en un token concreto es el sampler. Esa función es donde se decide si el modelo suena correcto, creativo, repetitivo o roto.

El bucle de decode en una pantalla

Después de que la Fase 18 entrenara un Mini-GPT, generar texto a partir de él es literalmente esto:

def generate(model, prompt: list[int], *, max_tokens: int, sampler) -> list[int]:
    tokens = list(prompt)                                 # work with a mutable copy
    for _ in range(max_tokens):
        logits = model(np.array(tokens))                  # (T, V)
        next_logits = logits[-1]                          # (V,) — only the last position matters
        next_token = sampler(next_logits)                 # int
        tokens.append(next_token)
        if next_token == EOS_TOKEN_ID:
            break
    return tokens[len(prompt):]                           # only the generated suffix

Cinco operaciones: forward, recortar la última posición, muestrear, añadir, comprobar parada. Repetir. Eso es todo.

Lo que cambia entre greedy y nucleus y "el muestreo estilo GPT-4" es solamente la función sampler. El bucle de decode en sí es compartido.

Qué puede significar "sampler"

Tres familias:

  1. Determinístico. Greedy decoding (argmax) siempre devuelve el mismo token para los mismos logits. Sin aleatoriedad.
  2. Estocástico sobre la distribución completa del modelo. Softmax escalada por temperatura + muestreo. Se consulta el vector de probabilidades entero del modelo.
  3. Estocástico con truncado. Top-k o top-p primero ponen a cero los tokens improbables, y después muestrean sobre la masa restante. Habitual en producción porque el muestreo puro sobre la distribución completa puede producir tokens de baja probabilidad de vez en cuando.

Cada familia hace un compromiso distinto entre diversidad y corrección, que el lab 03 medirá.

Por qué nos importa el muestreo sobre el corpus de verbos

Para el prompt "Tomorrow she", el Mini-GPT entrenado debería emitir un token que sea:

  1. Sintácticamente válido — una forma verbal, no un token de puntuación.
  2. Correcto en tiempowill go, is going to go o formas similares en futuro.
  3. Correcto en persona — tercera persona del singular, no primera persona.

Los logits del modelo codifican probabilidades para todo lo anterior. El sampler elige cuál observamos. Con greedy obtenemos la única continuación más probable, siempre. Con temperatura > 1 obtenemos variedad (a veces "will go", a veces "is going to walk"). Con temperatura demasiado alta obtenemos basura (tokens aleatorios).

Ajustar el sampler es parte de lo que hace que el agente tutor de la Fase 32 parezca competente en lugar de robótico.

Un matiz sobre probabilidad vs logits

El modelo emite logits (log-probabilidades sin normalizar). Para obtener probabilidades aplicas softmax. Muchos samplers pueden operar directamente sobre los logits sin calcular softmax — al argmax no le importa la normalización, top-k solo ordena por rango, y solo el escalado por temperatura y el nucleus necesitan los valores explícitos de probabilidad.

Higiene numérica: si calculas softmax y después muestreas, usa el truco log-sum-exp (Fase 05). Si puedes muestrear sin materializar nunca el vector completo de probabilidades (p. ej., el truco Gumbel-max), mejor todavía. Para nuestro \(V = 64\) esto da igual; para \(V = 50000\) sí importaría.

El coste: un forward por token

Cada iteración del bucle llama a model(tokens) — un forward completo a través del Mini-GPT sobre una secuencia de longitud \(T_\text{current}\). A medida que \(T_\text{current}\) crece, cada paso se vuelve más caro:

  • Token 0: forward sobre longitud \(L\) (el prompt).
  • Token 1: forward sobre longitud \(L + 1\).
  • Token 2: forward sobre longitud \(L + 2\).
  • ...
  • Token \(T\): forward sobre longitud \(L + T\).

Coste total: \(\sum_{t=0}^{T-1} O((L + t) \cdot d^2) \approx O(T(L + T) d^2)\). Para \(L = 8, T = 20, d = 64\): ~\(28 \cdot 4096 = 115{,}000\) "operaciones" — rápido en una CPU.

Pero fíjate: el forward del token 0 ve [prompt]; el forward del token 1 ve [prompt, gen_0]. Las primeras \(L\) posiciones no cambian entre llamadas. La mayor parte del cómputo del forward se recomputa de forma redundante en cada paso.

Por eso la Fase 22 introduce la KV cache: cachear las claves/valores de atención de posiciones anteriores para que solo computemos la nueva última posición en cada paso. La Fase 21 explícitamente no usa cache para que sientas el coste.

Lo que este archivo NO cubre

  • La matemática de la temperatura. Siguiente archivo.
  • Estrategias de truncado. Archivo 02.
  • Implementación de la cache. Fase 22.

Siguiente: 01-temperature.md