Probabilistic Inference in Language Models via Twisted Sequential Monte Carlo

Probabilistic Inference in Language Models via Twisted Sequential Monte Carlo

26 Apr 2024 | Stephen Zhao, Rob Brekelmans, Alireza Makhzani, Roger Grosse
This paper introduces a novel approach to probabilistic inference in language models using twisted Sequential Monte Carlo (SMC). The authors leverage the rich toolkit of SMC to address various tasks, including sampling from unnormalized target distributions defined by reward or potential functions. They propose learned *twist functions* to estimate the expected future value of the potential at each timestep, enabling efficient inference by focusing computation on promising partial sequences. A novel *contrastive twist learning* (CTL) method is developed to learn these twist functions, inspired by energy-based modeling and density ratio estimation. The paper also presents methods for evaluating the accuracy of language model inference techniques using *bidirectional* SMC bounds on the log partition function, which can estimate the KL divergence between the inference and target distributions in both directions. Experimental results demonstrate the effectiveness of twisted SMC in sampling undesirable outputs, generating reviews with varied sentiment, and performing infilling tasks. The contributions include a general framework for sampling and evaluation in language modeling, a novel twist learning method, and a set of tools for evaluating inference techniques.This paper introduces a novel approach to probabilistic inference in language models using twisted Sequential Monte Carlo (SMC). The authors leverage the rich toolkit of SMC to address various tasks, including sampling from unnormalized target distributions defined by reward or potential functions. They propose learned *twist functions* to estimate the expected future value of the potential at each timestep, enabling efficient inference by focusing computation on promising partial sequences. A novel *contrastive twist learning* (CTL) method is developed to learn these twist functions, inspired by energy-based modeling and density ratio estimation. The paper also presents methods for evaluating the accuracy of language model inference techniques using *bidirectional* SMC bounds on the log partition function, which can estimate the KL divergence between the inference and target distributions in both directions. Experimental results demonstrate the effectiveness of twisted SMC in sampling undesirable outputs, generating reviews with varied sentiment, and performing infilling tasks. The contributions include a general framework for sampling and evaluation in language modeling, a novel twist learning method, and a set of tools for evaluating inference techniques.
Reach us at info@study.space