Learning to Decode Collaboratively with Multiple Language Models

Learning to Decode Collaboratively with Multiple Language Models

6 Mar 2024 | Shannon Zejiang Shen, Hunter Lang, Bailin Wang, Yoon Kim, David Sontag
We propose a method to teach multiple large language models (LLMs) to collaborate by interleaving their generations at the token level. We model the decision of which LLM generates the next token as a latent variable. By optimizing the marginal likelihood of a training set under our latent variable model, the base LLM automatically learns when to generate itself and when to call on one of the “assistant” language models to generate, all without direct supervision. Token-level collaboration during decoding allows for a fusion of each model’s expertise in a manner tailored to the specific task at hand. Our collaborative decoding is especially useful in cross-domain settings where a generalist base LLM learns to invoke domain expert models. On instruction-following, domain-specific QA, and reasoning tasks, we show that the performance of the joint system exceeds that of the individual models. Through qualitative analysis of the learned latent decisions, we show models trained with our method exhibit several interesting collaboration patterns, e.g., template-filling. We propose a latent-variable framework for collaborative generation, where the models learn to interleave their generations token-by-token. Each token is generated by one model, so the models collaborate to generate a token sequence together. We represent the decision of which LLM generates the next token as a latent variable, assuming no direct supervision on the decision of which model to use at each decoding step. This enables the best collaboration pattern for a given task to be learned organically from data. In our experiments, we fine-tune models for specific tasks and test the models in-domain, comparing the end-task performance between CoLLM and multiple single- or multi-model baselines. We test on 4 datasets ranging from instruction-following to solving expert problems, trying to understand when and how model collaboration can be beneficial. We investigate the collaboration between different models (e.g., between Llama models of multiple scales, and between models fine-tuned on different domains). Overall, we find that Co-LLM can learn a successful collaboration between different base and reference models, leading to better results than tuning base models alone. Our results show that Co-LLM enables a modular approach to continued pretraining and task-specific finetuning: one can pretrain a large model on a domain-specific corpus, then fine-tune smaller models with Co-LLM to leverage the knowledge from the larger models and attain improved performance on the downstream tasks. Co-LLM also allows collaboration across model scales, as shown in our experiments. We compare Co-LLM with other collaborative methods, such as Proxy Tuning and Contrastive Decoding, and find that Co-LLM performs better in terms of both accuracy and efficiency. Our results also show that Co-LLM can be applied to classification tasks, where it boosts performance by enabling improved reasoning capability. We also evaluate the joint model at different deferral frequencies on small validation sets for GSM8k, MATH, and BioASQ, and plot the results in theWe propose a method to teach multiple large language models (LLMs) to collaborate by interleaving their generations at the token level. We model the decision of which LLM generates the next token as a latent variable. By optimizing the marginal likelihood of a training set under our latent variable model, the base LLM automatically learns when to generate itself and when to call on one of the “assistant” language models to generate, all without direct supervision. Token-level collaboration during decoding allows for a fusion of each model’s expertise in a manner tailored to the specific task at hand. Our collaborative decoding is especially useful in cross-domain settings where a generalist base LLM learns to invoke domain expert models. On instruction-following, domain-specific QA, and reasoning tasks, we show that the performance of the joint system exceeds that of the individual models. Through qualitative analysis of the learned latent decisions, we show models trained with our method exhibit several interesting collaboration patterns, e.g., template-filling. We propose a latent-variable framework for collaborative generation, where the models learn to interleave their generations token-by-token. Each token is generated by one model, so the models collaborate to generate a token sequence together. We represent the decision of which LLM generates the next token as a latent variable, assuming no direct supervision on the decision of which model to use at each decoding step. This enables the best collaboration pattern for a given task to be learned organically from data. In our experiments, we fine-tune models for specific tasks and test the models in-domain, comparing the end-task performance between CoLLM and multiple single- or multi-model baselines. We test on 4 datasets ranging from instruction-following to solving expert problems, trying to understand when and how model collaboration can be beneficial. We investigate the collaboration between different models (e.g., between Llama models of multiple scales, and between models fine-tuned on different domains). Overall, we find that Co-LLM can learn a successful collaboration between different base and reference models, leading to better results than tuning base models alone. Our results show that Co-LLM enables a modular approach to continued pretraining and task-specific finetuning: one can pretrain a large model on a domain-specific corpus, then fine-tune smaller models with Co-LLM to leverage the knowledge from the larger models and attain improved performance on the downstream tasks. Co-LLM also allows collaboration across model scales, as shown in our experiments. We compare Co-LLM with other collaborative methods, such as Proxy Tuning and Contrastive Decoding, and find that Co-LLM performs better in terms of both accuracy and efficiency. Our results also show that Co-LLM can be applied to classification tasks, where it boosts performance by enabling improved reasoning capability. We also evaluate the joint model at different deferral frequencies on small validation sets for GSM8k, MATH, and BioASQ, and plot the results in the
Reach us at info@study.space
[slides and audio] Learning to Decode Collaboratively with Multiple Language Models