English · Español
Lab 03 — Continuous batching: iteration-level scheduling¶
🇪🇸 La diferencia con static batching: en lugar de "ejecuta todo el decode hasta que todos terminen", ejecuta un paso del decode por iteración. Las peticiones rápidas salen pronto; las nuevas se unen entre pasos. Es la técnica que usa vLLM, TGI, Triton.
Objective¶
Replace the static batcher of lab 02 with a continuous (iteration-level) batcher. Show that p95 latency drops by ≥ 30% on a mixed-length workload, with equal or better throughput.
Setup¶
- Lab 02's scheduler (
src/miniserve/scheduler.py). - Phase 22's KV cache (
src/minimodel/kv_cache.py) — strictly required: continuous batching without a KV cache makes every step quadratic. If Phase 22 is not yet done, defer this lab. - A mixed-length workload: prompts where some answers are 1-2 tokens, others are 8-12 tokens.
Tasks¶
- Extend the agent with a stepwise API:
class GrammarTutorAgent:
def begin_correction(self, sentence: str, learner_id: str | None) -> "InflightCorrection":
"""Run the prefill, return a handle with the initial KV cache and first token."""
...
def step_corrections(self, inflight: list["InflightCorrection"]) -> list[tuple[int, bool]]:
"""One decode step on a batch of in-flight corrections. Returns [(new_token_id, is_done), ...]"""
...
Implementation sketch for step_corrections:
- Gather each in-flight's "last token" → batch of size \(B\).
- Run model.forward_one_step(tokens, kv_caches) — one decode step.
- For each: sample next token, check if it's EOS or hits max_tokens.
- Append the new token to each in-flight's state.
- Write the continuous batcher in
src/miniserve/scheduler.py:
class ContinuousBatchScheduler:
def __init__(self, agent, max_inflight: int, max_queue: int):
self.agent = agent
self.max_inflight = max_inflight
self.max_queue = max_queue
self.ready: asyncio.Queue[PendingRequest] = asyncio.Queue(maxsize=max_queue)
self.inflight: list[InflightCorrection] = []
async def submit(self, payload):
fut = asyncio.get_event_loop().create_future()
try:
self.ready.put_nowait(PendingRequest(payload, fut))
except asyncio.QueueFull:
raise HTTPException(503, "Server queue full")
return await fut
async def _loop(self):
while True:
# Admit
while len(self.inflight) < self.max_inflight and not self.ready.empty():
req = self.ready.get_nowait()
inflight_corr = await asyncio.to_thread(
self.agent.begin_correction, req.payload["sentence"], req.payload.get("learner_id")
)
inflight_corr.future = req.future
self.inflight.append(inflight_corr)
if not self.inflight:
await asyncio.sleep(0.001)
continue
# Step
step_results = await asyncio.to_thread(
self.agent.step_corrections, self.inflight
)
# Reap finished
still_inflight = []
for corr, (tok, done) in zip(self.inflight, step_results):
if done:
corr.future.set_result(corr.to_response_dict())
else:
still_inflight.append(corr)
self.inflight = still_inflight
- Wire it in. Replace the static scheduler in
app.py:
- Design the mixed-length workload. From the verb corpus, construct 200 prompts:
- 60% are short:
"He"→ expected 1-2 tokens (just the verb form). - 30% are medium:
"Yesterday I"→ 2-3 tokens. - 10% are long: prompts that elicit "going to" futures or full explanations — 6-10 tokens.
Shuffle. Save as data/mixed_workload.json.
- Load-test with three schedulers on this workload, all with
concurrency=50, total=500: - A: No scheduler (lab 01, variant C).
- B: Static batching (
max_batch=8,max_wait_ms=20). -
C: Continuous batching (
max_inflight=8). -
Plot the comparison:
- Latency CDF: three curves on the same axes.
- p50, p95, p99 bar chart: three groups of three bars.
-
Throughput (req/s) bar chart.
-
Verify the property:
- p95 of C is ≥ 30% better than p95 of B.
- p50 of C is similar to or better than p50 of B.
-
Throughput of C is similar to or better than B (continuous batching is throughput-neutral in the easy case; this is the latency win).
-
Sweep
max_inflight ∈ {1, 2, 4, 8, 16}for the continuous scheduler. Plot throughput vsmax_inflightand p95 vsmax_inflight. The expected shape: throughput rises and saturates; p95 stays roughly flat and then rises at the high end (queue wait kicks in).
Measurements¶
Save to experiments/<date>-phase-33-lab-03/:
latencies_nobatch.json,latencies_static.json,latencies_continuous.json.latency_cdf_3way.png.pNN_bars.png.inflight_sweep.csvandinflight_sweep.png.manifest.json— workload composition (the 60/30/10 split), seeds, model checkpoint.
Acceptance (DoD relevant)¶
- p95 of continuous batching ≥ 30% better than static batching on the mixed workload. (Phase 33 DoD.)
- Throughput of continuous batching ≥ static batching (within 10%).
- All requests return HTTP 200 OR HTTP 503 (no 500 errors, no timeouts).
- Under sustained over-capacity load (
concurrency=200), the scheduler rejects with 503 whenmax_queueis exceeded — not OOMs or stalls.
Pitfalls¶
- Per-request KV cache bookkeeping. Each in-flight has its own cache. Don't share buffers across requests — they have different lengths. (Phase 27's PagedAttention solves the memory-waste version of this; for our lab, just allocate per-request.)
- Batching the step, not the full forward. The whole point. If you accidentally batch a multi-token generation, you've reinvented static batching.
- Tokenizer determinism. When you sample a token in step
t, it has to feed back as the input to stept+1. If your batched forward pass has subtle padding-dependent outputs, you'll get different tokens than the un-batched version. Add a property test: continuous-batch output for a single request must match un-batched output (modulo floating-point reordering). - Cancellation. If the client disconnects, the request should leave the in-flight set. FastAPI exposes
request.is_disconnected()— handle it. Otherwise you waste compute on a dead client. - Fairness. FIFO admission is the simplest. With more sophisticated policies (priorities, fairness), you can starve some requests. Leave this for Phase 34.
Stretch (out of scope for DoD)¶
- Continuous batching + prefill batching. Real systems batch prefill separately because prefill is expensive. Survey this and gesture at it in lab 04.
- Variable-length attention in a single forward. Implement masking so requests of different cached-lengths can share the step kernel. Or: pad to the longest length (simpler, what lab 03 does).