Layer-Condensed KV Cache for Efficient Inference of Large Language Models

Layer-Condensed KV Cache for Efficient Inference of Large Language Models

4 Jun 2024 | Haoyi Wu and Kewei Tu
This paper proposes a novel method to reduce the memory consumption and improve the inference throughput of large language models (LLMs) by reducing the number of layers whose key-value (KV) caches need to be computed and stored. The key idea is to only compute and cache the KVs of a small number of layers, significantly reducing memory usage while maintaining inference efficiency. The method is orthogonal to existing memory-saving techniques and can be easily integrated with them to further improve inference efficiency. The proposed method reduces the memory consumption of the KV cache by dramatically reducing the number of cached layers. It pairs the queries of all layers with the KVs of just the top layer, eliminating the need to compute or cache KVs for other layers. This approach saves both memory and computation, while also reducing the number of model parameters. However, since each token also attends to itself, the top-layer KVs are needed for lower-layer attention computations. To address this, the method masks the diagonal of the attention matrix, allowing the first token to use zero vectors as dummy KVs. The method also retains standard attention for a small number of layers (called warmup layers) to ensure performance is not significantly degraded. The warmup layers are placed at the top and bottom of the model, forming a sandwich configuration that outperforms alternative configurations. The training process of the model is more complex than standard transformers due to sequential dependencies between tokens. To address this, the method introduces an approximate training process that supports parallel training. The training process involves computing the cross-entropy loss after the last iteration, with gradient stopping to reduce memory consumption. Additionally, the KVs converge quickly over iterations, allowing the model to approximate the KVs of many iterations with fewer computations. Experiments on the Llama model show that the proposed method achieves significantly larger batch sizes and higher throughput compared to standard transformers. The method also performs competitively in language modeling and downstream tasks. The method can be integrated with other memory-saving techniques like StreamingLLM, achieving further improvements in inference efficiency. The method is evaluated on various tasks, including language modeling and commonsense reasoning, and shows competitive performance with standard transformers. The method is also analyzed for its design choices, including the number of warmup layers and the convergence of KVs. The results show that the method achieves a trade-off between model performance and throughput, with the number of warmup layers controlling this trade-off. The method is shown to be effective in reducing memory consumption and improving inference throughput, while maintaining competitive performance. The method is orthogonal to existing memory-saving techniques and can be easily integrated with them to further improve inference efficiency. The method is also shown to be effective in processing long sequences, making it suitable for tasks with large generation lengths. The method is evaluated on various datasets and shows promising results in terms of performance and efficiency.This paper proposes a novel method to reduce the memory consumption and improve the inference throughput of large language models (LLMs) by reducing the number of layers whose key-value (KV) caches need to be computed and stored. The key idea is to only compute and cache the KVs of a small number of layers, significantly reducing memory usage while maintaining inference efficiency. The method is orthogonal to existing memory-saving techniques and can be easily integrated with them to further improve inference efficiency. The proposed method reduces the memory consumption of the KV cache by dramatically reducing the number of cached layers. It pairs the queries of all layers with the KVs of just the top layer, eliminating the need to compute or cache KVs for other layers. This approach saves both memory and computation, while also reducing the number of model parameters. However, since each token also attends to itself, the top-layer KVs are needed for lower-layer attention computations. To address this, the method masks the diagonal of the attention matrix, allowing the first token to use zero vectors as dummy KVs. The method also retains standard attention for a small number of layers (called warmup layers) to ensure performance is not significantly degraded. The warmup layers are placed at the top and bottom of the model, forming a sandwich configuration that outperforms alternative configurations. The training process of the model is more complex than standard transformers due to sequential dependencies between tokens. To address this, the method introduces an approximate training process that supports parallel training. The training process involves computing the cross-entropy loss after the last iteration, with gradient stopping to reduce memory consumption. Additionally, the KVs converge quickly over iterations, allowing the model to approximate the KVs of many iterations with fewer computations. Experiments on the Llama model show that the proposed method achieves significantly larger batch sizes and higher throughput compared to standard transformers. The method also performs competitively in language modeling and downstream tasks. The method can be integrated with other memory-saving techniques like StreamingLLM, achieving further improvements in inference efficiency. The method is evaluated on various tasks, including language modeling and commonsense reasoning, and shows competitive performance with standard transformers. The method is also analyzed for its design choices, including the number of warmup layers and the convergence of KVs. The results show that the method achieves a trade-off between model performance and throughput, with the number of warmup layers controlling this trade-off. The method is shown to be effective in reducing memory consumption and improving inference throughput, while maintaining competitive performance. The method is orthogonal to existing memory-saving techniques and can be easily integrated with them to further improve inference efficiency. The method is also shown to be effective in processing long sequences, making it suitable for tasks with large generation lengths. The method is evaluated on various datasets and shows promising results in terms of performance and efficiency.
Reach us at info@study.space