Training Large Language Models for Reasoning through Reverse Curriculum Reinforcement Learning

Training Large Language Models for Reasoning through Reverse Curriculum Reinforcement Learning

17 Mar 2024 | Zhiheng Xi, Wenxiang Chen, Boyang Hong, Senjie Jin, Rui Zheng, Wei He, Yiwen Ding, Shichun Liu, Xin Guo, Junzhe Wang, Honglin Guo, Wei Shen, Xiaoran Fan, Yuhao Zhou, Shihan Dou, Xiao Wang, Xinbo Zhang, Peng Sun, Tao Gui, Qi Zhang, Xuanjing Huang
This paper proposes R³, a novel method for training large language models (LLMs) for reasoning using reverse curriculum reinforcement learning (RL), which employs only outcome supervision to achieve the benefits of process supervision. The core challenge in applying RL to complex reasoning is identifying a sequence of actions that result in positive rewards and providing appropriate supervision for optimization. Outcome supervision provides sparse rewards for final results without identifying error locations, whereas process supervision offers step-wise rewards but requires extensive manual annotation. R³ overcomes these limitations by learning from correct demonstrations. Specifically, R³ progressively slides the start state of reasoning from a demonstration's end to its beginning, facilitating easier model exploration at all stages. Thus, R³ establishes a stepwise curriculum, allowing outcome supervision to offer step-level signals and precisely pinpoint errors. Using Llama2-7B, our method surpasses RL baseline on eight reasoning tasks by 4.1 points on average. Notably, in program-based reasoning on GSM8K, it exceeds the baseline by 4.2 points across three backbone models, and without any extra data, Codellama-7B + R³ performs comparable to larger models or closed-source models. The paper introduces R³ as a method that uses outcome supervision to achieve the effect of process supervision, conducts extensive experiments across eight reasoning tasks to highlight the effectiveness of the method, and performs in-depth ablation and analysis to provide insights into the training dynamics of R³ and how it works. The method facilitates model exploration by shortening the reasoning chain and narrowing the sampling space, aiding the model in gaining positive rewards more efficiently. R³ is interpreted as a form of dynamic programming. The method is shown to be effective in various reasoning tasks, including logical reasoning, mathematical reasoning, reading comprehension, and natural language inference (NLI). The experiments demonstrate that R³ outperforms both the SFT and RL baselines across eight reasoning tasks, achieving an average improvement of 5.4 points and 4.1 points, respectively. Notably, in program-based reasoning on GSM8K, it surpasses SFT and RL by an average of 11.4 points and 4.2 points, respectively. Moreover, Codellama-7B + R³ outshines models that use extra annotated data like MAmmoTH-Coder and Tora, and is comparable to larger or closed-source models such as GPT-3.5-Turbo. The paper also discusses the impact of different reward functions and the importance of challenging data in training. The results show that R³ provides stable reinforcement learning and is effective in various reasoning tasks. The method is shown to be versatile and adaptable, capable of extending to various reasoning styles like programs. The paper concludes that R³ is a promising approach for training LLMs for reasoning, and future work will focus on scaling up the model size and exploring the impact of training data with larger scale and diversity on R³.This paper proposes R³, a novel method for training large language models (LLMs) for reasoning using reverse curriculum reinforcement learning (RL), which employs only outcome supervision to achieve the benefits of process supervision. The core challenge in applying RL to complex reasoning is identifying a sequence of actions that result in positive rewards and providing appropriate supervision for optimization. Outcome supervision provides sparse rewards for final results without identifying error locations, whereas process supervision offers step-wise rewards but requires extensive manual annotation. R³ overcomes these limitations by learning from correct demonstrations. Specifically, R³ progressively slides the start state of reasoning from a demonstration's end to its beginning, facilitating easier model exploration at all stages. Thus, R³ establishes a stepwise curriculum, allowing outcome supervision to offer step-level signals and precisely pinpoint errors. Using Llama2-7B, our method surpasses RL baseline on eight reasoning tasks by 4.1 points on average. Notably, in program-based reasoning on GSM8K, it exceeds the baseline by 4.2 points across three backbone models, and without any extra data, Codellama-7B + R³ performs comparable to larger models or closed-source models. The paper introduces R³ as a method that uses outcome supervision to achieve the effect of process supervision, conducts extensive experiments across eight reasoning tasks to highlight the effectiveness of the method, and performs in-depth ablation and analysis to provide insights into the training dynamics of R³ and how it works. The method facilitates model exploration by shortening the reasoning chain and narrowing the sampling space, aiding the model in gaining positive rewards more efficiently. R³ is interpreted as a form of dynamic programming. The method is shown to be effective in various reasoning tasks, including logical reasoning, mathematical reasoning, reading comprehension, and natural language inference (NLI). The experiments demonstrate that R³ outperforms both the SFT and RL baselines across eight reasoning tasks, achieving an average improvement of 5.4 points and 4.1 points, respectively. Notably, in program-based reasoning on GSM8K, it surpasses SFT and RL by an average of 11.4 points and 4.2 points, respectively. Moreover, Codellama-7B + R³ outshines models that use extra annotated data like MAmmoTH-Coder and Tora, and is comparable to larger or closed-source models such as GPT-3.5-Turbo. The paper also discusses the impact of different reward functions and the importance of challenging data in training. The results show that R³ provides stable reinforcement learning and is effective in various reasoning tasks. The method is shown to be versatile and adaptable, capable of extending to various reasoning styles like programs. The paper concludes that R³ is a promising approach for training LLMs for reasoning, and future work will focus on scaling up the model size and exploring the impact of training data with larger scale and diversity on R³.
Reach us at info@study.space
[slides and audio] Training Large Language Models for Reasoning through Reverse Curriculum Reinforcement Learning