Multi-LoRA serving
One base model, N customer adapters. Serving them on one process without merging — and without one kernel launch per row.
The business problem
You have one base model — say LLaMA-3 70B at ~140 GB in fp16 — and N customers, each with their own LoRA fine-tune. The naive approach is N model loads: N × 140 GB. With even ten customers you've exceeded any single node. Untenable.
The fix is structural, not a serving trick: LoRA's mathematical form makes shared serving possible. This lesson is the path from that form to a kernel that batches a heterogeneous-adapter request mix at near-base throughput.
LoRA, the smallest possible recap
You freeze a pretrained weight W ∈ ℝdout × din and learn a low-rank update:
Storage per adapter is 2 · r · d per attention projection, per layer. Concretely, 50–160 MB per adapter (r=16, fp16, attention-only) depending on model size; roughly double those with MLP LoRA. For r=16 on LLaMA-3 70B (attention Q/K/V/O only, 80 layers, d=8192): ~160 MB per adapter. One base model + 1000 adapters = 140 GB + 160 GB ≈ fits on a 4-H100 node with room to spare.
Why merging fails for multi-tenant
The textbook formula y = x · W' with W' = W + BA suggests "for each request, materialize W' and run the matmul." This is fine for single-tenant. For multi-tenant it's a disaster:
- Per-request weight rewrites in HBM. A 140 GB write between every request. Throughput goes to zero.
- Kills prefix caching. Two requests with the same system prompt now use different weights — their KV caches aren't sharable.
- Kills batching. A batch can only run one adapter at a time; you can't mix customers.
The right move is to never merge. Keep W as the shared frozen tensor and run two matmuls:
First term: a normal batched GEMM, one big shared matmul across every row in the batch regardless of adapter. Second term: per-request — the per-row adapter delta. The question is how to compute that second term cheaply when the batch is heterogeneous.
The shape math
Let batch size be B tokens, hidden size d, rank r. Per row i with adapter index ai:
Shapes:
X : (B, d)
W : (d, d) — shared
B_i : (d, r), A_i : (r, d) — per-row, varies by adapter_id[i]
Y_base = X @ W → (B, d) one shared GEMM
Y_delta = stack_i( x_i @ B_i @ A_i ) → (B, d) grouped per-row
Y = Y_base + Y_delta
The base GEMM is cheap relative to the delta gather only if we don't pay one CUDA launch per row. A naive Python loop with B = 64 rows means 64 launches × ~5 µs each = 320 µs of pure overhead for what should be a few microseconds of math. That's the throughput killer.
The two S-LoRA / Punica contributions
1 · Unified adapter memory
Put every adapter's B and A matrices in one contiguous HBM pool, indexed by adapter ID:
B_pool: (N_adapters, d_out, r)
A_pool: (N_adapters, r, d_in)
No per-adapter tensor allocation. Adapters live or die based on usage: warm in HBM, cold in CPU DRAM, LRU eviction on miss. Same exact playbook as paged KV from lesson 02 — the unit of allocation is fixed, the index is the resource.
2 · BGMV / SGMV grouped-matmul kernel
One CUDA kernel that, given X (B, d), the two pools, and adapter_ids (B,), computes yi = xi · Bai · Aai in a single fused launch.
BGMV = "batched grouped matrix-vector". One CUDA block per row. Inside the block:
- Read this row's
adapter_id. - Load slices
B_pool[adapter_id]andA_pool[adapter_id]through normal global-memory loads into SRAM. - Do the (d, r) @ (r, d) matmul for this row's x.
- Write the d-dim result to
Y[i].
Parallelism is over rows, not adapters. One launch, regardless of how many distinct adapters appear in the batch.
SGMV — segmented variant for when adapters repeat
If multiple rows in the batch share an adapter, we can do better than gathering (B, A) once per row. Sort the batch by adapter_id so all rows sharing an adapter sit contiguously, then assign one "segment" to each adapter group. Inside a segment, load (B, A) into SRAM once and reuse across every row in the segment — the same tile-reuse pattern FlashAttention (lesson 03) uses for K, V across query positions.
The math is exactly the same; the SRAM reuse pattern is the difference. When the batch has k distinct adapters across B rows, the (B, A) loads from HBM drop from B to k. For a typical request mix (B=64, k=4–8 distinct adapters), that's an 8–16× reduction in adapter-pool HBM traffic.
Memory budget
Per-adapter size, attention only (Q/K/V/O projections):
| model | n_layers | d | r | MB / adapter (fp16, attn only) | warm in 30 GB |
|---|---|---|---|---|---|
| LLaMA-1 7B | 32 | 4096 | 16 | 33 MB | ~930 adapters |
| Mistral 7B | 32 | 4096 | 8 | 17 MB | ~1830 adapters |
| LLaMA-2 70B | 80 | 8192 | 16 | 160 MB | ~190 adapters |
Numbers in the table are attention-only (Q/K/V/O). Add MLP adapters (up + down projections) and they roughly double. The pattern that matters: one base + hundreds of warm adapters in a fraction of the HBM the base itself occupies.
Throughput caveats
- Best case: mostly distinct adapters, all warm. BGMV scales linearly with rows; latency overhead vs base-only is a few percent.
- Best+ case: grouped batch (sort by adapter). SGMV unlocks tile reuse; (B, A) loads drop k×. Better than the "best case" if k < B.
- Worst case: cold-miss heavy batch. A miss costs ~100 ms DRAM → HBM. If half the batch misses, latency tanks. Mitigation: predict adapter affinity at the load balancer and route warm traffic to warm replicas.
Interactive · heterogeneous-batch visualizer
The batch is a row of tokens colored by adapter ID. Below it is a per-token timeline showing the cost of gathering the adapter's (B, A) into SRAM plus the per-row matmul. In per-row mode each row pays the full gather cost. In SGMV grouped mode rows are sorted by adapter, and adjacent same-color rows share the gather cost (load once, reuse the tile). The KPI reports the difference.
Putting it together
| approach | HBM rewrites/req | kernel launches | (B,A) loads | scales to N adapters |
|---|---|---|---|---|
| merge per request | 140 GB | 1 | — | no |
| naive loop (no fusion) | 0 | B | B | throughput-limited |
| BGMV (per-row fused) | 0 | 1 | B | yes |
| SGMV (segmented) | 0 | 1 | k (distinct adapters) | yes, optimally |