2024 | Tianle Cai, Yuhong Li, Zhengyang Geng, Hongwu Peng, Jason D. Lee, Deming Chen, Tri Dao
MEDUSA is a framework that accelerates large language model (LLM) inference by adding multiple decoding heads to predict multiple tokens in parallel. It uses a tree-based attention mechanism to generate candidate continuations and verify them simultaneously, reducing the number of decoding steps. Two fine-tuning approaches, MEDUSA-1 and MEDUSA-2, are introduced. MEDUSA-1 fine-tunes the heads on a frozen backbone model, achieving over 2.2× speedup without compromising generation quality. MEDUSA-2 fine-tunes the heads and backbone together, achieving 2.3–2.8× speedup but requiring a specialized training recipe. Extensions include self-distillation for training without data and a typical acceptance scheme to improve acceptance rates while maintaining quality. Experiments on various models show significant speedups, with MEDUSA-2 achieving up to 2.8× speedup on Vicuna-7B and Vicuna-33B. The framework is efficient, easy to integrate, and suitable for distributed systems. It enhances inference speed while preserving model quality, offering a promising direction for optimizing LLMs.MEDUSA is a framework that accelerates large language model (LLM) inference by adding multiple decoding heads to predict multiple tokens in parallel. It uses a tree-based attention mechanism to generate candidate continuations and verify them simultaneously, reducing the number of decoding steps. Two fine-tuning approaches, MEDUSA-1 and MEDUSA-2, are introduced. MEDUSA-1 fine-tunes the heads on a frozen backbone model, achieving over 2.2× speedup without compromising generation quality. MEDUSA-2 fine-tunes the heads and backbone together, achieving 2.3–2.8× speedup but requiring a specialized training recipe. Extensions include self-distillation for training without data and a typical acceptance scheme to improve acceptance rates while maintaining quality. Experiments on various models show significant speedups, with MEDUSA-2 achieving up to 2.8× speedup on Vicuna-7B and Vicuna-33B. The framework is efficient, easy to integrate, and suitable for distributed systems. It enhances inference speed while preserving model quality, offering a promising direction for optimizing LLMs.