Skip to content

English · Español

04 — Coding Drills: 12 Ejercicios Implement-It-From-Scratch

12 ejercicios de coding tipo entrevista (attention, BPE, gradient checkpointing, KV cache, LoRA, DPO loss, top-p, LayerNorm fwd+bwd, RMSNorm, RoPE, beam search, continuous batcher). 30 minutos cada uno + stub de solución de referencia.

Cómo drillear

  • Presupuesto de 30 minutos. Pizarra o archivo .py en blanco sin internet.
  • NumPy preferido (más portable a pizarra); PyTorch aceptable para los drills más difíciles.
  • El stub de solución al final de cada sección es ~40 líneas; escribe más si hace falta pero la referencia es corta.
  • Enlace cruzado: cada drill nombra la fase que te preparó para él.

Por qué estos 12 específicamente

Estos son los drills que aparecen repetidamente en phone screens y on-sites 2024-2026 en Anthropic, OpenAI, Google Brain, DeepMind, xAI, Mistral, Cohere, y los principales scaleups. El drill "implementa attention" en particular es el filtro probado — la mayoría de candidatos conocen attention conceptualmente pero no pueden escribirla en un archivo en blanco en 20 minutos.


Drill 01 — Implementa scaled dot-product attention (solo forward)

Brief. Dados Q, K, V de shape (batch, n_heads, seq, d_head) y una máscara causal opcional, devuelve el output de attention del mismo shape.

Permitido. NumPy o PyTorch. Sin torch.nn.functional.scaled_dot_product_attention.

Restricciones. 20 minutos. Debe incluir escalado sqrt(d_k). Debe manejar masking causal.

Stub de solución.

import numpy as np

def scaled_dot_product_attention(Q, K, V, causal=False):
    # Q, K, V: (B, H, N, D)
    B, H, N, D = Q.shape
    scores = np.matmul(Q, np.swapaxes(K, -1, -2)) / np.sqrt(D)   # (B, H, N, N)
    if causal:
        mask = np.triu(np.ones((N, N), dtype=bool), k=1)
        scores = np.where(mask, -1e9, scores)
    # Numerically stable softmax
    scores = scores - scores.max(axis=-1, keepdims=True)
    weights = np.exp(scores)
    weights = weights / weights.sum(axis=-1, keepdims=True)
    out = np.matmul(weights, V)                                  # (B, H, N, D)
    return out

Bugs comunes a señalar. - Olvidar el escalado sqrt(d_k). - Máscara causal aplicada en el eje equivocado (error de transpuesta). - Overflow del softmax porque se omite el truco - max(scores).

→ Fase 15.


Drill 02 — Implementa tokenizer BPE (entrenamiento + encode)

Brief. Entrena un tokenizer BPE byte-level en un corpus pequeño a vocab_size = 1000, luego codifica un string nuevo con los merges aprendidos.

Permitido. Biblioteca estándar de Python.

Restricciones. 30 minutos. Debe manejar input byte-level (bytes.encode("utf-8")).

Stub de solución.

from collections import Counter

def train_bpe(corpus_bytes, vocab_size=1000):
    # Initial tokenization: each byte is a token.
    tokens = [list(b) for b in corpus_bytes]                 # list of list of int
    merges = []
    next_id = 256
    while next_id < vocab_size:
        pair_counts = Counter()
        for seq in tokens:
            for a, b in zip(seq, seq[1:]):
                pair_counts[(a, b)] += 1
        if not pair_counts:
            break
        best_pair, _ = pair_counts.most_common(1)[0]
        merges.append((best_pair, next_id))
        # Apply merge.
        new_tokens = []
        for seq in tokens:
            merged = []
            i = 0
            while i < len(seq):
                if i + 1 < len(seq) and (seq[i], seq[i+1]) == best_pair:
                    merged.append(next_id)
                    i += 2
                else:
                    merged.append(seq[i])
                    i += 1
            new_tokens.append(merged)
        tokens = new_tokens
        next_id += 1
    return merges

def encode(text, merges):
    seq = list(text.encode("utf-8"))
    for pair, new_id in merges:
        out = []
        i = 0
        while i < len(seq):
            if i + 1 < len(seq) and (seq[i], seq[i+1]) == pair:
                out.append(new_id)
                i += 2
            else:
                out.append(seq[i])
                i += 1
        seq = out
    return seq

Pit-of-failure. Aplicar merges en cualquier orden en vez de en train-order. Señales: nunca depuraste un tokenizer.

→ Fase 11.


Drill 03 — Implementa gradient checkpointing para un MLP de 4 capas

Brief. Construye un MLP de 4 capas. Implementa gradient checkpointing manualmente: no almacenes activaciones entre las capas 2 y 3; recomputa en backward.

Permitido. PyTorch. torch.utils.checkpoint.checkpoint no permitido — escríbelo tú con torch.autograd.Function.

Restricciones. 25 minutos.

Stub de solución.

import torch
from torch.autograd import Function

class Checkpoint(Function):
    @staticmethod
    def forward(ctx, run_function, x, *params):
        ctx.run_function = run_function
        ctx.save_for_backward(x, *params)
        with torch.no_grad():
            y = run_function(x, *params)
        return y

    @staticmethod
    def backward(ctx, grad_y):
        x, *params = ctx.saved_tensors
        # Recompute with grad enabled
        x = x.detach().requires_grad_(True)
        params = [p.detach().requires_grad_(True) for p in params]
        with torch.enable_grad():
            y = ctx.run_function(x, *params)
        grads = torch.autograd.grad(y, [x, *params], grad_y)
        return (None, *grads)

def checkpoint(fn, x, *params):
    return Checkpoint.apply(fn, x, *params)

Pit-of-failure. Olvidar detachar inputs antes de requires_grad-earlos. Señales: nunca escribiste un autograd Function.

→ Fase 07, Fase 08.


Drill 04 — Implementa un KV cache para decoding autorregresivo

Brief. Dado un bloque transformer, implementa step(x_new, cache) que use el cache para el prefijo y solo compute attention para el token nuevo.

Permitido. PyTorch.

Restricciones. 25 minutos. Debe manejar la primera llamada (cache vacío) correctamente.

Stub de solución.

import torch

class KVCache:
    def __init__(self):
        self.k = None  # (B, H, N, D)
        self.v = None

    def update(self, k_new, v_new):
        # k_new, v_new: (B, H, 1, D) typically
        if self.k is None:
            self.k = k_new
            self.v = v_new
        else:
            self.k = torch.cat([self.k, k_new], dim=2)
            self.v = torch.cat([self.v, v_new], dim=2)
        return self.k, self.v

def attention_step(x_new, W_q, W_k, W_v, W_o, cache, n_heads):
    # x_new: (B, 1, d_model)
    B, _, d = x_new.shape
    d_head = d // n_heads
    q = (x_new @ W_q).view(B, 1, n_heads, d_head).transpose(1, 2)
    k = (x_new @ W_k).view(B, 1, n_heads, d_head).transpose(1, 2)
    v = (x_new @ W_v).view(B, 1, n_heads, d_head).transpose(1, 2)
    K, V = cache.update(k, v)
    # Attention over the full prefix (no need for causal mask: only one new query)
    scores = (q @ K.transpose(-1, -2)) / (d_head ** 0.5)         # (B, H, 1, N)
    weights = torch.softmax(scores, dim=-1)
    attn = weights @ V                                            # (B, H, 1, D)
    out = attn.transpose(1, 2).reshape(B, 1, d) @ W_o
    return out

Pit-of-failure. Concatenar en el eje equivocado; olvidar que la query es solo 1 token pero key/value abarcan el prefijo entero.

→ Fase 22.


Drill 05 — Implementa una capa LoRA Linear

Brief. Construye un módulo LoRALinear(in_features, out_features, rank, alpha). El forward debe equivaler a x @ W^T + (alpha/rank) * x @ A @ B. El W base está congelado.

Permitido. PyTorch.

Restricciones. 20 minutos. Inicialización: A Kaiming, B ceros, así el delta LoRA arranca en cero.

Stub de solución.

import torch
import torch.nn as nn

class LoRALinear(nn.Module):
    def __init__(self, in_features, out_features, rank=8, alpha=16):
        super().__init__()
        self.base = nn.Linear(in_features, out_features, bias=False)
        for p in self.base.parameters():
            p.requires_grad = False
        # Low-rank factors
        self.A = nn.Parameter(torch.empty(in_features, rank))
        self.B = nn.Parameter(torch.zeros(rank, out_features))
        nn.init.kaiming_uniform_(self.A, a=5 ** 0.5)
        self.scale = alpha / rank

    def forward(self, x):
        # x: (..., in_features)
        return self.base(x) + self.scale * (x @ self.A @ self.B)

Pit-of-failure. Olvidar el escalado (alpha/rank); inicializar tanto A como B aleatoriamente (comportamiento del modelo corrompido en el paso 0).

→ Fase 28.


Drill 06 — Implementa la pérdida DPO

Brief. Dados los log-probs de respuestas chosen y rejected bajo la política y bajo una referencia congelada, y un hiperparámetro beta, computa la pérdida DPO.

Permitido. PyTorch.

Restricciones. 15 minutos. Debe usar logsigmoid para estabilidad numérica.

Stub de solución.

import torch
import torch.nn.functional as F

def dpo_loss(
    policy_chosen_logps: torch.Tensor,  # (B,)
    policy_rejected_logps: torch.Tensor,
    ref_chosen_logps: torch.Tensor,
    ref_rejected_logps: torch.Tensor,
    beta: float = 0.1,
):
    """Direct Preference Optimization loss (Rafailov 2023)."""
    pi_logratio = policy_chosen_logps - policy_rejected_logps
    ref_logratio = ref_chosen_logps - ref_rejected_logps
    logits = beta * (pi_logratio - ref_logratio)
    loss = -F.logsigmoid(logits).mean()
    # Implicit reward margin for logging
    chosen_reward = beta * (policy_chosen_logps - ref_chosen_logps).detach()
    rejected_reward = beta * (policy_rejected_logps - ref_rejected_logps).detach()
    return loss, chosen_reward, rejected_reward

Pit-of-failure. Usar log(sigmoid(...)) en vez de logsigmoid — overflow en pares confidentes. Señales: nunca entrenaste un modelo DPO.

→ Módulo X3.


Drill 07 — Implementa top-p (nucleus) sampling

Brief. Dados logits y top_p, muestrea un token.

Permitido. NumPy o PyTorch.

Restricciones. 15 minutos. Debe manejar el caso borde donde el top token ya excede top_p.

Stub de solución.

import torch

def top_p_sample(logits: torch.Tensor, top_p: float = 0.9, temperature: float = 1.0):
    # logits: (vocab,)
    logits = logits / max(temperature, 1e-6)
    probs = torch.softmax(logits, dim=-1)
    sorted_probs, sorted_idx = torch.sort(probs, descending=True)
    cumulative = torch.cumsum(sorted_probs, dim=-1)
    # Keep the smallest set whose cumsum >= top_p.
    # We include the first index that crosses top_p (otherwise empty when top_1 > top_p).
    cutoff_mask = cumulative > top_p
    # Shift right so we always keep at least one token.
    cutoff_mask[..., 1:] = cutoff_mask[..., :-1].clone()
    cutoff_mask[..., 0] = False
    sorted_probs[cutoff_mask] = 0.0
    sorted_probs = sorted_probs / sorted_probs.sum()
    sample_in_sorted = torch.multinomial(sorted_probs, 1)
    return sorted_idx[sample_in_sorted].item()

Pit-of-failure. Set de candidatos vacío cuando la probabilidad top-1 ya excede top_p. Señales: nunca codeaste el caso borde.

→ Fase 21.


Drill 08 — Implementa LayerNorm forward y backward

Brief. Dado input x de shape (B, N, D), gamma, beta aprendidos, computa y = gamma * (x - mu) / sqrt(var + eps) + beta, y los gradientes backward.

Permitido. NumPy. Sin autograd.

Restricciones. 30 minutos. El backward es el filtro real.

Stub de solución.

import numpy as np

def layernorm_forward(x, gamma, beta, eps=1e-5):
    # x: (B, N, D)
    mu = x.mean(axis=-1, keepdims=True)
    var = x.var(axis=-1, keepdims=True)
    inv = 1.0 / np.sqrt(var + eps)
    x_hat = (x - mu) * inv
    y = gamma * x_hat + beta
    cache = (x_hat, inv, gamma)
    return y, cache

def layernorm_backward(dy, cache):
    x_hat, inv, gamma = cache
    D = x_hat.shape[-1]
    # Param grads (sum over all but feature axis).
    dgamma = (dy * x_hat).sum(axis=(0, 1))
    dbeta = dy.sum(axis=(0, 1))
    # Activation grad.
    dx_hat = dy * gamma                                            # (B, N, D)
    # Standard LN backward derivation:
    dx = (1.0 / D) * inv * (
        D * dx_hat
        - dx_hat.sum(axis=-1, keepdims=True)
        - x_hat * (dx_hat * x_hat).sum(axis=-1, keepdims=True)
    )
    return dx, dgamma, dbeta

Pit-of-failure. Fórmula backward errada por un factor de 1/D. Señales: nunca la derivaste. La mayoría de candidatos falla.

→ Fase 10, Q2.


Drill 09 — Implementa RMSNorm

Brief. RMSNorm = LayerNorm sin sustracción de media. y = gamma * x / sqrt(mean(x^2) + eps).

Permitido. PyTorch.

Restricciones. 10 minutos (es corto por diseño).

Stub de solución.

import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        # x: (..., dim)
        rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
        return x * rms * self.weight

Follow-up. "¿Por qué RMSNorm funciona tan bien como LayerNorm?" → hallazgo empírico (Zhang & Sennrich 2019); la sustracción de media no carga peso; ahorra una reducción.

→ Fase 10.


Drill 10 — Implementa RoPE (Rotary Positional Embeddings)

Brief. Dados Q, K de shape (B, H, N, D), aplica rotación RoPE por posición.

Permitido. PyTorch.

Restricciones. 25 minutos. Debe usar el pareo estándar interleaved o even-odd.

Stub de solución.

import torch

def build_rope_cache(seq_len, dim, base=10000.0, device="cpu"):
    # dim must be even.
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim))
    t = torch.arange(seq_len, device=device).float()
    freqs = torch.outer(t, inv_freq)                              # (N, D/2)
    cos = freqs.cos()                                              # (N, D/2)
    sin = freqs.sin()
    return cos, sin

def apply_rope(x, cos, sin):
    # x: (B, H, N, D). Split into pairs along D.
    B, H, N, D = x.shape
    x1 = x[..., 0::2]                                              # (B, H, N, D/2)
    x2 = x[..., 1::2]
    # cos, sin: (N, D/2) -> broadcast.
    cos = cos[None, None, :, :]                                    # (1, 1, N, D/2)
    sin = sin[None, None, :, :]
    rotated_1 = x1 * cos - x2 * sin
    rotated_2 = x1 * sin + x2 * cos
    out = torch.stack([rotated_1, rotated_2], dim=-1).reshape(B, H, N, D)
    return out

Pit-of-failure. Confundir qué convención de pareo (even-odd vs first-half / second-half — Llama usa convenciones diferentes entre versiones). Documenta la convención que usaste.

→ Fase 16, Q11.


Drill 11 — Implementa beam search decoding

Brief. Dada una función step(prefix) -> log_probs sobre vocab, devuelve las top-B secuencias completas de longitud T (o paradas en EOS).

Permitido. PyTorch.

Restricciones. 30 minutos.

Stub de solución.

import torch
import heapq

def beam_search(step_fn, start_token, eos_token, beam_size=4, max_len=64):
    # Each beam: (score, tokens, finished).
    beams = [(0.0, [start_token], False)]
    finished = []
    for _ in range(max_len):
        candidates = []
        for score, tokens, done in beams:
            if done:
                candidates.append((score, tokens, True))
                continue
            log_probs = step_fn(tokens)                            # (vocab,)
            topk = torch.topk(log_probs, beam_size)
            for lp, idx in zip(topk.values.tolist(), topk.indices.tolist()):
                new_tokens = tokens + [idx]
                new_done = (idx == eos_token)
                candidates.append((score + lp, new_tokens, new_done))
        # Keep the top beam_size by score.
        candidates.sort(key=lambda x: x[0], reverse=True)
        beams = candidates[:beam_size]
        if all(b[2] for b in beams):
            break
    # Optional: length-normalize for fairness.
    beams.sort(key=lambda b: b[0] / max(len(b[1]), 1), reverse=True)
    return beams

Pit-of-failure. No manejar secuencias EOS / done (siguen "expandiéndose" y diluyen el beam). No normalizar por longitud (las secuencias cortas ganan injustamente).

→ Fase 21.


Drill 12 — Implementa un continuous batcher (esqueleto)

Brief. Construye la lógica del scheduler que admite requests nuevos mientras los requests existentes están a mitad de decode, rellenando el batch con requests nuevos de prefill y removiendo los terminados.

Permitido. Python. Sin GPU real; nivel pseudo-código es aceptable.

Restricciones. 30 minutos. Enfócate en la lógica de scheduling, no en el kernel.

Stub de solución.

from dataclasses import dataclass, field
from collections import deque
from typing import Callable, Any

@dataclass
class Request:
    req_id: int
    prompt_tokens: list
    generated: list = field(default_factory=list)
    max_new_tokens: int = 256
    done: bool = False

class ContinuousBatcher:
    def __init__(self, max_batch_size: int, prefill_fn: Callable, decode_fn: Callable):
        self.max_batch = max_batch_size
        self.prefill_fn = prefill_fn      # (list[Request]) -> per-req state (KV cache)
        self.decode_fn = decode_fn        # (list[Request], states) -> next tokens
        self.waiting = deque()            # pending requests
        self.active = []                  # currently decoding
        self.states = {}                  # req_id -> opaque state

    def submit(self, req: Request):
        self.waiting.append(req)

    def step(self):
        # 1. Admit new requests up to capacity.
        while self.waiting and len(self.active) < self.max_batch:
            new_req = self.waiting.popleft()
            state = self.prefill_fn([new_req])[0]
            self.states[new_req.req_id] = state
            self.active.append(new_req)
        # 2. Decode one token for everyone active.
        if not self.active:
            return
        next_tokens = self.decode_fn(self.active, [self.states[r.req_id] for r in self.active])
        # 3. Update each request; mark done if EOS or max length.
        still_active = []
        for r, tok in zip(self.active, next_tokens):
            r.generated.append(tok)
            if tok == EOS or len(r.generated) >= r.max_new_tokens:
                r.done = True
                self.states.pop(r.req_id)
            else:
                still_active.append(r)
        self.active = still_active

EOS = -1   # convention for this stub

Pit-of-failure. Estancar el batch hasta que todos terminen (el anti-patrón "static batching"). El propósito entero del continuous batching es no esperar.

→ Fase 33, prompt 4 de 02-systems-design-for-llms.md.


Calendario de progresión de drills

Un calendario sugerido de 14 días:

Día Drills Notas
1 01, 07 Attention + sampling, los dos drills más comunes de phone-screen
2 02 BPE — lento pero mecánico
3 04 KV cache
4 08 LayerNorm backward — programa tiempo extra
5 05, 09 LoRA + RMSNorm
6 10 RoPE
7 06 DPO loss
8 11 Beam search
9 03 Gradient checkpointing
10 12 Continuous batcher
11-14 Re-roll bajo temporizador Elige 3 drills/día al azar; 30 min cada uno

→ Siguiente: 05-behavioral-and-storytelling.md