English · Español
Lab 02 — Muestreo (sampling) top-k y top-p (nucleus)¶
🇪🇸 Top-k recorta a las k mejores; top-p recorta al conjunto mínimo cuya masa supera p. Implementa ambos, demuestra que coinciden en distribuciones picudas y divergen en planas, y verifica la propiedad de no-op
top-p=1.0.
Objetivo¶
Implementar las estrategias de muestreo TopK(k, tau) y TopP(p, tau), verificarlas contra las propiedades de no-op de theory/02-top-k-and-top-p.md, y demostrar cómo difieren sobre distribuciones sintéticas planas vs. picudas.
Setup¶
Greedy,Temperaturede los labs 00-01.- Checkpoint del Mini-GPT entrenado.
- Un vector sintético de 5 bins de logits para el test de divergencia:
- Picudo:
z_peaked = [5.0, 4.0, 0.0, 0.0, 0.0] - Plano:
z_flat = [2.0, 1.9, 1.8, 1.7, 0.0]
Tareas¶
- Implementa
TopK(k, tau)ensrc/minimodel/sampling.py:
@dataclass(frozen=True)
class TopK:
k: int
tau: float = 1.0
def __call__(self, logits, rng):
assert self.k >= 1
scaled = logits / self.tau
# Find the k-th largest logit; mask everything below it
threshold = np.partition(scaled, -self.k)[-self.k]
masked = np.where(scaled >= threshold, scaled, -np.inf)
probs = softmax(masked)
return int(rng.choice(len(probs), p=probs))
Nota sobre desempates: si hay empates en el k-ésimo logit, np.partition puede incluir más de k tokens. Decide si quieres exactamente k (usa np.argpartition y ordena) o "al menos k" (usa umbral >=). Documenta la elección. El lab usa "al menos k" — más simple y aceptable.
- Implementa
TopP(p, tau)ensrc/minimodel/sampling.py:
@dataclass(frozen=True)
class TopP:
p: float
tau: float = 1.0
def __call__(self, logits, rng):
assert 0 < self.p <= 1.0
scaled = logits / self.tau
probs = softmax(scaled)
sorted_idx = np.argsort(-probs) # descending
sorted_probs = probs[sorted_idx]
cumsum = np.cumsum(sorted_probs)
# Smallest K* such that sum >= p — use np.searchsorted + 1
cutoff = int(np.searchsorted(cumsum, self.p)) + 1
nucleus = sorted_idx[:cutoff]
mask = np.zeros_like(probs, dtype=bool)
mask[nucleus] = True
truncated = np.where(mask, probs, 0.0)
truncated /= truncated.sum()
return int(rng.choice(len(probs), p=truncated))
- Property test:
TopP(p=1.0) ≡ Temperature(tau).
Para 100 seeds aleatorias, sobre los logits del primer paso del Mini-GPT para "Tomorrow she":
- Muestrea con Temperature(1.0) → token A.
- Muestrea con TopP(1.0, tau=1.0) → token B.
- Comprueba A == B.
Si esto falla, tienes un off-by-one en searchsorted (muy común). El + 1 es crítico: searchsorted(cumsum, 1.0) devuelve V - 1, pero cumsum[V-1] = 1.0, así que el nucleus debe incluir el último índice. cutoff = searchsorted(...) + 1 = V — coge todo.
- Property test:
TopK(k=1) ≡ Greedymódulo empates.
Sobre los logits del primer paso del Mini-GPT, calcula Greedy()(logits, rng) y TopK(1)(logits, rng) para 100 seeds aleatorias. Deben ser iguales (porque la distribución de un solo elemento es determinística). Si hay empates en el argmax (raro en la práctica), documenta la discrepancia.
- Demuestra la divergencia sobre distribuciones sintéticas.
Para z_peaked = [5, 4, 0, 0, 0]:
- TopK(k=2, tau=1.0): mantiene {5, 4}.
- TopP(p=0.9, tau=1.0): calcula softmax(z_peaked) — el token superior obtiene ~73% de masa; los dos superiores ~99%. Así que top-p también escoge {5, 4} aquí. Coinciden.
Para z_flat = [2.0, 1.9, 1.8, 1.7, 0.0]:
- TopK(k=2): mantiene {2.0, 1.9}.
- TopP(p=0.9): softmax(z_flat) es aproximadamente [0.246, 0.222, 0.201, 0.182, 0.149]; cumulativa [0.246, 0.468, 0.670, 0.852, 1.0]. Así que searchsorted(., 0.9) = 4, cutoff = 5 — el nucleus es el vocabulario completo.
- Divergen.
Grafica ambos: un diagrama de barras lado a lado para cada uno (picudo y plano) mostrando la distribución original, la enmascarada por top-k y la enmascarada por top-p.
- Completado de verbos end-to-end. Genera 5 completados de
"Tomorrow she"con cada uno de: Greedy()Temperature(0.7)TopK(5, tau=0.7)TopP(0.9, tau=0.7)
Compara la diversidad (número de completados únicos) y la gramaticalidad (a ojo — ¿cada uno parece una forma verbal plausible?).
Mediciones¶
Guarda en experiments/<date>-phase-21-truncation/:
topk_vs_topp_peaked.png— diagrama de barras sobrez_peaked.topk_vs_topp_flat.png— diagrama de barras sobrez_flat.completions.json— los 5×4 = 20 completados de la tarea 6.property_test_results.txt— pass/fail para los dos property tests de las tareas 3 y 4.
Aceptación¶
TopP(p=1.0)produce salida idéntica aTemperaturepara al menos 95/100 seeds aleatorias (admitiendo el raro caso de desempate).TopK(k=1)produce salida idéntica aGreedypara al menos 95/100 seeds aleatorias.- Sobre
z_flat, el conjunto activo deTopK(k=2)es estrictamente más pequeño que el deTopP(p=0.9). mypy --strict src/minimodel/sampling.pypasa.
Trampas¶
- Off-by-one en la suma cumulativa (cubierto arriba). Cuidado especial:
numpy.searchsorted(a, v, side='left')devuelve la posición más a la izquierda donde se podría insertarv. Para nuestro uso (cumsum >= p), el corte correcto essearchsorted + 1solo sisearchsortedno incluye ya el token umbral. Testea el caso bordep = suma_de_los_dos_primeros. - Softmax sobre logits enmascarados con
-inf.np.exp(-inf) = 0, que es lo que quieres —softmax(masked)funciona correctamente. Pero si olvidas usar-inf(p. ej., enmascaras con0o un número pequeño), filtrarás silenciosamente masa de probabilidad a los tokens enmascarados. Usa-inf. - Desempates en la frontera k-ésima. Si
logits = [3, 3, 3, 1, 1]yk = 2, "top-2" es ambiguo. Documenta y testea. - Comparación de punto flotante en property tests. Al comprobar
TopP(p=1.0) ≡ Temperature(tau), los vectores de probabilidad internos pueden diferir por1e-16debido al paso de normalización. Compara tokens muestreados, no probabilidades crudas.
Siguiente: 03-diversity-vs-accuracy.md