English · Español
04 — Coding Drills: 12 Implement-It-From-Scratch Exercises¶
🇪🇸 12 ejercicios de coding tipo entrevista (atención, BPE, gradient checkpointing, KV cache, LoRA, DPO loss, top-p, LayerNorm fwd+bwd, RMSNorm, RoPE, beam search, continuous batcher). 30 minutos cada uno + stub de solución de referencia.
How to drill¶
- 30-minute budget. Whiteboard or blank
.pyfile with no internet. - NumPy preferred (more portable to whiteboard); PyTorch acceptable for the harder drills.
- Solution stub at the end of each section is ~40 lines; write more if needed but the reference is short.
- Cross-link: each drill names the phase that prepared you for it.
Why these 12 specifically¶
These are the drills that appear repeatedly in 2024-2026 phone screens and on-sites at Anthropic, OpenAI, Google Brain, DeepMind, xAI, Mistral, Cohere, and the major scaleups. The "implement attention" drill in particular is the proven filter — most candidates know attention conceptually but cannot write it on a blank file in 20 minutes.
Drill 01 — Implement scaled dot-product attention (forward only)¶
Brief. Given Q, K, V of shape (batch, n_heads, seq, d_head) and an optional causal mask, return the attention output of the same shape.
Allowed. NumPy or PyTorch. No torch.nn.functional.scaled_dot_product_attention.
Constraints. 20 minutes. Must include sqrt(d_k) scaling. Must handle causal masking.
Solution stub.
import numpy as np
def scaled_dot_product_attention(Q, K, V, causal=False):
# Q, K, V: (B, H, N, D)
B, H, N, D = Q.shape
scores = np.matmul(Q, np.swapaxes(K, -1, -2)) / np.sqrt(D) # (B, H, N, N)
if causal:
mask = np.triu(np.ones((N, N), dtype=bool), k=1)
scores = np.where(mask, -1e9, scores)
# Numerically stable softmax
scores = scores - scores.max(axis=-1, keepdims=True)
weights = np.exp(scores)
weights = weights / weights.sum(axis=-1, keepdims=True)
out = np.matmul(weights, V) # (B, H, N, D)
return out
Common bugs to flag.
- Forgetting sqrt(d_k) scaling.
- Causal mask applied on wrong axis (transpose error).
- Softmax overflow because the - max(scores) trick is omitted.
→ Phase 15.
Drill 02 — Implement BPE tokenizer (training + encode)¶
Brief. Train a byte-level BPE tokenizer on a small corpus to vocab_size = 1000, then encode a new string with the learned merges.
Allowed. Python standard library.
Constraints. 30 minutes. Must handle byte-level input (bytes.encode("utf-8")).
Solution stub.
from collections import Counter
def train_bpe(corpus_bytes, vocab_size=1000):
# Initial tokenization: each byte is a token.
tokens = [list(b) for b in corpus_bytes] # list of list of int
merges = []
next_id = 256
while next_id < vocab_size:
pair_counts = Counter()
for seq in tokens:
for a, b in zip(seq, seq[1:]):
pair_counts[(a, b)] += 1
if not pair_counts:
break
best_pair, _ = pair_counts.most_common(1)[0]
merges.append((best_pair, next_id))
# Apply merge.
new_tokens = []
for seq in tokens:
merged = []
i = 0
while i < len(seq):
if i + 1 < len(seq) and (seq[i], seq[i+1]) == best_pair:
merged.append(next_id)
i += 2
else:
merged.append(seq[i])
i += 1
new_tokens.append(merged)
tokens = new_tokens
next_id += 1
return merges
def encode(text, merges):
seq = list(text.encode("utf-8"))
for pair, new_id in merges:
out = []
i = 0
while i < len(seq):
if i + 1 < len(seq) and (seq[i], seq[i+1]) == pair:
out.append(new_id)
i += 2
else:
out.append(seq[i])
i += 1
seq = out
return seq
Pit-of-failure. Applying merges in any-order instead of train-order. Signals: never debugged a tokenizer.
→ Phase 11.
Drill 03 — Implement gradient checkpointing for a 4-layer MLP¶
Brief. Build a 4-layer MLP. Implement gradient checkpointing manually: do not store activations between layers 2 and 3; recompute on backward.
Allowed. PyTorch. torch.utils.checkpoint.checkpoint is not allowed — write it yourself with torch.autograd.Function.
Constraints. 25 minutes.
Solution stub.
import torch
from torch.autograd import Function
class Checkpoint(Function):
@staticmethod
def forward(ctx, run_function, x, *params):
ctx.run_function = run_function
ctx.save_for_backward(x, *params)
with torch.no_grad():
y = run_function(x, *params)
return y
@staticmethod
def backward(ctx, grad_y):
x, *params = ctx.saved_tensors
# Recompute with grad enabled
x = x.detach().requires_grad_(True)
params = [p.detach().requires_grad_(True) for p in params]
with torch.enable_grad():
y = ctx.run_function(x, *params)
grads = torch.autograd.grad(y, [x, *params], grad_y)
return (None, *grads)
def checkpoint(fn, x, *params):
return Checkpoint.apply(fn, x, *params)
Pit-of-failure. Forgetting to detach inputs before requires_grad-ing. Signals: never wrote an autograd Function.
→ Phase 07, Phase 08.
Drill 04 — Implement a KV cache for autoregressive decoding¶
Brief. Given a transformer block, implement step(x_new, cache) that uses the cache for the prefix and only computes attention for the new token.
Allowed. PyTorch.
Constraints. 25 minutes. Must handle the first call (empty cache) correctly.
Solution stub.
import torch
class KVCache:
def __init__(self):
self.k = None # (B, H, N, D)
self.v = None
def update(self, k_new, v_new):
# k_new, v_new: (B, H, 1, D) typically
if self.k is None:
self.k = k_new
self.v = v_new
else:
self.k = torch.cat([self.k, k_new], dim=2)
self.v = torch.cat([self.v, v_new], dim=2)
return self.k, self.v
def attention_step(x_new, W_q, W_k, W_v, W_o, cache, n_heads):
# x_new: (B, 1, d_model)
B, _, d = x_new.shape
d_head = d // n_heads
q = (x_new @ W_q).view(B, 1, n_heads, d_head).transpose(1, 2)
k = (x_new @ W_k).view(B, 1, n_heads, d_head).transpose(1, 2)
v = (x_new @ W_v).view(B, 1, n_heads, d_head).transpose(1, 2)
K, V = cache.update(k, v)
# Attention over the full prefix (no need for causal mask: only one new query)
scores = (q @ K.transpose(-1, -2)) / (d_head ** 0.5) # (B, H, 1, N)
weights = torch.softmax(scores, dim=-1)
attn = weights @ V # (B, H, 1, D)
out = attn.transpose(1, 2).reshape(B, 1, d) @ W_o
return out
Pit-of-failure. Concatenating on wrong axis; forgetting that the query is only 1 token but key/value span the whole prefix.
→ Phase 22.
Drill 05 — Implement a LoRA Linear layer¶
Brief. Build a LoRALinear(in_features, out_features, rank, alpha) module. Forward must equal x @ W^T + (alpha/rank) * x @ A @ B. The base W is frozen.
Allowed. PyTorch.
Constraints. 20 minutes. Initialization: A Kaiming, B zeros, so LoRA delta starts at zero.
Solution stub.
import torch
import torch.nn as nn
class LoRALinear(nn.Module):
def __init__(self, in_features, out_features, rank=8, alpha=16):
super().__init__()
self.base = nn.Linear(in_features, out_features, bias=False)
for p in self.base.parameters():
p.requires_grad = False
# Low-rank factors
self.A = nn.Parameter(torch.empty(in_features, rank))
self.B = nn.Parameter(torch.zeros(rank, out_features))
nn.init.kaiming_uniform_(self.A, a=5 ** 0.5)
self.scale = alpha / rank
def forward(self, x):
# x: (..., in_features)
return self.base(x) + self.scale * (x @ self.A @ self.B)
Pit-of-failure. Forgetting the (alpha/rank) scaling; initializing both A and B randomly (model behavior corrupted at step 0).
→ Phase 28.
Drill 06 — Implement the DPO loss¶
Brief. Given log-probs of chosen and rejected responses under the policy and under a frozen reference, and a beta hyperparameter, compute the DPO loss.
Allowed. PyTorch.
Constraints. 15 minutes. Must use logsigmoid for numerical stability.
Solution stub.
import torch
import torch.nn.functional as F
def dpo_loss(
policy_chosen_logps: torch.Tensor, # (B,)
policy_rejected_logps: torch.Tensor,
ref_chosen_logps: torch.Tensor,
ref_rejected_logps: torch.Tensor,
beta: float = 0.1,
):
"""Direct Preference Optimization loss (Rafailov 2023)."""
pi_logratio = policy_chosen_logps - policy_rejected_logps
ref_logratio = ref_chosen_logps - ref_rejected_logps
logits = beta * (pi_logratio - ref_logratio)
loss = -F.logsigmoid(logits).mean()
# Implicit reward margin for logging
chosen_reward = beta * (policy_chosen_logps - ref_chosen_logps).detach()
rejected_reward = beta * (policy_rejected_logps - ref_rejected_logps).detach()
return loss, chosen_reward, rejected_reward
Pit-of-failure. Using log(sigmoid(...)) instead of logsigmoid — overflow on confident pairs. Signals: never trained a DPO model.
→ X3 module.
Drill 07 — Implement top-p (nucleus) sampling¶
Brief. Given logits and top_p, sample a token.
Allowed. NumPy or PyTorch.
Constraints. 15 minutes. Must handle the edge case where the top token already exceeds top_p.
Solution stub.
import torch
def top_p_sample(logits: torch.Tensor, top_p: float = 0.9, temperature: float = 1.0):
# logits: (vocab,)
logits = logits / max(temperature, 1e-6)
probs = torch.softmax(logits, dim=-1)
sorted_probs, sorted_idx = torch.sort(probs, descending=True)
cumulative = torch.cumsum(sorted_probs, dim=-1)
# Keep the smallest set whose cumsum >= top_p.
# We include the first index that crosses top_p (otherwise empty when top_1 > top_p).
cutoff_mask = cumulative > top_p
# Shift right so we always keep at least one token.
cutoff_mask[..., 1:] = cutoff_mask[..., :-1].clone()
cutoff_mask[..., 0] = False
sorted_probs[cutoff_mask] = 0.0
sorted_probs = sorted_probs / sorted_probs.sum()
sample_in_sorted = torch.multinomial(sorted_probs, 1)
return sorted_idx[sample_in_sorted].item()
Pit-of-failure. Empty candidate set when top-1 probability already exceeds top_p. Signals: never coded the edge case.
→ Phase 21.
Drill 08 — Implement LayerNorm forward and backward¶
Brief. Given input x of shape (B, N, D), learned gamma, beta, compute y = gamma * (x - mu) / sqrt(var + eps) + beta, and the backward gradients.
Allowed. NumPy. No autograd.
Constraints. 30 minutes. The backward is the actual filter.
Solution stub.
import numpy as np
def layernorm_forward(x, gamma, beta, eps=1e-5):
# x: (B, N, D)
mu = x.mean(axis=-1, keepdims=True)
var = x.var(axis=-1, keepdims=True)
inv = 1.0 / np.sqrt(var + eps)
x_hat = (x - mu) * inv
y = gamma * x_hat + beta
cache = (x_hat, inv, gamma)
return y, cache
def layernorm_backward(dy, cache):
x_hat, inv, gamma = cache
D = x_hat.shape[-1]
# Param grads (sum over all but feature axis).
dgamma = (dy * x_hat).sum(axis=(0, 1))
dbeta = dy.sum(axis=(0, 1))
# Activation grad.
dx_hat = dy * gamma # (B, N, D)
# Standard LN backward derivation:
dx = (1.0 / D) * inv * (
D * dx_hat
- dx_hat.sum(axis=-1, keepdims=True)
- x_hat * (dx_hat * x_hat).sum(axis=-1, keepdims=True)
)
return dx, dgamma, dbeta
Pit-of-failure. Backward formula off by a factor of 1/D. Signals: never derived it. Most candidates fail.
→ Phase 10, Q2.
Drill 09 — Implement RMSNorm¶
Brief. RMSNorm = LayerNorm without mean subtraction. y = gamma * x / sqrt(mean(x^2) + eps).
Allowed. PyTorch.
Constraints. 10 minutes (it is short by design).
Solution stub.
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
# x: (..., dim)
rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
return x * rms * self.weight
Follow-up. "Why does RMSNorm work as well as LayerNorm?" → empirical finding (Zhang & Sennrich 2019); mean-subtraction is not load-bearing; saves one reduction.
→ Phase 10.
Drill 10 — Implement RoPE (Rotary Positional Embeddings)¶
Brief. Given Q, K of shape (B, H, N, D), apply RoPE rotation by position.
Allowed. PyTorch.
Constraints. 25 minutes. Must use the standard interleaved or even-odd pairing.
Solution stub.
import torch
def build_rope_cache(seq_len, dim, base=10000.0, device="cpu"):
# dim must be even.
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim))
t = torch.arange(seq_len, device=device).float()
freqs = torch.outer(t, inv_freq) # (N, D/2)
cos = freqs.cos() # (N, D/2)
sin = freqs.sin()
return cos, sin
def apply_rope(x, cos, sin):
# x: (B, H, N, D). Split into pairs along D.
B, H, N, D = x.shape
x1 = x[..., 0::2] # (B, H, N, D/2)
x2 = x[..., 1::2]
# cos, sin: (N, D/2) -> broadcast.
cos = cos[None, None, :, :] # (1, 1, N, D/2)
sin = sin[None, None, :, :]
rotated_1 = x1 * cos - x2 * sin
rotated_2 = x1 * sin + x2 * cos
out = torch.stack([rotated_1, rotated_2], dim=-1).reshape(B, H, N, D)
return out
Pit-of-failure. Mixing up which pair convention (even-odd vs first-half / second-half — Llama uses different conventions across versions). Document the convention you used.
→ Phase 16, Q11.
Drill 11 — Implement beam search decoding¶
Brief. Given a function step(prefix) -> log_probs over vocab, return the top-B complete sequences of length T (or stopped at EOS).
Allowed. PyTorch.
Constraints. 30 minutes.
Solution stub.
import torch
import heapq
def beam_search(step_fn, start_token, eos_token, beam_size=4, max_len=64):
# Each beam: (score, tokens, finished).
beams = [(0.0, [start_token], False)]
finished = []
for _ in range(max_len):
candidates = []
for score, tokens, done in beams:
if done:
candidates.append((score, tokens, True))
continue
log_probs = step_fn(tokens) # (vocab,)
topk = torch.topk(log_probs, beam_size)
for lp, idx in zip(topk.values.tolist(), topk.indices.tolist()):
new_tokens = tokens + [idx]
new_done = (idx == eos_token)
candidates.append((score + lp, new_tokens, new_done))
# Keep the top beam_size by score.
candidates.sort(key=lambda x: x[0], reverse=True)
beams = candidates[:beam_size]
if all(b[2] for b in beams):
break
# Optional: length-normalize for fairness.
beams.sort(key=lambda b: b[0] / max(len(b[1]), 1), reverse=True)
return beams
Pit-of-failure. Not handling EOS / done sequences (they continue to "expand" and dilute the beam). Not length-normalizing (short sequences win unfairly).
→ Phase 21.
Drill 12 — Implement a continuous batcher (skeleton)¶
Brief. Build the scheduler logic that admits new requests while existing requests are mid-decode, padding the batch with new prefill requests and removing finished ones.
Allowed. Python. No actual GPU; pseudo-code level is acceptable.
Constraints. 30 minutes. Focus on the scheduling logic, not the kernel.
Solution stub.
from dataclasses import dataclass, field
from collections import deque
from typing import Callable, Any
@dataclass
class Request:
req_id: int
prompt_tokens: list
generated: list = field(default_factory=list)
max_new_tokens: int = 256
done: bool = False
class ContinuousBatcher:
def __init__(self, max_batch_size: int, prefill_fn: Callable, decode_fn: Callable):
self.max_batch = max_batch_size
self.prefill_fn = prefill_fn # (list[Request]) -> per-req state (KV cache)
self.decode_fn = decode_fn # (list[Request], states) -> next tokens
self.waiting = deque() # pending requests
self.active = [] # currently decoding
self.states = {} # req_id -> opaque state
def submit(self, req: Request):
self.waiting.append(req)
def step(self):
# 1. Admit new requests up to capacity.
while self.waiting and len(self.active) < self.max_batch:
new_req = self.waiting.popleft()
state = self.prefill_fn([new_req])[0]
self.states[new_req.req_id] = state
self.active.append(new_req)
# 2. Decode one token for everyone active.
if not self.active:
return
next_tokens = self.decode_fn(self.active, [self.states[r.req_id] for r in self.active])
# 3. Update each request; mark done if EOS or max length.
still_active = []
for r, tok in zip(self.active, next_tokens):
r.generated.append(tok)
if tok == EOS or len(r.generated) >= r.max_new_tokens:
r.done = True
self.states.pop(r.req_id)
else:
still_active.append(r)
self.active = still_active
EOS = -1 # convention for this stub
Pit-of-failure. Stalling the batch until everyone finishes (the "static batching" anti-pattern). Continuous batching's whole point is to not wait.
→ Phase 33, prompt 4 of 02-systems-design-for-llms.md.
Drill-progression schedule¶
A suggested 14-day schedule:
| Day | Drills | Notes |
|---|---|---|
| 1 | 01, 07 | Attention + sampling, the two most common phone-screen drills |
| 2 | 02 | BPE — slow but mechanical |
| 3 | 04 | KV cache |
| 4 | 08 | LayerNorm backward — schedule extra time |
| 5 | 05, 09 | LoRA + RMSNorm |
| 6 | 10 | RoPE |
| 7 | 06 | DPO loss |
| 8 | 11 | Beam search |
| 9 | 03 | Gradient checkpointing |
| 10 | 12 | Continuous batcher |
| 11-14 | Re-roll under timer | Pick 3 drills/day at random; 30 min each |