MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads

MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads

2024 | Tianle Cai * 1 2 Yuhong Li * 3 Zhengyang Geng 4 Hongwu Peng 5 Jason D. Lee 1 Deming Chen 3 Tri Dao 1 2
**Abstract:** Large Language Models (LLMs) rely on auto-regressive decoding, which is memory-bandwidth-bound due to the sequential nature of each step requiring the full model parameters to be moved from High-Bandwidth Memory (HBM) to the accelerator's cache. This paper introduces MEDUSA, an efficient method that accelerates LLM inference by adding multiple decoding heads to predict subsequent tokens in parallel. Using a tree-based attention mechanism, MEDUSA constructs and verifies multiple candidate continuations simultaneously, reducing the number of decoding steps. Two fine-tuning procedures, MEDUSA-1 and MEDUSA-2, are proposed to meet different use cases: MEDUSA-1 fine-tunes on a frozen backbone LLM, enabling lossless inference acceleration, while MEDUSA-2 fine-tunes with the backbone LLM, improving prediction accuracy and speedup. Extensions include self-distillation for data-limited scenarios and a typical acceptance scheme to boost acceptance rates while maintaining generation quality. Experiments on various models show that MEDUSA-1 achieves over 2.2× speedup without compromising quality, and MEDUSA-2 further improves speedup to 2.3-2.8×. **Introduction:** The growth of LLMs has led to increased inference latency, a significant challenge for practical applications. MEDUSA addresses this by integrating multiple decoding heads on top of the backbone model, allowing for concurrent prediction of multiple tokens. This method overcomes the challenges of speculative decoding, ensuring seamless integration into existing LLM systems. MEDUSA-1 and MEDUSA-2 provide different fine-tuning procedures to meet various computational and performance requirements. The paper also introduces extensions like self-distillation and a typical acceptance scheme to enhance efficiency and maintain generation quality. **Methodology:** MEDUSA follows a framework similar to speculative decoding, with three substeps: generating candidates, processing candidates, and accepting candidates. MEDUSA heads are additional decoding heads appended to the last hidden states of the original model, predicting multiple tokens in parallel. Tree attention is used to process multiple candidates concurrently, and a typical acceptance scheme is proposed to select reasonable candidates. Two training strategies, MEDUSA-1 and MEDUSA-2, are detailed, along with extensions for self-distillation and optimized tree construction. **Experiments:** Experiments on models like Vicuna-7B, 13B, and Zephyr-7B demonstrate the effectiveness of MEDUSA. MEDUSA-1 achieves over 2× speedup, and MEDUSA-2 further improves this to 2.3-2.8×. Ablation studies on tree attention configuration, typical acceptance thresholds, and two-stage fine-tuning strategies are conducted to validate the methods. **Discussion:** MEDUSA enhances LLM inference speed by 2.3-2.8 times through the use of multiple decoding heads, parameter efficiency, and ease of integration. The typical acceptance scheme simplifies the rejection sampling process while**Abstract:** Large Language Models (LLMs) rely on auto-regressive decoding, which is memory-bandwidth-bound due to the sequential nature of each step requiring the full model parameters to be moved from High-Bandwidth Memory (HBM) to the accelerator's cache. This paper introduces MEDUSA, an efficient method that accelerates LLM inference by adding multiple decoding heads to predict subsequent tokens in parallel. Using a tree-based attention mechanism, MEDUSA constructs and verifies multiple candidate continuations simultaneously, reducing the number of decoding steps. Two fine-tuning procedures, MEDUSA-1 and MEDUSA-2, are proposed to meet different use cases: MEDUSA-1 fine-tunes on a frozen backbone LLM, enabling lossless inference acceleration, while MEDUSA-2 fine-tunes with the backbone LLM, improving prediction accuracy and speedup. Extensions include self-distillation for data-limited scenarios and a typical acceptance scheme to boost acceptance rates while maintaining generation quality. Experiments on various models show that MEDUSA-1 achieves over 2.2× speedup without compromising quality, and MEDUSA-2 further improves speedup to 2.3-2.8×. **Introduction:** The growth of LLMs has led to increased inference latency, a significant challenge for practical applications. MEDUSA addresses this by integrating multiple decoding heads on top of the backbone model, allowing for concurrent prediction of multiple tokens. This method overcomes the challenges of speculative decoding, ensuring seamless integration into existing LLM systems. MEDUSA-1 and MEDUSA-2 provide different fine-tuning procedures to meet various computational and performance requirements. The paper also introduces extensions like self-distillation and a typical acceptance scheme to enhance efficiency and maintain generation quality. **Methodology:** MEDUSA follows a framework similar to speculative decoding, with three substeps: generating candidates, processing candidates, and accepting candidates. MEDUSA heads are additional decoding heads appended to the last hidden states of the original model, predicting multiple tokens in parallel. Tree attention is used to process multiple candidates concurrently, and a typical acceptance scheme is proposed to select reasonable candidates. Two training strategies, MEDUSA-1 and MEDUSA-2, are detailed, along with extensions for self-distillation and optimized tree construction. **Experiments:** Experiments on models like Vicuna-7B, 13B, and Zephyr-7B demonstrate the effectiveness of MEDUSA. MEDUSA-1 achieves over 2× speedup, and MEDUSA-2 further improves this to 2.3-2.8×. Ablation studies on tree attention configuration, typical acceptance thresholds, and two-stage fine-tuning strategies are conducted to validate the methods. **Discussion:** MEDUSA enhances LLM inference speed by 2.3-2.8 times through the use of multiple decoding heads, parameter efficiency, and ease of integration. The typical acceptance scheme simplifies the rejection sampling process while
Reach us at info@study.space