Skip to content

English · Español

Lab 01 — Implement the KV Cache

Goal: write src/minicache/cache.py from BLUEPRINT.md. Make tests/test_minicache.py pass.

Estimated time: 4–8 hours over 2–3 sessions.

Prereq: lab/00-derive-cache-size.md committed. src/minicache/BLUEPRINT.md read; any open questions on it resolved with /phase-checkpoint before starting.


What you produce

  • src/minicache/cache.py — implementation per BLUEPRINT.md API.
  • tests/test_minicache.py — Claude scaffolds the failing tests; you make them pass.
  • Updated src/minimodel/attention.pyattention(...) accepts an optional cache: KVCache | None.
  • All tests green; mypy --strict src/minicache clean; ruff check src/minicache clean.

TODOs

Block A — read the blueprint

  • Open src/minicache/BLUEPRINT.md. Re-read every § (Purpose, API, Alternatives, Complexity, Test plan, Anti-goals, Open questions).
  • If any open question is unresolved, stop here and run /phase-checkpoint. Don't code against open questions.

Block B — write the failing tests (TDD)

Tests scaffold is at tests/test_minicache.py (Claude commits this empty-bodied; you fill the bodies first, before any cache code). Each test starts as a docstring describing the property it checks; you turn the docstring into asserts.

Test list (from BLUEPRINT.md §Test plan):

  1. test_allocate_shapesKVCache.allocate(layers=4, heads=4, head_dim=32, max_seq=128, batch=2, dtype=np.float32) returns an object whose per-layer K and V buffers have shape (2, 4, 128, 32).
  2. test_initial_cursor_zerocache.current_length() == 0 immediately after allocate.
  3. test_append_advances_cursor — appending 5 tokens leaves cursor at 5.
  4. test_append_one_token_per_layer — appending writes the right slice; reads return the same bytes.
  5. test_read_returns_only_filled_rowscache.read(layer=0) returns shape (B, H, cursor, d_h), not the whole pre-allocated tensor.
  6. test_capacity_exceeded_raises — appending past max_seq raises a custom CacheFullError.
  7. test_dtype_preserved — allocate fp16, append fp16 tokens, read returns fp16.
  8. test_reset_empties_cursorcache.reset() zeroes the cursor; the underlying buffer is not required to be zeroed.
  9. test_independent_layers — writing to layer 1 doesn't disturb layer 0.
  10. test_independent_batch_entries — writing for batch index 0 doesn't disturb batch index 1.
  11. test_memory_footprint_matches_formulacache.bytes_allocated() equals 2 · L · H · d_h · S_max · B · s exactly.

Each test should be 5–15 lines. If yours is longer, the implementation is fighting the test.

Block C — implement KVCache

API per BLUEPRINT.md:

class KVCache:
    @classmethod
    def allocate(cls, *, layers: int, heads: int, head_dim: int,
                 max_seq: int, batch: int, dtype: np.dtype) -> "KVCache": ...
    def append(self, layer: int, k_new: np.ndarray, v_new: np.ndarray) -> None: ...
    def read(self, layer: int) -> tuple[np.ndarray, np.ndarray]: ...
    def current_length(self) -> int: ...
    def reset(self) -> None: ...
    def bytes_allocated(self) -> int: ...

Constraints:

  • All public methods type-annotated. mypy --strict clean.
  • No external dependencies beyond numpy + stdlib.
  • append is O(1) per token (no concatenation). This is the cache's whole point; if you find yourself reaching for np.concatenate, re-read theory/02.
  • The cursor advances once per append call. Appending K and V together advance the cursor by 1, not by 2. (Easy off-by-one — the K and V appends are conceptually the same "step".)
  • All layers share the same cursor (they all process the same token at the same step).

Block D — wire into attention

Update src/minimodel/attention.py:

def attention(q: np.ndarray, k: np.ndarray, v: np.ndarray, *,
              mask: np.ndarray | None = None,
              cache: KVCache | None = None,
              layer_idx: int | None = None) -> np.ndarray:
    """If cache is None: original Phase-15 path (training).
       If cache is not None: append (k, v) to cache, read full cached K, V, do attention.
       layer_idx must be supplied if cache is supplied."""

Constraints:

  • The training path (cache=None) is unchanged in numerics. Phase-15 tests must still pass byte-identically.
  • The decode path (cache is not None) must take q of shape (B, H, 1, d_h) (sequence length 1) and append (k, v) of the same shape.
  • No new mask needed in the decode path (see theory/01-prefill-vs-decode.md §Pseudo-pseudocode).

Block E — manifest

Commit a manifest.json at experiments/22-cache-impl/:

{
  "experiment": "22-cache-impl",
  "date": "YYYY-MM-DD",
  "seed": 42,
  "versions": {"python": "3.11.x", "numpy": "X.Y.Z"},
  "tests": {"total": 11, "passed": null, "skipped": 0},
  "lines_added": null,
  "mypy_strict_clean": null,
  "ruff_clean": null
}

Fill the nulls after tests pass.

Constraints

  • NumPy only. No PyTorch yet (Phase 24 introduces it).
  • No np.concatenate in append. O(1) write into pre-allocated buffer; this is the entire correctness/perf contract.
  • No backward pass needed. The cache is inference-only. Training never uses it (Phase 17 trains without cache).
  • No threading. Cache is single-stream.

Stop conditions

Done when:

  1. All 11 tests in tests/test_minicache.py pass.
  2. mypy --strict src/minicache clean.
  3. ruff check src/minicache clean.
  4. Phase-15 attention tests still pass (i.e., you didn't break the training path).
  5. manifest.json committed.
  6. src/minicache/README.md reflects the final API (kept in sync with BLUEPRINT.md per A5).

Pitfalls (read before debugging)

  • append(layer, k, v) advances cursor twice if you increment per call. The cursor advances once per token, not once per (layer, k_or_v) pair. Track this carefully — see BLUEPRINT.md §Pitfalls.
  • Forgetting layer_idx in the attention call. Without it, the cache can't know which layer's K, V to read.
  • Mutating returned slices. cache.read(layer) returns a view into the pre-allocated buffer. Mutating it corrupts the cache. Document the contract: read returns a view; do not write.
  • Test ordering dependence. If test_capacity_exceeded_raises runs after others, the cache might already be at cursor=0 due to fixtures; ensure each test creates a fresh cache.

When to consult solutions/

After all stop conditions are met. The reference at solutions/01-implement-cache-ref.md (written at phase open) walks through the layout decision and the cursor-management invariant.


Next lab: lab/02-correctness-test.md.