Skip to content

English · Español

Lab 03 — Continuous batching: iteration-level scheduling

🇪🇸 La diferencia con static batching: en lugar de "ejecuta todo el decode hasta que todos terminen", ejecuta un paso del decode por iteración. Las peticiones rápidas salen pronto; las nuevas se unen entre pasos. Es la técnica que usa vLLM, TGI, Triton.

Objective

Replace the static batcher of lab 02 with a continuous (iteration-level) batcher. Show that p95 latency drops by ≥ 30% on a mixed-length workload, with equal or better throughput.

Setup

  • Lab 02's scheduler (src/miniserve/scheduler.py).
  • Phase 22's KV cache (src/minimodel/kv_cache.py) — strictly required: continuous batching without a KV cache makes every step quadratic. If Phase 22 is not yet done, defer this lab.
  • A mixed-length workload: prompts where some answers are 1-2 tokens, others are 8-12 tokens.

Tasks

  1. Extend the agent with a stepwise API:
class GrammarTutorAgent:
    def begin_correction(self, sentence: str, learner_id: str | None) -> "InflightCorrection":
        """Run the prefill, return a handle with the initial KV cache and first token."""
        ...

    def step_corrections(self, inflight: list["InflightCorrection"]) -> list[tuple[int, bool]]:
        """One decode step on a batch of in-flight corrections. Returns [(new_token_id, is_done), ...]"""
        ...

Implementation sketch for step_corrections: - Gather each in-flight's "last token" → batch of size \(B\). - Run model.forward_one_step(tokens, kv_caches) — one decode step. - For each: sample next token, check if it's EOS or hits max_tokens. - Append the new token to each in-flight's state.

  1. Write the continuous batcher in src/miniserve/scheduler.py:
class ContinuousBatchScheduler:
    def __init__(self, agent, max_inflight: int, max_queue: int):
        self.agent = agent
        self.max_inflight = max_inflight
        self.max_queue = max_queue
        self.ready: asyncio.Queue[PendingRequest] = asyncio.Queue(maxsize=max_queue)
        self.inflight: list[InflightCorrection] = []

    async def submit(self, payload):
        fut = asyncio.get_event_loop().create_future()
        try:
            self.ready.put_nowait(PendingRequest(payload, fut))
        except asyncio.QueueFull:
            raise HTTPException(503, "Server queue full")
        return await fut

    async def _loop(self):
        while True:
            # Admit
            while len(self.inflight) < self.max_inflight and not self.ready.empty():
                req = self.ready.get_nowait()
                inflight_corr = await asyncio.to_thread(
                    self.agent.begin_correction, req.payload["sentence"], req.payload.get("learner_id")
                )
                inflight_corr.future = req.future
                self.inflight.append(inflight_corr)

            if not self.inflight:
                await asyncio.sleep(0.001)
                continue

            # Step
            step_results = await asyncio.to_thread(
                self.agent.step_corrections, self.inflight
            )

            # Reap finished
            still_inflight = []
            for corr, (tok, done) in zip(self.inflight, step_results):
                if done:
                    corr.future.set_result(corr.to_response_dict())
                else:
                    still_inflight.append(corr)
            self.inflight = still_inflight
  1. Wire it in. Replace the static scheduler in app.py:
scheduler = ContinuousBatchScheduler(
    agent=agent,
    max_inflight=8,
    max_queue=200,
)
  1. Design the mixed-length workload. From the verb corpus, construct 200 prompts:
  2. 60% are short: "He" → expected 1-2 tokens (just the verb form).
  3. 30% are medium: "Yesterday I" → 2-3 tokens.
  4. 10% are long: prompts that elicit "going to" futures or full explanations — 6-10 tokens.

Shuffle. Save as data/mixed_workload.json.

  1. Load-test with three schedulers on this workload, all with concurrency=50, total=500:
  2. A: No scheduler (lab 01, variant C).
  3. B: Static batching (max_batch=8, max_wait_ms=20).
  4. C: Continuous batching (max_inflight=8).

  5. Plot the comparison:

  6. Latency CDF: three curves on the same axes.
  7. p50, p95, p99 bar chart: three groups of three bars.
  8. Throughput (req/s) bar chart.

  9. Verify the property:

  10. p95 of C is ≥ 30% better than p95 of B.
  11. p50 of C is similar to or better than p50 of B.
  12. Throughput of C is similar to or better than B (continuous batching is throughput-neutral in the easy case; this is the latency win).

  13. Sweep max_inflight ∈ {1, 2, 4, 8, 16} for the continuous scheduler. Plot throughput vs max_inflight and p95 vs max_inflight. The expected shape: throughput rises and saturates; p95 stays roughly flat and then rises at the high end (queue wait kicks in).

Measurements

Save to experiments/<date>-phase-33-lab-03/:

  • latencies_nobatch.json, latencies_static.json, latencies_continuous.json.
  • latency_cdf_3way.png.
  • pNN_bars.png.
  • inflight_sweep.csv and inflight_sweep.png.
  • manifest.json — workload composition (the 60/30/10 split), seeds, model checkpoint.

Acceptance (DoD relevant)

  • p95 of continuous batching ≥ 30% better than static batching on the mixed workload. (Phase 33 DoD.)
  • Throughput of continuous batching ≥ static batching (within 10%).
  • All requests return HTTP 200 OR HTTP 503 (no 500 errors, no timeouts).
  • Under sustained over-capacity load (concurrency=200), the scheduler rejects with 503 when max_queue is exceeded — not OOMs or stalls.

Pitfalls

  • Per-request KV cache bookkeeping. Each in-flight has its own cache. Don't share buffers across requests — they have different lengths. (Phase 27's PagedAttention solves the memory-waste version of this; for our lab, just allocate per-request.)
  • Batching the step, not the full forward. The whole point. If you accidentally batch a multi-token generation, you've reinvented static batching.
  • Tokenizer determinism. When you sample a token in step t, it has to feed back as the input to step t+1. If your batched forward pass has subtle padding-dependent outputs, you'll get different tokens than the un-batched version. Add a property test: continuous-batch output for a single request must match un-batched output (modulo floating-point reordering).
  • Cancellation. If the client disconnects, the request should leave the in-flight set. FastAPI exposes request.is_disconnected() — handle it. Otherwise you waste compute on a dead client.
  • Fairness. FIFO admission is the simplest. With more sophisticated policies (priorities, fairness), you can starve some requests. Leave this for Phase 34.

Stretch (out of scope for DoD)

  • Continuous batching + prefill batching. Real systems batch prefill separately because prefill is expensive. Survey this and gesture at it in lab 04.
  • Variable-length attention in a single forward. Implement masking so requests of different cached-lengths can share the step kernel. Or: pad to the longest length (simpler, what lab 03 does).

Next: 04-vllm-and-tgi-survey.md