Mechanics of Next Token Prediction with Self-Attention

Mechanics of Next Token Prediction with Self-Attention

12 Mar 2024 | Yingcong Li* 1 Yixiao Huang* 1 M. Emrullah Ildiz1 Ankit Singh Rawat2 Samet Oymak1
This paper investigates the mechanics of next-token prediction using self-attention in Transformer-based models. The authors explore how a single self-attention layer learns to predict the next token in two distinct steps: hard retrieval and soft composition. Hard retrieval involves selecting high-priority input tokens associated with the last input token, while soft composition creates a convex combination of these high-priority tokens to sample the next token. The paper rigorously characterizes these mechanisms through directed graphs over tokens extracted from the training data, proving that gradient descent implicitly discovers the strongly-connected components (SCCs) of the graph and learns to retrieve tokens belonging to the highest-priority SCC. The theory decomposes the model weights into a directional component (hard retrieval) and a finite component (soft composition), formalizing an implicit bias formula conjectured in previous work. The findings provide insights into how self-attention processes sequential data and contribute to demystifying more complex architectures. The paper also discusses the optimization landscape and implicit biases of self-attention, establishing global convergence results for gradient descent and regularized path algorithms.This paper investigates the mechanics of next-token prediction using self-attention in Transformer-based models. The authors explore how a single self-attention layer learns to predict the next token in two distinct steps: hard retrieval and soft composition. Hard retrieval involves selecting high-priority input tokens associated with the last input token, while soft composition creates a convex combination of these high-priority tokens to sample the next token. The paper rigorously characterizes these mechanisms through directed graphs over tokens extracted from the training data, proving that gradient descent implicitly discovers the strongly-connected components (SCCs) of the graph and learns to retrieve tokens belonging to the highest-priority SCC. The theory decomposes the model weights into a directional component (hard retrieval) and a finite component (soft composition), formalizing an implicit bias formula conjectured in previous work. The findings provide insights into how self-attention processes sequential data and contribute to demystifying more complex architectures. The paper also discusses the optimization landscape and implicit biases of self-attention, establishing global convergence results for gradient descent and regularized path algorithms.
Reach us at info@study.space