Speculative execution for LLMs is an excellent inference-time optimization.
It hinges on the following unintuitive observation: forwarding an LLM on a single input token takes about as much time as forwarding an LLM on K input tokens in a batch (for larger K than you might think). This unintuitive fact is because sampling is heavily memory bound: most of the "work" is not doing compute, it is reading in the weights of the transformer from VRAM into on-chip cache for processing. So if you're going to do all that work of reading in all those weights, you might as well apply them to a whole batch of input vectors. I went into more detail in an earlier thread:
twitter.com/karpathy/status/…
The reason we can't naively use this fact to sample in chunks of K tokens at a time is that every N-th token depends on what token we sample at time at step N-1. There is a serial dependency, so the baseline implementation just goes one by one left to right.
Now the clever idea is to use a small and cheap draft model to first generate a candidate sequence of K tokens - a "draft". Then we feed all of these together through the big model in a batch. This is almost as fast as feeding in just one token, per the above. Then we go from left to right over the logits predicted by the model and sample tokens. Any sample that agrees with the draft allows us to immediately skip forward to the next token. If there is a disagreement then we throw the draft away and eat the cost of doing some throwaway work (sampling the draft and the forward passing for all the later tokens).
The reason this works in practice is that most of the time the draft tokens get accepted, because they are easy, so even a much smaller draft model gets them. As these easy tokens get accepted, we skip through those parts in leaps. The hard tokens where the big model disagrees "fall back" to original speed, but actually a bit slower because of all the extra work.
So TLDR: this one weird trick works because LLMs are memory bound at inference time, in the "batch size 1" setting of sampling a single sequence of interest, that a large fraction of "local LLM" use cases fall into. And because most tokens are "easy".
References
arxiv.org/abs/2302.01318
arxiv.org/abs/1811.03115
arxiv.org/abs/2211.17192
Full F16 precision 34B Code Llama at >20 t/s on M2 Ultra