Skip to content

English · Español

04 — Coding Drills: 12 Implement-It-From-Scratch Exercises

🇪🇸 12 ejercicios de coding tipo entrevista (atención, BPE, gradient checkpointing, KV cache, LoRA, DPO loss, top-p, LayerNorm fwd+bwd, RMSNorm, RoPE, beam search, continuous batcher). 30 minutos cada uno + stub de solución de referencia.

How to drill

  • 30-minute budget. Whiteboard or blank .py file with no internet.
  • NumPy preferred (more portable to whiteboard); PyTorch acceptable for the harder drills.
  • Solution stub at the end of each section is ~40 lines; write more if needed but the reference is short.
  • Cross-link: each drill names the phase that prepared you for it.

Why these 12 specifically

These are the drills that appear repeatedly in 2024-2026 phone screens and on-sites at Anthropic, OpenAI, Google Brain, DeepMind, xAI, Mistral, Cohere, and the major scaleups. The "implement attention" drill in particular is the proven filter — most candidates know attention conceptually but cannot write it on a blank file in 20 minutes.


Drill 01 — Implement scaled dot-product attention (forward only)

Brief. Given Q, K, V of shape (batch, n_heads, seq, d_head) and an optional causal mask, return the attention output of the same shape.

Allowed. NumPy or PyTorch. No torch.nn.functional.scaled_dot_product_attention.

Constraints. 20 minutes. Must include sqrt(d_k) scaling. Must handle causal masking.

Solution stub.

import numpy as np

def scaled_dot_product_attention(Q, K, V, causal=False):
    # Q, K, V: (B, H, N, D)
    B, H, N, D = Q.shape
    scores = np.matmul(Q, np.swapaxes(K, -1, -2)) / np.sqrt(D)   # (B, H, N, N)
    if causal:
        mask = np.triu(np.ones((N, N), dtype=bool), k=1)
        scores = np.where(mask, -1e9, scores)
    # Numerically stable softmax
    scores = scores - scores.max(axis=-1, keepdims=True)
    weights = np.exp(scores)
    weights = weights / weights.sum(axis=-1, keepdims=True)
    out = np.matmul(weights, V)                                  # (B, H, N, D)
    return out

Common bugs to flag. - Forgetting sqrt(d_k) scaling. - Causal mask applied on wrong axis (transpose error). - Softmax overflow because the - max(scores) trick is omitted.

→ Phase 15.


Drill 02 — Implement BPE tokenizer (training + encode)

Brief. Train a byte-level BPE tokenizer on a small corpus to vocab_size = 1000, then encode a new string with the learned merges.

Allowed. Python standard library.

Constraints. 30 minutes. Must handle byte-level input (bytes.encode("utf-8")).

Solution stub.

from collections import Counter

def train_bpe(corpus_bytes, vocab_size=1000):
    # Initial tokenization: each byte is a token.
    tokens = [list(b) for b in corpus_bytes]                 # list of list of int
    merges = []
    next_id = 256
    while next_id < vocab_size:
        pair_counts = Counter()
        for seq in tokens:
            for a, b in zip(seq, seq[1:]):
                pair_counts[(a, b)] += 1
        if not pair_counts:
            break
        best_pair, _ = pair_counts.most_common(1)[0]
        merges.append((best_pair, next_id))
        # Apply merge.
        new_tokens = []
        for seq in tokens:
            merged = []
            i = 0
            while i < len(seq):
                if i + 1 < len(seq) and (seq[i], seq[i+1]) == best_pair:
                    merged.append(next_id)
                    i += 2
                else:
                    merged.append(seq[i])
                    i += 1
            new_tokens.append(merged)
        tokens = new_tokens
        next_id += 1
    return merges

def encode(text, merges):
    seq = list(text.encode("utf-8"))
    for pair, new_id in merges:
        out = []
        i = 0
        while i < len(seq):
            if i + 1 < len(seq) and (seq[i], seq[i+1]) == pair:
                out.append(new_id)
                i += 2
            else:
                out.append(seq[i])
                i += 1
        seq = out
    return seq

Pit-of-failure. Applying merges in any-order instead of train-order. Signals: never debugged a tokenizer.

→ Phase 11.


Drill 03 — Implement gradient checkpointing for a 4-layer MLP

Brief. Build a 4-layer MLP. Implement gradient checkpointing manually: do not store activations between layers 2 and 3; recompute on backward.

Allowed. PyTorch. torch.utils.checkpoint.checkpoint is not allowed — write it yourself with torch.autograd.Function.

Constraints. 25 minutes.

Solution stub.

import torch
from torch.autograd import Function

class Checkpoint(Function):
    @staticmethod
    def forward(ctx, run_function, x, *params):
        ctx.run_function = run_function
        ctx.save_for_backward(x, *params)
        with torch.no_grad():
            y = run_function(x, *params)
        return y

    @staticmethod
    def backward(ctx, grad_y):
        x, *params = ctx.saved_tensors
        # Recompute with grad enabled
        x = x.detach().requires_grad_(True)
        params = [p.detach().requires_grad_(True) for p in params]
        with torch.enable_grad():
            y = ctx.run_function(x, *params)
        grads = torch.autograd.grad(y, [x, *params], grad_y)
        return (None, *grads)

def checkpoint(fn, x, *params):
    return Checkpoint.apply(fn, x, *params)

Pit-of-failure. Forgetting to detach inputs before requires_grad-ing. Signals: never wrote an autograd Function.

→ Phase 07, Phase 08.


Drill 04 — Implement a KV cache for autoregressive decoding

Brief. Given a transformer block, implement step(x_new, cache) that uses the cache for the prefix and only computes attention for the new token.

Allowed. PyTorch.

Constraints. 25 minutes. Must handle the first call (empty cache) correctly.

Solution stub.

import torch

class KVCache:
    def __init__(self):
        self.k = None  # (B, H, N, D)
        self.v = None

    def update(self, k_new, v_new):
        # k_new, v_new: (B, H, 1, D) typically
        if self.k is None:
            self.k = k_new
            self.v = v_new
        else:
            self.k = torch.cat([self.k, k_new], dim=2)
            self.v = torch.cat([self.v, v_new], dim=2)
        return self.k, self.v

def attention_step(x_new, W_q, W_k, W_v, W_o, cache, n_heads):
    # x_new: (B, 1, d_model)
    B, _, d = x_new.shape
    d_head = d // n_heads
    q = (x_new @ W_q).view(B, 1, n_heads, d_head).transpose(1, 2)
    k = (x_new @ W_k).view(B, 1, n_heads, d_head).transpose(1, 2)
    v = (x_new @ W_v).view(B, 1, n_heads, d_head).transpose(1, 2)
    K, V = cache.update(k, v)
    # Attention over the full prefix (no need for causal mask: only one new query)
    scores = (q @ K.transpose(-1, -2)) / (d_head ** 0.5)         # (B, H, 1, N)
    weights = torch.softmax(scores, dim=-1)
    attn = weights @ V                                            # (B, H, 1, D)
    out = attn.transpose(1, 2).reshape(B, 1, d) @ W_o
    return out

Pit-of-failure. Concatenating on wrong axis; forgetting that the query is only 1 token but key/value span the whole prefix.

→ Phase 22.


Drill 05 — Implement a LoRA Linear layer

Brief. Build a LoRALinear(in_features, out_features, rank, alpha) module. Forward must equal x @ W^T + (alpha/rank) * x @ A @ B. The base W is frozen.

Allowed. PyTorch.

Constraints. 20 minutes. Initialization: A Kaiming, B zeros, so LoRA delta starts at zero.

Solution stub.

import torch
import torch.nn as nn

class LoRALinear(nn.Module):
    def __init__(self, in_features, out_features, rank=8, alpha=16):
        super().__init__()
        self.base = nn.Linear(in_features, out_features, bias=False)
        for p in self.base.parameters():
            p.requires_grad = False
        # Low-rank factors
        self.A = nn.Parameter(torch.empty(in_features, rank))
        self.B = nn.Parameter(torch.zeros(rank, out_features))
        nn.init.kaiming_uniform_(self.A, a=5 ** 0.5)
        self.scale = alpha / rank

    def forward(self, x):
        # x: (..., in_features)
        return self.base(x) + self.scale * (x @ self.A @ self.B)

Pit-of-failure. Forgetting the (alpha/rank) scaling; initializing both A and B randomly (model behavior corrupted at step 0).

→ Phase 28.


Drill 06 — Implement the DPO loss

Brief. Given log-probs of chosen and rejected responses under the policy and under a frozen reference, and a beta hyperparameter, compute the DPO loss.

Allowed. PyTorch.

Constraints. 15 minutes. Must use logsigmoid for numerical stability.

Solution stub.

import torch
import torch.nn.functional as F

def dpo_loss(
    policy_chosen_logps: torch.Tensor,  # (B,)
    policy_rejected_logps: torch.Tensor,
    ref_chosen_logps: torch.Tensor,
    ref_rejected_logps: torch.Tensor,
    beta: float = 0.1,
):
    """Direct Preference Optimization loss (Rafailov 2023)."""
    pi_logratio = policy_chosen_logps - policy_rejected_logps
    ref_logratio = ref_chosen_logps - ref_rejected_logps
    logits = beta * (pi_logratio - ref_logratio)
    loss = -F.logsigmoid(logits).mean()
    # Implicit reward margin for logging
    chosen_reward = beta * (policy_chosen_logps - ref_chosen_logps).detach()
    rejected_reward = beta * (policy_rejected_logps - ref_rejected_logps).detach()
    return loss, chosen_reward, rejected_reward

Pit-of-failure. Using log(sigmoid(...)) instead of logsigmoid — overflow on confident pairs. Signals: never trained a DPO model.

→ X3 module.


Drill 07 — Implement top-p (nucleus) sampling

Brief. Given logits and top_p, sample a token.

Allowed. NumPy or PyTorch.

Constraints. 15 minutes. Must handle the edge case where the top token already exceeds top_p.

Solution stub.

import torch

def top_p_sample(logits: torch.Tensor, top_p: float = 0.9, temperature: float = 1.0):
    # logits: (vocab,)
    logits = logits / max(temperature, 1e-6)
    probs = torch.softmax(logits, dim=-1)
    sorted_probs, sorted_idx = torch.sort(probs, descending=True)
    cumulative = torch.cumsum(sorted_probs, dim=-1)
    # Keep the smallest set whose cumsum >= top_p.
    # We include the first index that crosses top_p (otherwise empty when top_1 > top_p).
    cutoff_mask = cumulative > top_p
    # Shift right so we always keep at least one token.
    cutoff_mask[..., 1:] = cutoff_mask[..., :-1].clone()
    cutoff_mask[..., 0] = False
    sorted_probs[cutoff_mask] = 0.0
    sorted_probs = sorted_probs / sorted_probs.sum()
    sample_in_sorted = torch.multinomial(sorted_probs, 1)
    return sorted_idx[sample_in_sorted].item()

Pit-of-failure. Empty candidate set when top-1 probability already exceeds top_p. Signals: never coded the edge case.

→ Phase 21.


Drill 08 — Implement LayerNorm forward and backward

Brief. Given input x of shape (B, N, D), learned gamma, beta, compute y = gamma * (x - mu) / sqrt(var + eps) + beta, and the backward gradients.

Allowed. NumPy. No autograd.

Constraints. 30 minutes. The backward is the actual filter.

Solution stub.

import numpy as np

def layernorm_forward(x, gamma, beta, eps=1e-5):
    # x: (B, N, D)
    mu = x.mean(axis=-1, keepdims=True)
    var = x.var(axis=-1, keepdims=True)
    inv = 1.0 / np.sqrt(var + eps)
    x_hat = (x - mu) * inv
    y = gamma * x_hat + beta
    cache = (x_hat, inv, gamma)
    return y, cache

def layernorm_backward(dy, cache):
    x_hat, inv, gamma = cache
    D = x_hat.shape[-1]
    # Param grads (sum over all but feature axis).
    dgamma = (dy * x_hat).sum(axis=(0, 1))
    dbeta = dy.sum(axis=(0, 1))
    # Activation grad.
    dx_hat = dy * gamma                                            # (B, N, D)
    # Standard LN backward derivation:
    dx = (1.0 / D) * inv * (
        D * dx_hat
        - dx_hat.sum(axis=-1, keepdims=True)
        - x_hat * (dx_hat * x_hat).sum(axis=-1, keepdims=True)
    )
    return dx, dgamma, dbeta

Pit-of-failure. Backward formula off by a factor of 1/D. Signals: never derived it. Most candidates fail.

→ Phase 10, Q2.


Drill 09 — Implement RMSNorm

Brief. RMSNorm = LayerNorm without mean subtraction. y = gamma * x / sqrt(mean(x^2) + eps).

Allowed. PyTorch.

Constraints. 10 minutes (it is short by design).

Solution stub.

import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        # x: (..., dim)
        rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
        return x * rms * self.weight

Follow-up. "Why does RMSNorm work as well as LayerNorm?" → empirical finding (Zhang & Sennrich 2019); mean-subtraction is not load-bearing; saves one reduction.

→ Phase 10.


Drill 10 — Implement RoPE (Rotary Positional Embeddings)

Brief. Given Q, K of shape (B, H, N, D), apply RoPE rotation by position.

Allowed. PyTorch.

Constraints. 25 minutes. Must use the standard interleaved or even-odd pairing.

Solution stub.

import torch

def build_rope_cache(seq_len, dim, base=10000.0, device="cpu"):
    # dim must be even.
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim))
    t = torch.arange(seq_len, device=device).float()
    freqs = torch.outer(t, inv_freq)                              # (N, D/2)
    cos = freqs.cos()                                              # (N, D/2)
    sin = freqs.sin()
    return cos, sin

def apply_rope(x, cos, sin):
    # x: (B, H, N, D). Split into pairs along D.
    B, H, N, D = x.shape
    x1 = x[..., 0::2]                                              # (B, H, N, D/2)
    x2 = x[..., 1::2]
    # cos, sin: (N, D/2) -> broadcast.
    cos = cos[None, None, :, :]                                    # (1, 1, N, D/2)
    sin = sin[None, None, :, :]
    rotated_1 = x1 * cos - x2 * sin
    rotated_2 = x1 * sin + x2 * cos
    out = torch.stack([rotated_1, rotated_2], dim=-1).reshape(B, H, N, D)
    return out

Pit-of-failure. Mixing up which pair convention (even-odd vs first-half / second-half — Llama uses different conventions across versions). Document the convention you used.

→ Phase 16, Q11.


Drill 11 — Implement beam search decoding

Brief. Given a function step(prefix) -> log_probs over vocab, return the top-B complete sequences of length T (or stopped at EOS).

Allowed. PyTorch.

Constraints. 30 minutes.

Solution stub.

import torch
import heapq

def beam_search(step_fn, start_token, eos_token, beam_size=4, max_len=64):
    # Each beam: (score, tokens, finished).
    beams = [(0.0, [start_token], False)]
    finished = []
    for _ in range(max_len):
        candidates = []
        for score, tokens, done in beams:
            if done:
                candidates.append((score, tokens, True))
                continue
            log_probs = step_fn(tokens)                            # (vocab,)
            topk = torch.topk(log_probs, beam_size)
            for lp, idx in zip(topk.values.tolist(), topk.indices.tolist()):
                new_tokens = tokens + [idx]
                new_done = (idx == eos_token)
                candidates.append((score + lp, new_tokens, new_done))
        # Keep the top beam_size by score.
        candidates.sort(key=lambda x: x[0], reverse=True)
        beams = candidates[:beam_size]
        if all(b[2] for b in beams):
            break
    # Optional: length-normalize for fairness.
    beams.sort(key=lambda b: b[0] / max(len(b[1]), 1), reverse=True)
    return beams

Pit-of-failure. Not handling EOS / done sequences (they continue to "expand" and dilute the beam). Not length-normalizing (short sequences win unfairly).

→ Phase 21.


Drill 12 — Implement a continuous batcher (skeleton)

Brief. Build the scheduler logic that admits new requests while existing requests are mid-decode, padding the batch with new prefill requests and removing finished ones.

Allowed. Python. No actual GPU; pseudo-code level is acceptable.

Constraints. 30 minutes. Focus on the scheduling logic, not the kernel.

Solution stub.

from dataclasses import dataclass, field
from collections import deque
from typing import Callable, Any

@dataclass
class Request:
    req_id: int
    prompt_tokens: list
    generated: list = field(default_factory=list)
    max_new_tokens: int = 256
    done: bool = False

class ContinuousBatcher:
    def __init__(self, max_batch_size: int, prefill_fn: Callable, decode_fn: Callable):
        self.max_batch = max_batch_size
        self.prefill_fn = prefill_fn      # (list[Request]) -> per-req state (KV cache)
        self.decode_fn = decode_fn        # (list[Request], states) -> next tokens
        self.waiting = deque()            # pending requests
        self.active = []                  # currently decoding
        self.states = {}                  # req_id -> opaque state

    def submit(self, req: Request):
        self.waiting.append(req)

    def step(self):
        # 1. Admit new requests up to capacity.
        while self.waiting and len(self.active) < self.max_batch:
            new_req = self.waiting.popleft()
            state = self.prefill_fn([new_req])[0]
            self.states[new_req.req_id] = state
            self.active.append(new_req)
        # 2. Decode one token for everyone active.
        if not self.active:
            return
        next_tokens = self.decode_fn(self.active, [self.states[r.req_id] for r in self.active])
        # 3. Update each request; mark done if EOS or max length.
        still_active = []
        for r, tok in zip(self.active, next_tokens):
            r.generated.append(tok)
            if tok == EOS or len(r.generated) >= r.max_new_tokens:
                r.done = True
                self.states.pop(r.req_id)
            else:
                still_active.append(r)
        self.active = still_active

EOS = -1   # convention for this stub

Pit-of-failure. Stalling the batch until everyone finishes (the "static batching" anti-pattern). Continuous batching's whole point is to not wait.

→ Phase 33, prompt 4 of 02-systems-design-for-llms.md.


Drill-progression schedule

A suggested 14-day schedule:

Day Drills Notes
1 01, 07 Attention + sampling, the two most common phone-screen drills
2 02 BPE — slow but mechanical
3 04 KV cache
4 08 LayerNorm backward — schedule extra time
5 05, 09 LoRA + RMSNorm
6 10 RoPE
7 06 DPO loss
8 11 Beam search
9 03 Gradient checkpointing
10 12 Continuous batcher
11-14 Re-roll under timer Pick 3 drills/day at random; 30 min each

→ Next: 05-behavioral-and-storytelling.md