6 Jun 2024 | Dan Zhang, Sining Zhubian, Yisong Yue, Yuxiao Dong, Jie Tang
ReST-MCTS* is a self-training method for large language models (LLMs) that uses process reward guidance and tree search (MCTS*) to collect high-quality reasoning traces and per-step values for training policy and reward models. Unlike traditional methods that rely on manual annotation for process rewards, ReST-MCTS* infers correct process rewards by estimating the probability that each step leads to the correct answer, given an oracle final correct answer. This approach allows for the automatic generation of per-step labels for training per-step reward models without additional human intervention. The tree-search policy in ReST-MCTS* achieves higher accuracy compared to prior LLM reasoning baselines such as Best-of-N and Tree-of-Thought within the same search budget. ReST-MCTS* also outperforms other self-training algorithms such as ReST $ ^{EM} $ and Self-Rewarding LM. The method is validated on multiple benchmarks, including SciBench and MATH, showing improved performance in reasoning tasks. The key contributions include the development of ReST-MCTS*, a framework for training LLMs using model-based reinforcement learning, which utilizes a modified MCTS algorithm guided by a trained per-step process reward model. The method automatically generates per-step labels for training per-step reward models through sufficient rollouts, leading to higher-quality process reward models and improved self-training. The approach is effective in enhancing the performance of LLMs for complex reasoning tasks by continuously improving both the policy and reward models through mutual self-training.ReST-MCTS* is a self-training method for large language models (LLMs) that uses process reward guidance and tree search (MCTS*) to collect high-quality reasoning traces and per-step values for training policy and reward models. Unlike traditional methods that rely on manual annotation for process rewards, ReST-MCTS* infers correct process rewards by estimating the probability that each step leads to the correct answer, given an oracle final correct answer. This approach allows for the automatic generation of per-step labels for training per-step reward models without additional human intervention. The tree-search policy in ReST-MCTS* achieves higher accuracy compared to prior LLM reasoning baselines such as Best-of-N and Tree-of-Thought within the same search budget. ReST-MCTS* also outperforms other self-training algorithms such as ReST $ ^{EM} $ and Self-Rewarding LM. The method is validated on multiple benchmarks, including SciBench and MATH, showing improved performance in reasoning tasks. The key contributions include the development of ReST-MCTS*, a framework for training LLMs using model-based reinforcement learning, which utilizes a modified MCTS algorithm guided by a trained per-step process reward model. The method automatically generates per-step labels for training per-step reward models through sufficient rollouts, leading to higher-quality process reward models and improved self-training. The approach is effective in enhancing the performance of LLMs for complex reasoning tasks by continuously improving both the policy and reward models through mutual self-training.