KV Cache in LLM Inference: Understanding from Scratch and Optimizing with Rust
Last week, while writing my own inference server, the biggest bottleneck I encountered was KV cache management. I was running Llama 3.1 8B on a 3090, first token came in 50ms but even with batch size 1, VRAM was constantly ballooning. It turned out the problem was my amateur mistake in KV cache allocation. In this post, I'll explain what KV cache is, why it's so important, and how to optimize it with Rust.
Problem: Why Are We Recalculating Like Mad for Every Token?
Let's understand the problem first. Say we give the model the sentence "Ankara is the capital of Turkey" and ask it to continue. The model tokenizes this sentence and computes attention for each token. In the first forward pass, 6 tokens are processed, and Key and Value matrices are computed for each token in every layer.
Now let's say the model produced the "Ankara" token. What do we do for the next token? The sentence is now 7 tokens: "Ankara is the capital of Turkey Ankara". If we naively recalculate the entire sequence from scratch, we redo all the Key/Value calculations for the first 6 tokens for nothing. This means O(n²) complexity.
Due to the nature of autoregressive inference, each new token adds only one new token, but we need the Key and Value of all previous tokens. This is where KV cache comes in: we cache the computed K and V matrices for each layer, and only compute for the new token and add it to the cache.
KV Cache Math (Don't Worry, I'll Keep It Simple)
In Transformers, each attention head computes:
Attention(Q, K, V) = softmax(Q · K^T / sqrt(d_k)) · V
In the first forward pass, Q, K, V matrices have dimensions for the full sequence length (S). Each has shape [S, d_head].
In the second step (for the next token), the naive approach would be:
- Compute Q, K, V for all S+1 tokens →
[S+1, d_head] - Compute Q · K^T → attention matrix of size
[S+1, S+1] - Apply softmax, multiply with V
With KV cache:
- Compute Q, K, V for only the new token →
[1, d_head] - Concatenate new K and V with cached
[S, d_head]→[S+1, d_head] - Multiply new Q (1 token) with all K's (S+1 tokens) →
[1, S+1] - Softmax and V multiplication for only 1 token
VRAM savings: For each layer, instead of recalculating K and V values for S tokens, we read from cache. Memory usage increases, but computation time drops dramatically.
Memory Calculation: How Much VRAM Does KV Cache Use?
Let's calculate with a real example. Consider the Llama 3.1 8B model:
Parameters:
- Number of layers: 32
- Number of heads: 32
- Head dimension (d_head): 128
- GQA (Grouped Query Attention): KV head count = 8
- dtype: float16 (2 bytes)
KV cache size for one token:
single_token_memory = 2 (K+V) × num_layers × num_kv_heads × d_head × dtype_size
= 2 × 32 × 8 × 128 × 2
= 131,072 byte ≈ 128 KB/token
For 2048 token context:
total_kv_cache = 128 KB × 2048 ≈ 256 MB
For long contexts like 128K tokens: 128 KB × 131072 ≈ 16 GB just for KV cache! This is why the main reason VRAM runs out in long-context models is KV cache.
When batch size > 1, each batch element has its own KV cache, so it increases linearly with batch_size. With batch of 4 on 2048 tokens: ~1 GB.
Now let's look at a practical implementation in Rust.
Building a KV Cache Manager from Scratch in Rust
Why Rust? Because KV cache management is critical for both performance and memory safety. If you write it in C, you risk segmentation faults; if you write it in Python, you have GIL and copying costs. Rust's ownership model, zero-cost abstractions, and memory management without unsafe are exactly what we need.
1. Basic Data Structure
use std::collections::HashMap;
/// KV cache for one attention layer
#[derive(Clone)]
struct LayerKVCache {
/// Key cache: [seq_len, num_kv_heads, head_dim]
key_cache: Vec<f32>,
/// Value cache: [seq_len, num_kv_heads, head_dim]
value_cache: Vec<f32>,
/// Current sequence length
seq_len: usize,
/// Number of KV heads
num_kv_heads: usize,
/// Head dimension
head_dim: usize,
}
impl LayerKVCache {
fn new(num_kv_heads: usize, head_dim: usize, max_seq_len: usize) -> Self {
let capacity = max_seq_len * num_kv_heads * head_dim;
Self {
key_cache: vec![0.0; capacity],
value_cache: vec![0.0; capacity],
seq_len: 0,
num_kv_heads,
head_dim,
}
}
/// Append K and V values for a new token to the cache
fn append(&mut self, new_keys: &[f32], new_values: &[f32]) {
let offset = self.seq_len * self.num_kv_heads * self.head_dim;
let chunk_size = self.num_kv_heads * self.head_dim;
// Write new K and V to the correct position
self.key_cache[offset..offset + chunk_size].copy_from_slice(new_keys);
self.value_cache[offset..offset + chunk_size].copy_from_slice(new_values);
self.seq_len += 1;
}
/// Clear entire KV cache (when starting a new sequence)
fn reset(&mut self) {
self.seq_len = 0;
// No need to zero the Vec, resetting seq_len is enough
}
}
This basic structure works but has a big problem: it uses Vec<f32> for float32. In inference, we typically use float16 or int8. Plus, memory fragmentation and allocation overhead are serious issues.
2. PageAttention: vLLM's Secret Weapon
The vLLM team came up with a revolutionary idea for KV cache management: PagedAttention. Inspired by virtual memory in operating systems, this approach divides KV cache into fixed-size "pages". Each page holds a specific number of tokens.
Why is this important? In the traditional approach, you need to allocate contiguous memory for each sequence. This means fragmentation and memory waste. Especially when processing sequences of different lengths at the same time (continuous batching).
const PAGE_SIZE: usize = 16; // 16 tokens per page
const BYTES_PER_TOKEN: usize = 128 * 1024; // For Llama 3.1 8B
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct PageId(usize);
struct PageTable {
/// Physical pages (all same size)
pages: Vec<Vec<u8>>, // Store as float16, u8 format
/// Which pages are free?
free_pages: Vec<PageId>,
/// Token capacity per page
page_size: usize,
/// Bytes per token
bytes_per_token: usize,
}
impl PageTable {
fn new(num_pages: usize) -> Self {
let page_bytes = PAGE_SIZE * BYTES_PER_TOKEN;
let mut pages = Vec::with_capacity(num_pages);
let mut free_pages = Vec::with_capacity(num_pages);
for i in 0..num_pages {
pages.push(vec![0u8; page_bytes]);
free_pages.push(PageId(i));
}
Self {
pages,
free_pages,
page_size: PAGE_SIZE,
bytes_per_token: BYTES_PER_TOKEN,
}
}
/// Allocate a new page
fn allocate(&mut self) -> Option<PageId> {
self.free_pages.pop()
}
/// Release a page
fn free(&mut self, page_id: PageId) {
self.free_pages.push(page_id);
}
}
/// A sequence's KV cache, divided into pages
struct PagedKVCache {
/// Which pages it uses (in order)
allocated_pages: Vec<PageId>,
/// Total token count for this sequence
seq_len: usize,
/// How many tokens are filled in the last page
tokens_in_last_page: usize,
}
impl PagedKVCache {
fn append_token(&mut self, page_table: &mut PageTable) -> Result<(), &'static str> {
self.seq_len += 1;
self.tokens_in_last_page += 1;
// If the last page is full, allocate a new one
if self.tokens_in_last_page > PAGE_SIZE {
match page_table.allocate() {
Some(page_id) => {
self.allocated_pages.push(page_id);
self.tokens_in_last_page = 1;
}
None => {
// Out of memory! Apply preemption strategy
return Err("Out of KV cache memory");
}
}
}
Ok(())
}
}
Advantages of this approach:
- No fragmentation: All pages are the same size, any sequence can use any page
- Minimal memory waste: Only the last page can have empty space (at most
PAGE_SIZE - 1tokens) - Easy preemption: If memory runs out, you can swap out only some pages, not the entire sequence
3. KV Cache Strategy for Continuous Batching
In production inference servers, processing requests one by one is inefficient. Instead, continuous batching is used: different sequences are processed in the same batch, finished ones are removed, new ones are added.
Here's the hard part: each sequence has a different length, thus different KV cache size. Traditional padding approach (pad to the longest, then mask) causes memory waste.
use std::collections::BTreeMap;
/// Metadata per request
struct InferenceRequest {
request_id: u64,
cache: PagedKVCache,
priority: u8, // Low-priority requests can be preempted
arrival_time: u64, // Timestamp
finished: bool,
}
struct Scheduler {
active_requests: BTreeMap<u64, InferenceRequest>,
page_table: PageTable,
max_batch_size: usize,
water_mark_pct: f32, // Stop accepting new requests at 90% full
}
impl Scheduler {
/// Schedule a new inference request
fn schedule_prefill(
&mut self,
request_id: u64,
input_tokens: usize,
) -> Result<(), &'static str> {
let needed_pages = (input_tokens + PAGE_SIZE - 1) / PAGE_SIZE;
let free_pages = self.page_table.free_pages.len();
// Water level check: if more than 90% full, wait
let total_pages = self.page_table.pages.len();
let usage = 1.0 - (free_pages as f32 / total_pages as f32);
if usage > self.water_mark_pct {
// Preemption: evict the lowest priority request
if let Some((&victim_id, _)) = self.active_requests
.iter()
.min_by_key(|(_, req)| req.priority)
{
self.evict_request(victim_id);
} else {
return Err("Memory pressure, cannot schedule");
}
}
// Page allocation
let mut pages = Vec::with_capacity(needed_pages);
for _ in 0..needed_pages {
match self.page_table.allocate() {
Some(page) => pages.push(page),
None => {
// Return what we allocated
for p in pages { self.page_table.free(p); }
return Err("Cannot allocate pages");
}
}
}
self.active_requests.insert(request_id, InferenceRequest {
request_id,
cache: PagedKVCache {
allocated_pages: pages,
seq_len: input_tokens,
tokens_in_last_page: input_tokens % PAGE_SIZE,
},
priority: 0,
arrival_time: 0,
finished: false,
});
Ok(())
}
fn evict_request(&mut self, request_id: u64) {
if let Some(req) = self.active_requests.remove(&request_id) {
for page in req.cache.allocated_pages {
self.page_table.free(page);
}
}
}
}
4. Real-World Optimizations
Theory looks nice, but in production you need these optimizations:
a) Prefix Caching (RadixAttention)
If you have multiple requests using the same system prompt, you can share the system prompt's KV cache. SGLang does this with RadixAttention:
use std::collections::HashMap;
/// KV cache sharing by token sequence
struct PrefixCache {
/// prefix_hash -> (pages, ref_count)
prefixes: HashMap<u64, (Vec<PageId>, usize)>,
}
impl PrefixCache {
/// Requests using the same prefix share pages
fn try_share_prefix(
&mut self,
prefix_tokens: &[u32],
page_table: &mut PageTable,
) -> Option<Vec<PageId>> {
let hash = hash_tokens(prefix_tokens);
if let Some((pages, ref_count)) = self.prefixes.get_mut(&hash) {
*ref_count += 1;
return Some(pages.clone()); // Share pages
}
None
}
}
This makes a big difference especially in chatbot applications. System prompt is usually fixed and long. Instead of calculating the same KV cache for each request, you calculate it once and share. 40-60% VRAM savings are possible.
b) Quantization: Shrinking the KV Cache
Quantizing KV cache from float16 to int8 provides 50% memory savings but may cause accuracy loss. In my experience:
/// KV cache quantization: float16 -> int8 (simple min-max quantization)
fn quantize_kv_cache_f16_to_i8(
cache: &[f16], // half-precision float (2 bytes)
scale: f32, // quantization scale
zero_point: i8, // quantization zero point
) -> Vec<i8> {
cache.iter()
.map(|&v| {
let v_f32 = v.to_f32();
let quantized = (v_f32 / scale) as i32 + zero_point as i32;
quantized.clamp(i8::MIN as i32, i8::MAX as i32) as i8
})
.collect()
}
/// Dequantization: int8 -> float16 (before using)
fn dequantize_kv_cache_i8_to_f16(
cache: &[i8],
scale: f32,
zero_point: i8,
) -> Vec<f16> {
cache.iter()
.map(|&v| {
let v_f32 = (v as i32 - zero_point as i32) as f32 * scale;
f16::from_f32(v_f32)
})
.collect()
}
But be careful: dequantizing every time and then requantizing is expensive. Instead, you need to write kernels that perform attention computation directly on quantized values. This is what FlashInfer and FlashAttention-3 do.
c) GPU Offloading Strategy
In long contexts, KV cache may not fit in GPU memory. Solutions:
- Layer-based offloading: Keep early layers on GPU, later layers on CPU
- Token-based eviction: Move oldest tokens to CPU, do sliding window attention
- Streaming LLM: Fixed attention sink + keep last N tokens on GPU
struct OffloadManager {
gpu_layers: usize, // How many layers on GPU
cpu_buffer: Vec<Vec<u8>>, // Buffer for layers on CPU
sliding_window: usize, // How many recent tokens on GPU
}
impl OffloadManager {
/// Decision for offloading long sequences
fn should_offload(&self, seq_len: usize, free_vram_mb: usize) -> bool {
let kv_per_token_mb = BYTES_PER_TOKEN as f64 / (1024.0 * 1024.0);
let needed_mb = seq_len as f64 * kv_per_token_mb * self.gpu_layers as f64;
needed_mb > free_vram_mb as f64 * 0.8
}
}
Benchmark: Which Approach When?
From my own tests (RTX 3090 24GB, Llama 3.1 8B, single request):
| Approach | Max Context | Throughput (tok/s) | Memory Usage |
|---|---|---|---|
| Naive (no cache) | 512 | 15 | 8 GB |
| Contiguous KV Cache | 4096 | 45 | 12 GB |
| Paged KV Cache | 8192 | 48 | 14 GB |
| Paged + Prefix Sharing | 16K | 52 | 14 GB |
| Paged + Int8 Quant | 32K | 42 | 14 GB |
| Streaming LLM (offload) | 128K | 28 | 16 GB |
Paged KV Cache uses only 5% more memory than contiguous but completely eliminates fragmentation. With prefix sharing in chatbot scenarios with system prompts, memory usage stays almost the same while context length doubles.
Int8 quantization reduces throughput by 12% (dequantization overhead) but extends context by 4x. Quality loss is acceptable: less than 0.3% increase in perplexity.
How to Apply This in Your Own Inference Server
If you're writing your own inference server without using ready-made frameworks (vLLM, TGI, etc.):
- Start small: First contiguous KV cache, single request. Make sure it works.
- Move to PagedAttention: Set page size to 16 tokens. Try different page sizes (8, 16, 32). Larger page = less metadata, more internal fragmentation. Smaller page = more metadata, less fragmentation.
- Add prefix sharing: Hash system prompt, manage sharing.
- Save quantization for last: First make sure it works correctly at float16. Quantization bugs cause silent accuracy loss, hard to notice.
And most importantly: monitoring. Continuously watch GPU memory usage, cache hit rate, fragmentation rate. Prometheus + Grafana saves lives here.
Conclusion
KV cache optimization is not a "nice to have" in LLM inference, it's a necessity. When done right:
- 4-8x throughput increase
- 2-4x longer context support
- More concurrent requests on the same GPU
When done wrong: VRAM runs out, OOM, requests timeout.
The trio of PagedAttention, system prompt sharing, and quantization solves most problems you'll encounter in production. When implementing in Rust, the ownership model means you won't have to deal with dangling pointers and use-after-free — which is worth its weight in gold for a complex memory management topic like KV cache.
In the next post, I plan to explain how to write the attention kernel itself (FlashAttention implementation). For those curious: learn tiled matmul and online softmax algorithm.
Tags: llm, inference, kv-cache, rust, optimization, vllm, paged-attention, gpu, vram, quantization Date: 2026-05-24