Skip to content

English · Español

Lab 02 — Static batching

🇪🇸 Junta peticiones en una cola. Cuando llegues a N (o pase un timeout), ejecuta el modelo sobre todo el batch en una sola pasada. Throughput sube; latencia de la última petición del batch sube también. Mídelo.

Objective

Implement a static batching scheduler: collect requests, run the model once per batch. Measure throughput and tail latency vs the un-batched baseline from lab 01.

Setup

  • Lab 01's working FastAPI service (variant C: async + to_thread).
  • src/miniserve/scheduler.py — new module.
  • The loadtest script from lab 01.

Tasks

  1. Modify the model API to accept batched input. The Mini-GPT's forward() already supports a batch dimension (Phase 17). Add agent.correct_batch(sentences: list[str]) -> list[Correction]:
class GrammarTutorAgent:
    def correct_batch(self, sentences: list[str], learner_ids: list[str | None]) -> list[Correction]:
        # Run the agent loop for each sentence with a single batched model.forward.
        # For now: simple — generate all responses to max_tokens with the same generation length.
        ...

For Phase 33 simplicity, the agent will generate to a fixed max_tokens for every batch member. This is intentional — it makes the static-batching tail-latency problem visible.

  1. Write the scheduler in src/miniserve/scheduler.py:
import asyncio
from dataclasses import dataclass
from typing import Callable

@dataclass
class PendingRequest:
    payload: dict
    future: asyncio.Future

class StaticBatchScheduler:
    def __init__(self, batch_fn: Callable, max_batch: int, max_wait_ms: int):
        self.batch_fn = batch_fn
        self.max_batch = max_batch
        self.max_wait_ms = max_wait_ms
        self.queue: asyncio.Queue[PendingRequest] = asyncio.Queue()
        self._loop_task: asyncio.Task | None = None

    async def submit(self, payload: dict) -> dict:
        fut = asyncio.get_event_loop().create_future()
        await self.queue.put(PendingRequest(payload, fut))
        return await fut

    async def start(self):
        self._loop_task = asyncio.create_task(self._loop())

    async def _loop(self):
        while True:
            batch = await self._collect_batch()
            if not batch:
                await asyncio.sleep(0.001)
                continue
            # Run model in thread (CPU-bound)
            results = await asyncio.to_thread(
                self.batch_fn, [r.payload for r in batch]
            )
            for r, res in zip(batch, results):
                r.future.set_result(res)

    async def _collect_batch(self) -> list[PendingRequest]:
        batch = []
        try:
            first = await asyncio.wait_for(self.queue.get(), timeout=1.0)
            batch.append(first)
        except asyncio.TimeoutError:
            return batch
        deadline = asyncio.get_event_loop().time() + self.max_wait_ms / 1000
        while len(batch) < self.max_batch:
            remaining = deadline - asyncio.get_event_loop().time()
            if remaining <= 0:
                break
            try:
                batch.append(await asyncio.wait_for(self.queue.get(), timeout=remaining))
            except asyncio.TimeoutError:
                break
        return batch
  1. Wire it into the FastAPI app:
scheduler = StaticBatchScheduler(
    batch_fn=lambda payloads: agent.correct_batch(
        [p["sentence"] for p in payloads],
        [p.get("learner_id") for p in payloads],
    ),
    max_batch=8,
    max_wait_ms=20,
)

@app.on_event("startup")  # or use lifespan
async def _start():
    await scheduler.start()

@app.post("/correct")
async def correct(req: CorrectRequest) -> CorrectResponse:
    result = await scheduler.submit(req.model_dump())
    return CorrectResponse(**result)
  1. Sweep batch parameters. For each (max_batch, max_wait_ms) ∈ {(1,0), (2,10), (4,10), (8,20), (16,20), (32,50)}:
  2. Run loadtest with concurrency=50, total=500.
  3. Record p50, p95, p99, throughput.

  4. Plot results.

  5. X-axis: max_batch. Y-axis: throughput (req/s). One line per max_wait_ms.
  6. X-axis: max_batch. Y-axis: p95 latency. Same line setup.
  7. You should see throughput increase and p95 latency also increase — the trade-off.

  8. Compare to lab 01's baseline. On the same axes, plot the un-batched (variant C) result as a single point.

Measurements

Save to experiments/<date>-phase-33-lab-02/:

  • batch_sweep.csv — one row per (max_batch, max_wait_ms): p50, p95, p99, throughput.
  • throughput_vs_batch.png
  • p95_vs_batch.png
  • latency_cdf_baseline_vs_batch8.png — un-batched vs max_batch=8.
  • manifest.json.

Acceptance

  • For max_batch ≥ 4, throughput exceeds the un-batched baseline by ≥ 50%.
  • For max_batch ≥ 4, p95 latency degrades vs the un-batched baseline (this is the expected trade-off, not a regression).
  • The throughput curve saturates somewhere in the sweep — increasing max_batch beyond a point doesn't help.
  • All requests return HTTP 200 (no timeouts) under the load.

Pitfalls

  • Running the model with asyncio.to_thread but inside the scheduler _loop. This is correct — the loop must yield while the model runs. If you call self.batch_fn(...) directly (not via to_thread), the loop blocks and no other requests can be queued during the forward pass.
  • Setting max_wait_ms too high. If you wait 100 ms for batch fill, the first request in every batch eats 100 ms of pure queue wait. Find the right value experimentally.
  • Forgetting batch padding. All sequences in a batch must have the same length (or you must mask). For lab 02 we sidestep this by generating to fixed max_tokens. The padding waste is the cost.
  • Memory growth under load. If the queue fills up, you'll OOM. For lab 02, bound the queue: in submit(), if queue.qsize() > MAX_QUEUE, return 503.
  • Measurement noise. Run each config 3 times, report median p95.

Stretch

  • Add a histogram of "how full was each batch" — many batches will be size 1 at low load, size N at high load. This explains the saturation.
  • What happens if you set max_batch=1? It should behave identically to the un-batched baseline (modulo the scheduler overhead). Verify.

Next: 03-continuous-batching.md