English · Español
04 — Coding Drills: 12 Ejercicios Implement-It-From-Scratch¶
12 ejercicios de coding tipo entrevista (attention, BPE, gradient checkpointing, KV cache, LoRA, DPO loss, top-p, LayerNorm fwd+bwd, RMSNorm, RoPE, beam search, continuous batcher). 30 minutos cada uno + stub de solución de referencia.
Cómo drillear¶
- Presupuesto de 30 minutos. Pizarra o archivo
.pyen blanco sin internet. - NumPy preferido (más portable a pizarra); PyTorch aceptable para los drills más difíciles.
- El stub de solución al final de cada sección es ~40 líneas; escribe más si hace falta pero la referencia es corta.
- Enlace cruzado: cada drill nombra la fase que te preparó para él.
Por qué estos 12 específicamente¶
Estos son los drills que aparecen repetidamente en phone screens y on-sites 2024-2026 en Anthropic, OpenAI, Google Brain, DeepMind, xAI, Mistral, Cohere, y los principales scaleups. El drill "implementa attention" en particular es el filtro probado — la mayoría de candidatos conocen attention conceptualmente pero no pueden escribirla en un archivo en blanco en 20 minutos.
Drill 01 — Implementa scaled dot-product attention (solo forward)¶
Brief. Dados Q, K, V de shape (batch, n_heads, seq, d_head) y una máscara causal opcional, devuelve el output de attention del mismo shape.
Permitido. NumPy o PyTorch. Sin torch.nn.functional.scaled_dot_product_attention.
Restricciones. 20 minutos. Debe incluir escalado sqrt(d_k). Debe manejar masking causal.
Stub de solución.
import numpy as np
def scaled_dot_product_attention(Q, K, V, causal=False):
# Q, K, V: (B, H, N, D)
B, H, N, D = Q.shape
scores = np.matmul(Q, np.swapaxes(K, -1, -2)) / np.sqrt(D) # (B, H, N, N)
if causal:
mask = np.triu(np.ones((N, N), dtype=bool), k=1)
scores = np.where(mask, -1e9, scores)
# Numerically stable softmax
scores = scores - scores.max(axis=-1, keepdims=True)
weights = np.exp(scores)
weights = weights / weights.sum(axis=-1, keepdims=True)
out = np.matmul(weights, V) # (B, H, N, D)
return out
Bugs comunes a señalar.
- Olvidar el escalado sqrt(d_k).
- Máscara causal aplicada en el eje equivocado (error de transpuesta).
- Overflow del softmax porque se omite el truco - max(scores).
→ Fase 15.
Drill 02 — Implementa tokenizer BPE (entrenamiento + encode)¶
Brief. Entrena un tokenizer BPE byte-level en un corpus pequeño a vocab_size = 1000, luego codifica un string nuevo con los merges aprendidos.
Permitido. Biblioteca estándar de Python.
Restricciones. 30 minutos. Debe manejar input byte-level (bytes.encode("utf-8")).
Stub de solución.
from collections import Counter
def train_bpe(corpus_bytes, vocab_size=1000):
# Initial tokenization: each byte is a token.
tokens = [list(b) for b in corpus_bytes] # list of list of int
merges = []
next_id = 256
while next_id < vocab_size:
pair_counts = Counter()
for seq in tokens:
for a, b in zip(seq, seq[1:]):
pair_counts[(a, b)] += 1
if not pair_counts:
break
best_pair, _ = pair_counts.most_common(1)[0]
merges.append((best_pair, next_id))
# Apply merge.
new_tokens = []
for seq in tokens:
merged = []
i = 0
while i < len(seq):
if i + 1 < len(seq) and (seq[i], seq[i+1]) == best_pair:
merged.append(next_id)
i += 2
else:
merged.append(seq[i])
i += 1
new_tokens.append(merged)
tokens = new_tokens
next_id += 1
return merges
def encode(text, merges):
seq = list(text.encode("utf-8"))
for pair, new_id in merges:
out = []
i = 0
while i < len(seq):
if i + 1 < len(seq) and (seq[i], seq[i+1]) == pair:
out.append(new_id)
i += 2
else:
out.append(seq[i])
i += 1
seq = out
return seq
Pit-of-failure. Aplicar merges en cualquier orden en vez de en train-order. Señales: nunca depuraste un tokenizer.
→ Fase 11.
Drill 03 — Implementa gradient checkpointing para un MLP de 4 capas¶
Brief. Construye un MLP de 4 capas. Implementa gradient checkpointing manualmente: no almacenes activaciones entre las capas 2 y 3; recomputa en backward.
Permitido. PyTorch. torch.utils.checkpoint.checkpoint no permitido — escríbelo tú con torch.autograd.Function.
Restricciones. 25 minutos.
Stub de solución.
import torch
from torch.autograd import Function
class Checkpoint(Function):
@staticmethod
def forward(ctx, run_function, x, *params):
ctx.run_function = run_function
ctx.save_for_backward(x, *params)
with torch.no_grad():
y = run_function(x, *params)
return y
@staticmethod
def backward(ctx, grad_y):
x, *params = ctx.saved_tensors
# Recompute with grad enabled
x = x.detach().requires_grad_(True)
params = [p.detach().requires_grad_(True) for p in params]
with torch.enable_grad():
y = ctx.run_function(x, *params)
grads = torch.autograd.grad(y, [x, *params], grad_y)
return (None, *grads)
def checkpoint(fn, x, *params):
return Checkpoint.apply(fn, x, *params)
Pit-of-failure. Olvidar detachar inputs antes de requires_grad-earlos. Señales: nunca escribiste un autograd Function.
→ Fase 07, Fase 08.
Drill 04 — Implementa un KV cache para decoding autorregresivo¶
Brief. Dado un bloque transformer, implementa step(x_new, cache) que use el cache para el prefijo y solo compute attention para el token nuevo.
Permitido. PyTorch.
Restricciones. 25 minutos. Debe manejar la primera llamada (cache vacío) correctamente.
Stub de solución.
import torch
class KVCache:
def __init__(self):
self.k = None # (B, H, N, D)
self.v = None
def update(self, k_new, v_new):
# k_new, v_new: (B, H, 1, D) typically
if self.k is None:
self.k = k_new
self.v = v_new
else:
self.k = torch.cat([self.k, k_new], dim=2)
self.v = torch.cat([self.v, v_new], dim=2)
return self.k, self.v
def attention_step(x_new, W_q, W_k, W_v, W_o, cache, n_heads):
# x_new: (B, 1, d_model)
B, _, d = x_new.shape
d_head = d // n_heads
q = (x_new @ W_q).view(B, 1, n_heads, d_head).transpose(1, 2)
k = (x_new @ W_k).view(B, 1, n_heads, d_head).transpose(1, 2)
v = (x_new @ W_v).view(B, 1, n_heads, d_head).transpose(1, 2)
K, V = cache.update(k, v)
# Attention over the full prefix (no need for causal mask: only one new query)
scores = (q @ K.transpose(-1, -2)) / (d_head ** 0.5) # (B, H, 1, N)
weights = torch.softmax(scores, dim=-1)
attn = weights @ V # (B, H, 1, D)
out = attn.transpose(1, 2).reshape(B, 1, d) @ W_o
return out
Pit-of-failure. Concatenar en el eje equivocado; olvidar que la query es solo 1 token pero key/value abarcan el prefijo entero.
→ Fase 22.
Drill 05 — Implementa una capa LoRA Linear¶
Brief. Construye un módulo LoRALinear(in_features, out_features, rank, alpha). El forward debe equivaler a x @ W^T + (alpha/rank) * x @ A @ B. El W base está congelado.
Permitido. PyTorch.
Restricciones. 20 minutos. Inicialización: A Kaiming, B ceros, así el delta LoRA arranca en cero.
Stub de solución.
import torch
import torch.nn as nn
class LoRALinear(nn.Module):
def __init__(self, in_features, out_features, rank=8, alpha=16):
super().__init__()
self.base = nn.Linear(in_features, out_features, bias=False)
for p in self.base.parameters():
p.requires_grad = False
# Low-rank factors
self.A = nn.Parameter(torch.empty(in_features, rank))
self.B = nn.Parameter(torch.zeros(rank, out_features))
nn.init.kaiming_uniform_(self.A, a=5 ** 0.5)
self.scale = alpha / rank
def forward(self, x):
# x: (..., in_features)
return self.base(x) + self.scale * (x @ self.A @ self.B)
Pit-of-failure. Olvidar el escalado (alpha/rank); inicializar tanto A como B aleatoriamente (comportamiento del modelo corrompido en el paso 0).
→ Fase 28.
Drill 06 — Implementa la pérdida DPO¶
Brief. Dados los log-probs de respuestas chosen y rejected bajo la política y bajo una referencia congelada, y un hiperparámetro beta, computa la pérdida DPO.
Permitido. PyTorch.
Restricciones. 15 minutos. Debe usar logsigmoid para estabilidad numérica.
Stub de solución.
import torch
import torch.nn.functional as F
def dpo_loss(
policy_chosen_logps: torch.Tensor, # (B,)
policy_rejected_logps: torch.Tensor,
ref_chosen_logps: torch.Tensor,
ref_rejected_logps: torch.Tensor,
beta: float = 0.1,
):
"""Direct Preference Optimization loss (Rafailov 2023)."""
pi_logratio = policy_chosen_logps - policy_rejected_logps
ref_logratio = ref_chosen_logps - ref_rejected_logps
logits = beta * (pi_logratio - ref_logratio)
loss = -F.logsigmoid(logits).mean()
# Implicit reward margin for logging
chosen_reward = beta * (policy_chosen_logps - ref_chosen_logps).detach()
rejected_reward = beta * (policy_rejected_logps - ref_rejected_logps).detach()
return loss, chosen_reward, rejected_reward
Pit-of-failure. Usar log(sigmoid(...)) en vez de logsigmoid — overflow en pares confidentes. Señales: nunca entrenaste un modelo DPO.
→ Módulo X3.
Drill 07 — Implementa top-p (nucleus) sampling¶
Brief. Dados logits y top_p, muestrea un token.
Permitido. NumPy o PyTorch.
Restricciones. 15 minutos. Debe manejar el caso borde donde el top token ya excede top_p.
Stub de solución.
import torch
def top_p_sample(logits: torch.Tensor, top_p: float = 0.9, temperature: float = 1.0):
# logits: (vocab,)
logits = logits / max(temperature, 1e-6)
probs = torch.softmax(logits, dim=-1)
sorted_probs, sorted_idx = torch.sort(probs, descending=True)
cumulative = torch.cumsum(sorted_probs, dim=-1)
# Keep the smallest set whose cumsum >= top_p.
# We include the first index that crosses top_p (otherwise empty when top_1 > top_p).
cutoff_mask = cumulative > top_p
# Shift right so we always keep at least one token.
cutoff_mask[..., 1:] = cutoff_mask[..., :-1].clone()
cutoff_mask[..., 0] = False
sorted_probs[cutoff_mask] = 0.0
sorted_probs = sorted_probs / sorted_probs.sum()
sample_in_sorted = torch.multinomial(sorted_probs, 1)
return sorted_idx[sample_in_sorted].item()
Pit-of-failure. Set de candidatos vacío cuando la probabilidad top-1 ya excede top_p. Señales: nunca codeaste el caso borde.
→ Fase 21.
Drill 08 — Implementa LayerNorm forward y backward¶
Brief. Dado input x de shape (B, N, D), gamma, beta aprendidos, computa y = gamma * (x - mu) / sqrt(var + eps) + beta, y los gradientes backward.
Permitido. NumPy. Sin autograd.
Restricciones. 30 minutos. El backward es el filtro real.
Stub de solución.
import numpy as np
def layernorm_forward(x, gamma, beta, eps=1e-5):
# x: (B, N, D)
mu = x.mean(axis=-1, keepdims=True)
var = x.var(axis=-1, keepdims=True)
inv = 1.0 / np.sqrt(var + eps)
x_hat = (x - mu) * inv
y = gamma * x_hat + beta
cache = (x_hat, inv, gamma)
return y, cache
def layernorm_backward(dy, cache):
x_hat, inv, gamma = cache
D = x_hat.shape[-1]
# Param grads (sum over all but feature axis).
dgamma = (dy * x_hat).sum(axis=(0, 1))
dbeta = dy.sum(axis=(0, 1))
# Activation grad.
dx_hat = dy * gamma # (B, N, D)
# Standard LN backward derivation:
dx = (1.0 / D) * inv * (
D * dx_hat
- dx_hat.sum(axis=-1, keepdims=True)
- x_hat * (dx_hat * x_hat).sum(axis=-1, keepdims=True)
)
return dx, dgamma, dbeta
Pit-of-failure. Fórmula backward errada por un factor de 1/D. Señales: nunca la derivaste. La mayoría de candidatos falla.
→ Fase 10, Q2.
Drill 09 — Implementa RMSNorm¶
Brief. RMSNorm = LayerNorm sin sustracción de media. y = gamma * x / sqrt(mean(x^2) + eps).
Permitido. PyTorch.
Restricciones. 10 minutos (es corto por diseño).
Stub de solución.
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
# x: (..., dim)
rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
return x * rms * self.weight
Follow-up. "¿Por qué RMSNorm funciona tan bien como LayerNorm?" → hallazgo empírico (Zhang & Sennrich 2019); la sustracción de media no carga peso; ahorra una reducción.
→ Fase 10.
Drill 10 — Implementa RoPE (Rotary Positional Embeddings)¶
Brief. Dados Q, K de shape (B, H, N, D), aplica rotación RoPE por posición.
Permitido. PyTorch.
Restricciones. 25 minutos. Debe usar el pareo estándar interleaved o even-odd.
Stub de solución.
import torch
def build_rope_cache(seq_len, dim, base=10000.0, device="cpu"):
# dim must be even.
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim))
t = torch.arange(seq_len, device=device).float()
freqs = torch.outer(t, inv_freq) # (N, D/2)
cos = freqs.cos() # (N, D/2)
sin = freqs.sin()
return cos, sin
def apply_rope(x, cos, sin):
# x: (B, H, N, D). Split into pairs along D.
B, H, N, D = x.shape
x1 = x[..., 0::2] # (B, H, N, D/2)
x2 = x[..., 1::2]
# cos, sin: (N, D/2) -> broadcast.
cos = cos[None, None, :, :] # (1, 1, N, D/2)
sin = sin[None, None, :, :]
rotated_1 = x1 * cos - x2 * sin
rotated_2 = x1 * sin + x2 * cos
out = torch.stack([rotated_1, rotated_2], dim=-1).reshape(B, H, N, D)
return out
Pit-of-failure. Confundir qué convención de pareo (even-odd vs first-half / second-half — Llama usa convenciones diferentes entre versiones). Documenta la convención que usaste.
→ Fase 16, Q11.
Drill 11 — Implementa beam search decoding¶
Brief. Dada una función step(prefix) -> log_probs sobre vocab, devuelve las top-B secuencias completas de longitud T (o paradas en EOS).
Permitido. PyTorch.
Restricciones. 30 minutos.
Stub de solución.
import torch
import heapq
def beam_search(step_fn, start_token, eos_token, beam_size=4, max_len=64):
# Each beam: (score, tokens, finished).
beams = [(0.0, [start_token], False)]
finished = []
for _ in range(max_len):
candidates = []
for score, tokens, done in beams:
if done:
candidates.append((score, tokens, True))
continue
log_probs = step_fn(tokens) # (vocab,)
topk = torch.topk(log_probs, beam_size)
for lp, idx in zip(topk.values.tolist(), topk.indices.tolist()):
new_tokens = tokens + [idx]
new_done = (idx == eos_token)
candidates.append((score + lp, new_tokens, new_done))
# Keep the top beam_size by score.
candidates.sort(key=lambda x: x[0], reverse=True)
beams = candidates[:beam_size]
if all(b[2] for b in beams):
break
# Optional: length-normalize for fairness.
beams.sort(key=lambda b: b[0] / max(len(b[1]), 1), reverse=True)
return beams
Pit-of-failure. No manejar secuencias EOS / done (siguen "expandiéndose" y diluyen el beam). No normalizar por longitud (las secuencias cortas ganan injustamente).
→ Fase 21.
Drill 12 — Implementa un continuous batcher (esqueleto)¶
Brief. Construye la lógica del scheduler que admite requests nuevos mientras los requests existentes están a mitad de decode, rellenando el batch con requests nuevos de prefill y removiendo los terminados.
Permitido. Python. Sin GPU real; nivel pseudo-código es aceptable.
Restricciones. 30 minutos. Enfócate en la lógica de scheduling, no en el kernel.
Stub de solución.
from dataclasses import dataclass, field
from collections import deque
from typing import Callable, Any
@dataclass
class Request:
req_id: int
prompt_tokens: list
generated: list = field(default_factory=list)
max_new_tokens: int = 256
done: bool = False
class ContinuousBatcher:
def __init__(self, max_batch_size: int, prefill_fn: Callable, decode_fn: Callable):
self.max_batch = max_batch_size
self.prefill_fn = prefill_fn # (list[Request]) -> per-req state (KV cache)
self.decode_fn = decode_fn # (list[Request], states) -> next tokens
self.waiting = deque() # pending requests
self.active = [] # currently decoding
self.states = {} # req_id -> opaque state
def submit(self, req: Request):
self.waiting.append(req)
def step(self):
# 1. Admit new requests up to capacity.
while self.waiting and len(self.active) < self.max_batch:
new_req = self.waiting.popleft()
state = self.prefill_fn([new_req])[0]
self.states[new_req.req_id] = state
self.active.append(new_req)
# 2. Decode one token for everyone active.
if not self.active:
return
next_tokens = self.decode_fn(self.active, [self.states[r.req_id] for r in self.active])
# 3. Update each request; mark done if EOS or max length.
still_active = []
for r, tok in zip(self.active, next_tokens):
r.generated.append(tok)
if tok == EOS or len(r.generated) >= r.max_new_tokens:
r.done = True
self.states.pop(r.req_id)
else:
still_active.append(r)
self.active = still_active
EOS = -1 # convention for this stub
Pit-of-failure. Estancar el batch hasta que todos terminen (el anti-patrón "static batching"). El propósito entero del continuous batching es no esperar.
→ Fase 33, prompt 4 de 02-systems-design-for-llms.md.
Calendario de progresión de drills¶
Un calendario sugerido de 14 días:
| Día | Drills | Notas |
|---|---|---|
| 1 | 01, 07 | Attention + sampling, los dos drills más comunes de phone-screen |
| 2 | 02 | BPE — lento pero mecánico |
| 3 | 04 | KV cache |
| 4 | 08 | LayerNorm backward — programa tiempo extra |
| 5 | 05, 09 | LoRA + RMSNorm |
| 6 | 10 | RoPE |
| 7 | 06 | DPO loss |
| 8 | 11 | Beam search |
| 9 | 03 | Gradient checkpointing |
| 10 | 12 | Continuous batcher |
| 11-14 | Re-roll bajo temporizador | Elige 3 drills/día al azar; 30 min cada uno |
→ Siguiente: 05-behavioral-and-storytelling.md