BOND: Aligning LLMs with Best-of-N Distillation

BOND: Aligning LLMs with Best-of-N Distillation

19 Jul 2024 | Pier Giuseppe Sessa, Robert Dadashi, Léonard Husseno, Johan Ferret, Nino Vieillard, Alexandre Ramé, Bobak Shariari, Sarah Perrin, Abe Friesen, Geoffrey Cideron, Sertan Girgin, Piotr Stanczyk, Andrea Michi, Danila Sinopalnikov, Sabela Ramos, Amélie Héliou, Aliaksei Severyn, Matt Hoffman, Nikola Momchev, Olivier Bachem
BOND is a novel reinforcement learning from human feedback (RLHF) algorithm that aligns large language models (LLMs) with the Best-of-N sampling strategy without the high computational cost of inference-time Best-of-N. The algorithm, named Best-of-N Distillation (BOND), aims to emulate the quality of Best-of-N sampling by distilling the Best-of-N strategy into the policy, enabling high-quality generation with a single inference sample. BOND is formulated as a distribution matching problem, where the policy is trained to align with the Best-of-N distribution. The algorithm uses the Jeffreys divergence, a combination of forward and backward KL divergences, to balance mode-covering and mode-seeking behaviors. It also introduces an iterative approach that distills the Best-of-N strategy of a moving anchor policy, improving performance and stability. Experiments on abstractive summarization and Gemma models show that BOND outperforms other RLHF methods, achieving better reward-KL trade-offs and improved performance on benchmarks. The J-BOND algorithm, a practical implementation of BOND, uses an exponential moving average anchor and additional KL regularization to enhance performance and stability. J-BOND demonstrates superior results compared to standard RLHF baselines, achieving better reward/KL trade-offs and improved alignment with human preferences. The approach is effective for aligning LLMs with human preferences, making them safer and more reliable.BOND is a novel reinforcement learning from human feedback (RLHF) algorithm that aligns large language models (LLMs) with the Best-of-N sampling strategy without the high computational cost of inference-time Best-of-N. The algorithm, named Best-of-N Distillation (BOND), aims to emulate the quality of Best-of-N sampling by distilling the Best-of-N strategy into the policy, enabling high-quality generation with a single inference sample. BOND is formulated as a distribution matching problem, where the policy is trained to align with the Best-of-N distribution. The algorithm uses the Jeffreys divergence, a combination of forward and backward KL divergences, to balance mode-covering and mode-seeking behaviors. It also introduces an iterative approach that distills the Best-of-N strategy of a moving anchor policy, improving performance and stability. Experiments on abstractive summarization and Gemma models show that BOND outperforms other RLHF methods, achieving better reward-KL trade-offs and improved performance on benchmarks. The J-BOND algorithm, a practical implementation of BOND, uses an exponential moving average anchor and additional KL regularization to enhance performance and stability. J-BOND demonstrates superior results compared to standard RLHF baselines, achieving better reward/KL trade-offs and improved alignment with human preferences. The approach is effective for aligning LLMs with human preferences, making them safer and more reliable.
Reach us at info@study.space