I added KV caching and INT8 KV quantization to our transformer inference, improving throughput by 35x.
All of this was done from scratch in Rust CUDA, on top of a homemade ML framework.
On a 4-token prompt with 252 generated tokens:
- Original: 0.76 tok/s
- KV cache fp32: 27.21 tok/s
- KV cache int8 (quantized): 27.29 tok/s
Try it out yourself here:
mni-ml.github.io/demos/kv-ca…
In practice:
- KV caching gave us about a 35x end-to-end speedup
- INT8 KV cache kept roughly the same speed as fp32 but cut KV cache memory by 3.78x
FP32 cache used 4.5 MB in this run while the INT8 cache used only 1.19 MB
This simple change to inference created a huge impact on performance. To learn more about the KV cache and other optimizations like this, check out the blog at
mni.ml!