vllm_lessons / 12 · multi-LoRA serving lesson 12 / 12

Multi-LoRA serving

One base model, N customer adapters. Serving them on one process without merging — and without one kernel launch per row.

The business problem

You have one base model — say LLaMA-3 70B at ~140 GB in fp16 — and N customers, each with their own LoRA fine-tune. The naive approach is N model loads: N × 140 GB. With even ten customers you've exceeded any single node. Untenable.

The fix is structural, not a serving trick: LoRA's mathematical form makes shared serving possible. This lesson is the path from that form to a kernel that batches a heterogeneous-adapter request mix at near-base throughput.

LoRA, the smallest possible recap

You freeze a pretrained weight W ∈ ℝdout × din and learn a low-rank update:

W'  =  W  +  B · A,     B ∈ ℝdout × r,    A ∈ ℝr × din,    r ∈ {4, 8, 16, 32}

Storage per adapter is 2 · r · d per attention projection, per layer. Concretely, 50–160 MB per adapter (r=16, fp16, attention-only) depending on model size; roughly double those with MLP LoRA. For r=16 on LLaMA-3 70B (attention Q/K/V/O only, 80 layers, d=8192): ~160 MB per adapter. One base model + 1000 adapters = 140 GB + 160 GB ≈ fits on a 4-H100 node with room to spare.

Why merging fails for multi-tenant

The textbook formula y = x · W' with W' = W + BA suggests "for each request, materialize W' and run the matmul." This is fine for single-tenant. For multi-tenant it's a disaster:

The right move is to never merge. Keep W as the shared frozen tensor and run two matmuls:

y  =  x · W  +  x · B · A

First term: a normal batched GEMM, one big shared matmul across every row in the batch regardless of adapter. Second term: per-request — the per-row adapter delta. The question is how to compute that second term cheaply when the batch is heterogeneous.

The shape math

Let batch size be B tokens, hidden size d, rank r. Per row i with adapter index ai:

yi  =  xi · W  +  xi · Bai · Aai

Shapes:

X     : (B, d)
W     : (d, d)                        — shared
B_i   : (d, r), A_i : (r, d)          — per-row, varies by adapter_id[i]

Y_base   = X @ W                      → (B, d)   one shared GEMM
Y_delta  = stack_i( x_i @ B_i @ A_i ) → (B, d)   grouped per-row
Y        = Y_base + Y_delta

The base GEMM is cheap relative to the delta gather only if we don't pay one CUDA launch per row. A naive Python loop with B = 64 rows means 64 launches × ~5 µs each = 320 µs of pure overhead for what should be a few microseconds of math. That's the throughput killer.

The two S-LoRA / Punica contributions

1 · Unified adapter memory

Put every adapter's B and A matrices in one contiguous HBM pool, indexed by adapter ID:

B_pool: (N_adapters, d_out, r)
A_pool: (N_adapters, r, d_in)

No per-adapter tensor allocation. Adapters live or die based on usage: warm in HBM, cold in CPU DRAM, LRU eviction on miss. Same exact playbook as paged KV from lesson 02 — the unit of allocation is fixed, the index is the resource.

2 · BGMV / SGMV grouped-matmul kernel

One CUDA kernel that, given X (B, d), the two pools, and adapter_ids (B,), computes yi = xi · Bai · Aai in a single fused launch.

BGMV = "batched grouped matrix-vector". One CUDA block per row. Inside the block:

  1. Read this row's adapter_id.
  2. Load slices B_pool[adapter_id] and A_pool[adapter_id] through normal global-memory loads into SRAM.
  3. Do the (d, r) @ (r, d) matmul for this row's x.
  4. Write the d-dim result to Y[i].

Parallelism is over rows, not adapters. One launch, regardless of how many distinct adapters appear in the batch.

SGMV — segmented variant for when adapters repeat

If multiple rows in the batch share an adapter, we can do better than gathering (B, A) once per row. Sort the batch by adapter_id so all rows sharing an adapter sit contiguously, then assign one "segment" to each adapter group. Inside a segment, load (B, A) into SRAM once and reuse across every row in the segment — the same tile-reuse pattern FlashAttention (lesson 03) uses for K, V across query positions.

The math is exactly the same; the SRAM reuse pattern is the difference. When the batch has k distinct adapters across B rows, the (B, A) loads from HBM drop from B to k. For a typical request mix (B=64, k=4–8 distinct adapters), that's an 8–16× reduction in adapter-pool HBM traffic.

Memory budget

Per-adapter size, attention only (Q/K/V/O projections):

size_per_adapter  =  4 · 2 · d · r · dtype_bytes · n_layers
modeln_layersdrMB / adapter (fp16, attn only)warm in 30 GB
LLaMA-1 7B 3240961633 MB~930 adapters
Mistral 7B 3240968 17 MB~1830 adapters
LLaMA-2 70B 80819216160 MB~190 adapters

Numbers in the table are attention-only (Q/K/V/O). Add MLP adapters (up + down projections) and they roughly double. The pattern that matters: one base + hundreds of warm adapters in a fraction of the HBM the base itself occupies.

Throughput caveats

Interactive · heterogeneous-batch visualizer

The batch is a row of tokens colored by adapter ID. Below it is a per-token timeline showing the cost of gathering the adapter's (B, A) into SRAM plus the per-row matmul. In per-row mode each row pays the full gather cost. In SGMV grouped mode rows are sorted by adapter, and adjacent same-color rows share the gather cost (load once, reuse the tile). The KPI reports the difference.

Heterogeneous-adapter batch
Sort the batch by adapter_id (grouped mode) and the (B, A) gather happens once per segment, not once per row. Same SRAM-reuse idea as FlashAttention's tile loop.
kernel launches
(B,A) HBM loads
gather bandwidth used
throughput (rows/μs)
show the cost model
// Per row:
//   gather cost (HBM load of (d, r) + (r, d) for B_i, A_i): ~G µs
//   matmul cost (one row × adapter matrices):              ~M µs
//
// per-row mode    : B rows × (G + M)
// grouped mode    : k segments × G  +  B rows × M
//                    where k = number of distinct adapters
//                    in the batch (when sorted, each segment
//                    loads (B, A) once into SRAM).

Putting it together

approachHBM rewrites/reqkernel launches(B,A) loadsscales to N adapters
merge per request140 GB1no
naive loop (no fusion)0BBthroughput-limited
BGMV (per-row fused)01Byes
SGMV (segmented)01k (distinct adapters)yes, optimally
Takeaway
Don't merge. Compute y = x · W + x · B · A with one shared GEMM for the base and one BGMV/SGMV kernel for the delta. The shared adapter pool turns "N adapters cost N model loads" into "N adapters cost N · ~100–160 MB of HBM" (attention-only, fp16, r=16; roughly double with MLP LoRA). The fused kernel turns "B rows cost B launches" into "1 launch with row-level parallelism". Throughput within a few percent of base, hundreds of customers per process.