English · Español
Lab 02 — Static batching¶
🇪🇸 Junta peticiones en una cola. Cuando llegues a N (o pase un timeout), ejecuta el modelo sobre todo el batch en una sola pasada. Throughput sube; latencia de la última petición del batch sube también. Mídelo.
Objective¶
Implement a static batching scheduler: collect requests, run the model once per batch. Measure throughput and tail latency vs the un-batched baseline from lab 01.
Setup¶
- Lab 01's working FastAPI service (variant C: async + to_thread).
src/miniserve/scheduler.py— new module.- The loadtest script from lab 01.
Tasks¶
- Modify the model API to accept batched input. The Mini-GPT's
forward()already supports a batch dimension (Phase 17). Addagent.correct_batch(sentences: list[str]) -> list[Correction]:
class GrammarTutorAgent:
def correct_batch(self, sentences: list[str], learner_ids: list[str | None]) -> list[Correction]:
# Run the agent loop for each sentence with a single batched model.forward.
# For now: simple — generate all responses to max_tokens with the same generation length.
...
For Phase 33 simplicity, the agent will generate to a fixed max_tokens for every batch member. This is intentional — it makes the static-batching tail-latency problem visible.
- Write the scheduler in
src/miniserve/scheduler.py:
import asyncio
from dataclasses import dataclass
from typing import Callable
@dataclass
class PendingRequest:
payload: dict
future: asyncio.Future
class StaticBatchScheduler:
def __init__(self, batch_fn: Callable, max_batch: int, max_wait_ms: int):
self.batch_fn = batch_fn
self.max_batch = max_batch
self.max_wait_ms = max_wait_ms
self.queue: asyncio.Queue[PendingRequest] = asyncio.Queue()
self._loop_task: asyncio.Task | None = None
async def submit(self, payload: dict) -> dict:
fut = asyncio.get_event_loop().create_future()
await self.queue.put(PendingRequest(payload, fut))
return await fut
async def start(self):
self._loop_task = asyncio.create_task(self._loop())
async def _loop(self):
while True:
batch = await self._collect_batch()
if not batch:
await asyncio.sleep(0.001)
continue
# Run model in thread (CPU-bound)
results = await asyncio.to_thread(
self.batch_fn, [r.payload for r in batch]
)
for r, res in zip(batch, results):
r.future.set_result(res)
async def _collect_batch(self) -> list[PendingRequest]:
batch = []
try:
first = await asyncio.wait_for(self.queue.get(), timeout=1.0)
batch.append(first)
except asyncio.TimeoutError:
return batch
deadline = asyncio.get_event_loop().time() + self.max_wait_ms / 1000
while len(batch) < self.max_batch:
remaining = deadline - asyncio.get_event_loop().time()
if remaining <= 0:
break
try:
batch.append(await asyncio.wait_for(self.queue.get(), timeout=remaining))
except asyncio.TimeoutError:
break
return batch
- Wire it into the FastAPI app:
scheduler = StaticBatchScheduler(
batch_fn=lambda payloads: agent.correct_batch(
[p["sentence"] for p in payloads],
[p.get("learner_id") for p in payloads],
),
max_batch=8,
max_wait_ms=20,
)
@app.on_event("startup") # or use lifespan
async def _start():
await scheduler.start()
@app.post("/correct")
async def correct(req: CorrectRequest) -> CorrectResponse:
result = await scheduler.submit(req.model_dump())
return CorrectResponse(**result)
- Sweep batch parameters. For each
(max_batch, max_wait_ms) ∈ {(1,0), (2,10), (4,10), (8,20), (16,20), (32,50)}: - Run loadtest with
concurrency=50, total=500. -
Record p50, p95, p99, throughput.
-
Plot results.
- X-axis:
max_batch. Y-axis: throughput (req/s). One line permax_wait_ms. - X-axis:
max_batch. Y-axis: p95 latency. Same line setup. -
You should see throughput increase and p95 latency also increase — the trade-off.
-
Compare to lab 01's baseline. On the same axes, plot the un-batched (variant C) result as a single point.
Measurements¶
Save to experiments/<date>-phase-33-lab-02/:
batch_sweep.csv— one row per(max_batch, max_wait_ms): p50, p95, p99, throughput.throughput_vs_batch.pngp95_vs_batch.pnglatency_cdf_baseline_vs_batch8.png— un-batched vsmax_batch=8.manifest.json.
Acceptance¶
- For
max_batch ≥ 4, throughput exceeds the un-batched baseline by ≥ 50%. - For
max_batch ≥ 4, p95 latency degrades vs the un-batched baseline (this is the expected trade-off, not a regression). - The throughput curve saturates somewhere in the sweep — increasing
max_batchbeyond a point doesn't help. - All requests return HTTP 200 (no timeouts) under the load.
Pitfalls¶
- Running the model with
asyncio.to_threadbut inside the scheduler_loop. This is correct — the loop must yield while the model runs. If you callself.batch_fn(...)directly (not viato_thread), the loop blocks and no other requests can be queued during the forward pass. - Setting
max_wait_mstoo high. If you wait 100 ms for batch fill, the first request in every batch eats 100 ms of pure queue wait. Find the right value experimentally. - Forgetting batch padding. All sequences in a batch must have the same length (or you must mask). For lab 02 we sidestep this by generating to fixed
max_tokens. The padding waste is the cost. - Memory growth under load. If the queue fills up, you'll OOM. For lab 02, bound the queue: in
submit(), ifqueue.qsize() > MAX_QUEUE, return 503. - Measurement noise. Run each config 3 times, report median p95.
Stretch¶
- Add a histogram of "how full was each batch" — many batches will be size 1 at low load, size N at high load. This explains the saturation.
- What happens if you set
max_batch=1? It should behave identically to the un-batched baseline (modulo the scheduler overhead). Verify.