21 Jun 2024 | Kaixuan Huang, Xudong Guo, Mengdi Wang
**SpecDec++: Boosting Speculative Decoding via Adaptive Candidate Lengths**
**Authors:** Kaixuan Huang
**Abstract:**
Speculative decoding reduces inference latency by using a smaller, faster draft model to generate candidate tokens for the target large language model. The performance depends on the hyperparameter \( K \), the number of candidate tokens. Previous methods often use heuristics to choose \( K \), leading to sub-optimal performance. This paper formulates the choice of \( K \) as a Markov Decision Process (MDP) and theoretically shows that the optimal policy is a threshold policy, where speculation should stop when the probability of rejection exceeds a threshold. Inspired by this theory, SpecDec++ adaptively determines the candidate length by training an acceptance prediction head on the draft model to predict the conditional acceptance probability of candidate tokens. SpecDec++ stops speculation when the predicted probability of at least one token being rejected exceeds a threshold. Experiments on the Alpaca, HumanEval, and GSM8K datasets show that SpecDec++ achieves a 2.04x, 2.26x, and 2.23x speedup, respectively, compared to the baseline speculative decoding method.
**Contributions:**
- Formulate the dynamic choice of candidate length in speculative decoding as an MDP.
- Propose SpecDec++, an enhanced version of speculative decoding that adaptively determines the candidate length.
- Use a weighted Binary Cross-Entropy loss and token-mixing method to efficiently train the prediction head.
- Achieve significant improvements over the baseline speculative decoding method on multiple datasets.
**Background:**
- **Rejection Sampling:** A method to sample from a target distribution by first sampling from a draft distribution and then accepting or rejecting the sample based on the acceptance probability.
- **Speculative Decoding:** A technique that chains multiple rejection sampling procedures to generate sequences from the target distribution.
**Inference Time Analysis:**
- The total inference time is influenced by the number of discarded tokens and the number of forward passes of the target model.
- The optimal policy in the MDP framework is to stop speculation when the probability of at least one token being rejected exceeds a threshold.
**SpecDec++: Theory and Algorithm:**
- Formulate speculative decoding as an MDP with states, actions, transitions, and immediate costs.
- Propose SpecDec++ with an additional prediction head to determine whether to stop speculation based on the predicted acceptance probability.
**Experiments:**
- Compare SpecDec++ with the baseline speculative decoding method on the Alpaca, HumanEval, and GSM8K datasets.
- Show that SpecDec++ achieves significant speedups and improvements in both discard rate and verification rate.
**Related Work:**
- Discuss previous works on speculative decoding and candidate length selection, highlighting the complementary nature of this work.
**Conclusion:**
This paper addresses the problem of determining the candidate lengths for speculative decoding by formulating it as a**SpecDec++: Boosting Speculative Decoding via Adaptive Candidate Lengths**
**Authors:** Kaixuan Huang
**Abstract:**
Speculative decoding reduces inference latency by using a smaller, faster draft model to generate candidate tokens for the target large language model. The performance depends on the hyperparameter \( K \), the number of candidate tokens. Previous methods often use heuristics to choose \( K \), leading to sub-optimal performance. This paper formulates the choice of \( K \) as a Markov Decision Process (MDP) and theoretically shows that the optimal policy is a threshold policy, where speculation should stop when the probability of rejection exceeds a threshold. Inspired by this theory, SpecDec++ adaptively determines the candidate length by training an acceptance prediction head on the draft model to predict the conditional acceptance probability of candidate tokens. SpecDec++ stops speculation when the predicted probability of at least one token being rejected exceeds a threshold. Experiments on the Alpaca, HumanEval, and GSM8K datasets show that SpecDec++ achieves a 2.04x, 2.26x, and 2.23x speedup, respectively, compared to the baseline speculative decoding method.
**Contributions:**
- Formulate the dynamic choice of candidate length in speculative decoding as an MDP.
- Propose SpecDec++, an enhanced version of speculative decoding that adaptively determines the candidate length.
- Use a weighted Binary Cross-Entropy loss and token-mixing method to efficiently train the prediction head.
- Achieve significant improvements over the baseline speculative decoding method on multiple datasets.
**Background:**
- **Rejection Sampling:** A method to sample from a target distribution by first sampling from a draft distribution and then accepting or rejecting the sample based on the acceptance probability.
- **Speculative Decoding:** A technique that chains multiple rejection sampling procedures to generate sequences from the target distribution.
**Inference Time Analysis:**
- The total inference time is influenced by the number of discarded tokens and the number of forward passes of the target model.
- The optimal policy in the MDP framework is to stop speculation when the probability of at least one token being rejected exceeds a threshold.
**SpecDec++: Theory and Algorithm:**
- Formulate speculative decoding as an MDP with states, actions, transitions, and immediate costs.
- Propose SpecDec++ with an additional prediction head to determine whether to stop speculation based on the predicted acceptance probability.
**Experiments:**
- Compare SpecDec++ with the baseline speculative decoding method on the Alpaca, HumanEval, and GSM8K datasets.
- Show that SpecDec++ achieves significant speedups and improvements in both discard rate and verification rate.
**Related Work:**
- Discuss previous works on speculative decoding and candidate length selection, highlighting the complementary nature of this work.
**Conclusion:**
This paper addresses the problem of determining the candidate lengths for speculative decoding by formulating it as a