English · Español
Lab 02 — Top-k and Top-p (nucleus) sampling¶
🇪🇸 Top-k recorta a las k mejores; top-p recorta al conjunto mínimo cuya masa supera p. Implementa ambos, demuestra que coinciden en distribuciones picudas y divergen en planas, y verifica la propiedad de no-op
top-p=1.0.
Objective¶
Implement TopK(k, tau) and TopP(p, tau) sampling strategies, verify them against the no-op properties from theory/02-top-k-and-top-p.md, and demonstrate how they differ on synthetic flat vs. peaked distributions.
Setup¶
Greedy,Temperaturefrom labs 00-01.- Trained Mini-GPT checkpoint.
- A synthetic 5-bin logit vector for the divergence test:
- Peaked:
z_peaked = [5.0, 4.0, 0.0, 0.0, 0.0] - Flat:
z_flat = [2.0, 1.9, 1.8, 1.7, 0.0]
Tasks¶
- Implement
TopK(k, tau)insrc/minimodel/sampling.py:
@dataclass(frozen=True)
class TopK:
k: int
tau: float = 1.0
def __call__(self, logits, rng):
assert self.k >= 1
scaled = logits / self.tau
# Find the k-th largest logit; mask everything below it
threshold = np.partition(scaled, -self.k)[-self.k]
masked = np.where(scaled >= threshold, scaled, -np.inf)
probs = softmax(masked)
return int(rng.choice(len(probs), p=probs))
Tie-breaking note: if there are ties at the k-th logit, np.partition may include more than k tokens. Decide whether you want exactly k (use np.argpartition and sort) or "at least k" (use >= threshold). Document the choice. Lab uses "at least k" — simpler and acceptable.
- Implement
TopP(p, tau)insrc/minimodel/sampling.py:
@dataclass(frozen=True)
class TopP:
p: float
tau: float = 1.0
def __call__(self, logits, rng):
assert 0 < self.p <= 1.0
scaled = logits / self.tau
probs = softmax(scaled)
sorted_idx = np.argsort(-probs) # descending
sorted_probs = probs[sorted_idx]
cumsum = np.cumsum(sorted_probs)
# Smallest K* such that sum >= p — use np.searchsorted + 1
cutoff = int(np.searchsorted(cumsum, self.p)) + 1
nucleus = sorted_idx[:cutoff]
mask = np.zeros_like(probs, dtype=bool)
mask[nucleus] = True
truncated = np.where(mask, probs, 0.0)
truncated /= truncated.sum()
return int(rng.choice(len(probs), p=truncated))
- Property test:
TopP(p=1.0) ≡ Temperature(tau).
For 100 random seeds, on the Mini-GPT first-step logits for "Tomorrow she":
- Sample with Temperature(1.0) → token A.
- Sample with TopP(1.0, tau=1.0) → token B.
- Assert A == B.
If this fails, you have an off-by-one in searchsorted (very common). The + 1 is critical: searchsorted(cumsum, 1.0) returns V - 1, but cumsum[V-1] = 1.0, so the nucleus must include the last index. cutoff = searchsorted(...) + 1 = V — pick everything.
- Property test:
TopK(k=1) ≡ Greedymodulo ties.
On the Mini-GPT first-step logits, compute Greedy()(logits, rng) and TopK(1)(logits, rng) for 100 random seeds. They should be equal (since the single-element distribution is deterministic). If there are ties at the argmax (rare in practice), document the discrepancy.
- Demonstrate divergence on synthetic distributions.
For z_peaked = [5, 4, 0, 0, 0]:
- TopK(k=2, tau=1.0): keeps {5, 4}.
- TopP(p=0.9, tau=1.0): compute softmax(z_peaked) — the top token gets ~73% mass; top-two get ~99%. So top-p picks {5, 4} here too. They coincide.
For z_flat = [2.0, 1.9, 1.8, 1.7, 0.0]:
- TopK(k=2): keeps {2.0, 1.9}.
- TopP(p=0.9): softmax(z_flat) is roughly [0.246, 0.222, 0.201, 0.182, 0.149]; cumulative [0.246, 0.468, 0.670, 0.852, 1.0]. So searchsorted(., 0.9) = 4, cutoff = 5 — the nucleus is the entire vocab.
- They diverge.
Plot both: a side-by-side bar chart for each (peaked and flat) showing the original distribution, the top-k masked distribution, and the top-p masked distribution.
- End-to-end verb completion. Generate 5 completions of
"Tomorrow she"with each of: Greedy()Temperature(0.7)TopK(5, tau=0.7)TopP(0.9, tau=0.7)
Compare the diversity (number of unique completions) and grammaticality (eyeball — does each look like a plausible verb form?).
Measurements¶
Save to experiments/<date>-phase-21-truncation/:
topk_vs_topp_peaked.png— bar chart onz_peaked.topk_vs_topp_flat.png— bar chart onz_flat.completions.json— the 5×4 = 20 completions from task 6.property_test_results.txt— pass/fail for the two property tests in tasks 3 and 4.
Acceptance¶
TopP(p=1.0)produces identical output toTemperaturefor at least 95/100 random seeds (allowing for the rare tie-breaking case).TopK(k=1)produces identical output toGreedyfor at least 95/100 random seeds.- On
z_flat,TopK(k=2)'s active set is strictly smaller thanTopP(p=0.9)'s active set. mypy --strict src/minimodel/sampling.pypasses.
Pitfalls¶
- Off-by-one in cumulative sum (covered above). Be especially careful:
numpy.searchsorted(a, v, side='left')returns the leftmost position wherevcould be inserted. For our use (cumsum >= p), the correct cutoff issearchsorted + 1only ifsearchsorteddoesn't already include the threshold token. Test on the edge casep = sum_of_first_two. - Softmax on
-infmasked logits.np.exp(-inf) = 0, which is what you want —softmax(masked)works correctly. But if you forget to use-inf(e.g., you mask with0or a small number), you'll silently leak probability mass to the masked-out tokens. Use-inf. - Tie-breaking at the k-th boundary. If
logits = [3, 3, 3, 1, 1]andk = 2, "top-2" is ambiguous. Document and test. - Floating-point comparison in property tests. When checking
TopP(p=1.0) ≡ Temperature(tau), the internal probability vectors may differ by1e-16due to the normalisation step. Compare sampled tokens, not raw probabilities.