5 Mar 2024 | Wonseok Jeon, Mukul Gagrani, Raghav Goel, Junyoung Park, Mingu Lee, Christopher Lott
Recursive Speculative Decoding (RSD) is a novel tree-based method for accelerating large language model (LLM) inference by sampling draft tokens without replacement, maximizing the diversity of the draft-token tree. The method leverages the parallelizability of the transformer network and uses a small draft model to generate draft tokens, which are then verified by the target LLM. RSD introduces recursive rejection sampling, which allows for the recovery of the target distribution by sampling without replacement. Two variants of RSD are proposed: RSD-C, which uses a constant branching factor to construct the draft-token tree, and RSD-S, which employs stochastic beam search to sample sequences without replacement while truncating unlikely sequences. RSD outperforms existing methods in both fixed draft sequence length and fixed computational budget scenarios, demonstrating superior performance in terms of block efficiency, memory-bound speed-up, and token rate. The method is evaluated on Llama 2 and OPT models, showing consistent improvements over baseline methods. RSD's ability to sample without replacement and maximize diversity leads to higher acceptance rates and better performance in resource-constrained environments. Theoretical analysis confirms that RSD recovers the target distribution, and empirical results validate its effectiveness in accelerating LLM inference.Recursive Speculative Decoding (RSD) is a novel tree-based method for accelerating large language model (LLM) inference by sampling draft tokens without replacement, maximizing the diversity of the draft-token tree. The method leverages the parallelizability of the transformer network and uses a small draft model to generate draft tokens, which are then verified by the target LLM. RSD introduces recursive rejection sampling, which allows for the recovery of the target distribution by sampling without replacement. Two variants of RSD are proposed: RSD-C, which uses a constant branching factor to construct the draft-token tree, and RSD-S, which employs stochastic beam search to sample sequences without replacement while truncating unlikely sequences. RSD outperforms existing methods in both fixed draft sequence length and fixed computational budget scenarios, demonstrating superior performance in terms of block efficiency, memory-bound speed-up, and token rate. The method is evaluated on Llama 2 and OPT models, showing consistent improvements over baseline methods. RSD's ability to sample without replacement and maximize diversity leads to higher acceptance rates and better performance in resource-constrained environments. Theoretical analysis confirms that RSD recovers the target distribution, and empirical results validate its effectiveness in accelerating LLM inference.