KV Prediction for Improved Time to First Token
LLM inference can be split in two phases: Prefilling and Decoding.
The decoding phase is in autoregressive mode, where tokens are generating one by one, by re-using previous Key/Value tensors in the KV-cache. To speed up that process, we can use speculative decoding (
arxiv.org/abs/2302.01318) where a small draft model samples quickly many tokens, and a bigger scorer model check from time to time, in a single forward, if the draft looks ok.
KV Prediction for Improved Time to First Token (
arxiv.org/abs/2410.08391) proposes a similar thing, but for decoding, where a smaller model will predict the KV-cache, per layers, of the big model. This is quite useful with extremely long context, where this initial prefilling contains tons of pdf or videos.
The problem is that while in speculative decoding, the small and big models share the same token logits space, in the prefiling of the KV-cache, the dimensions of the tensors are different between the two models. So it wouldn't work out of the box.
The paper proposes to train a linear projection per layer, to go from the small model space to the larger model space.
Flops is reduced with less performance degradation (see KVP-C and KVP-LP).
The prompt length seems however quite small, so it'd be worth investigating how well it can scale, in the hay in the needle in a haystack benchmark from Anthropic (
anthropic.com/news/claude-3-…, ctrl f "recall").