March 13, 2024 | Sainbayar Sukhbaatar, Olga Golovneva, Vasu Sharma, Hu Xu, Xi Victoria Lin, Baptiste Rozière, Jacob Kahn, Daniel Li, Wen-tau Yih, Jason Weston, Xian Li
The paper introduces Branch-Train-Mix (BTX), a method for training Large Language Models (LLMs) to possess specialized capabilities in multiple domains such as coding, math reasoning, and world knowledge. BTX starts with a seed model and branches it into multiple expert models, each trained on a specific domain dataset. These experts are then combined into a single Mixture-of-Experts (MoE) model through a finetuning stage, where the router network learns to select the appropriate expert for each token. This approach combines the benefits of asynchronous parallel training and MoE, achieving high throughput and reduced communication costs. Compared to alternative methods, BTX achieves better accuracy and efficiency, outperforming both Branch-Train-Merge (BTM) and sparse upcycling. The experiments demonstrate that BTX improves performance on various tasks, especially in domains where specialized models excel, while retaining general knowledge from the seed model. The method is efficient and robust, making it a promising approach for training LLMs with specialized capabilities.The paper introduces Branch-Train-Mix (BTX), a method for training Large Language Models (LLMs) to possess specialized capabilities in multiple domains such as coding, math reasoning, and world knowledge. BTX starts with a seed model and branches it into multiple expert models, each trained on a specific domain dataset. These experts are then combined into a single Mixture-of-Experts (MoE) model through a finetuning stage, where the router network learns to select the appropriate expert for each token. This approach combines the benefits of asynchronous parallel training and MoE, achieving high throughput and reduced communication costs. Compared to alternative methods, BTX achieves better accuracy and efficiency, outperforming both Branch-Train-Merge (BTM) and sparse upcycling. The experiments demonstrate that BTX improves performance on various tasks, especially in domains where specialized models excel, while retaining general knowledge from the seed model. The method is efficient and robust, making it a promising approach for training LLMs with specialized capabilities.