Mechanics of Next Token Prediction with Self-Attention

Mechanics of Next Token Prediction with Self-Attention

12 Mar 2024 | Yingcong Li*, Yixiao Huang*, M. Emrullah Ildiz, Ankit Singh Rawat, Samet Oymak
This paper investigates the mechanics of next-token prediction using self-attention in Transformer-based language models. The authors show that training self-attention with gradient descent leads to an automaton that generates the next token in two steps: (1) hard retrieval, where the model selects high-priority input tokens associated with the last input token, and (2) soft composition, where it creates a convex combination of these tokens to sample the next token. They rigorously characterize these mechanics through a directed graph over tokens extracted from training data, proving that gradient descent implicitly discovers strongly-connected components (SCCs) of this graph. The model learns to retrieve tokens from the highest-priority SCC in the context window. The authors decompose the model weights into a directional component (hard retrieval) and a finite component (soft composition), formalizing an implicit bias conjecture. They also show that the self-attention mechanism implicitly discovers SCCs during training, which captures the priority order among tokens. The paper introduces a token-priority graph (TPG) to represent the priority relationships between tokens in the training data. The TPG is used to define the SVM problem (Graph-SVM), which describes the direction gradient descent converges to. This leads to the implicit bias of the solution obtained by vanishing regularization. The paper also studies the global convergence of gradient descent for the next-token prediction task under log-loss, showing that the problem is convex under suitable assumptions. The authors establish a global convergence result to fully formalize the theory in terms of a directional component (Graph-SVM) and a finite component. They also reveal insights into the connections between continuous and discrete optimization, showing that self-attention implicitly discovers the SCCs of the TPGs during training. The paper further explores the implicit bias of self-attention on more general next-token prediction problems, showing that the solution of the regularization path algorithm converges to the solution of the Graph-SVM. The authors also investigate the local convergence of gradient descent, showing that it can exhibit local directional convergence rather than global. They characterize these local directions through the SVM solutions of pseudo TPGs. Overall, the paper provides a comprehensive analysis of the mechanics of next-token prediction using self-attention, showing how the model learns to select and compose tokens based on their priority in the training data. The results highlight the importance of understanding the implicit biases of self-attention in Transformer-based language models and provide a foundation for further research into more complex architectures.This paper investigates the mechanics of next-token prediction using self-attention in Transformer-based language models. The authors show that training self-attention with gradient descent leads to an automaton that generates the next token in two steps: (1) hard retrieval, where the model selects high-priority input tokens associated with the last input token, and (2) soft composition, where it creates a convex combination of these tokens to sample the next token. They rigorously characterize these mechanics through a directed graph over tokens extracted from training data, proving that gradient descent implicitly discovers strongly-connected components (SCCs) of this graph. The model learns to retrieve tokens from the highest-priority SCC in the context window. The authors decompose the model weights into a directional component (hard retrieval) and a finite component (soft composition), formalizing an implicit bias conjecture. They also show that the self-attention mechanism implicitly discovers SCCs during training, which captures the priority order among tokens. The paper introduces a token-priority graph (TPG) to represent the priority relationships between tokens in the training data. The TPG is used to define the SVM problem (Graph-SVM), which describes the direction gradient descent converges to. This leads to the implicit bias of the solution obtained by vanishing regularization. The paper also studies the global convergence of gradient descent for the next-token prediction task under log-loss, showing that the problem is convex under suitable assumptions. The authors establish a global convergence result to fully formalize the theory in terms of a directional component (Graph-SVM) and a finite component. They also reveal insights into the connections between continuous and discrete optimization, showing that self-attention implicitly discovers the SCCs of the TPGs during training. The paper further explores the implicit bias of self-attention on more general next-token prediction problems, showing that the solution of the regularization path algorithm converges to the solution of the Graph-SVM. The authors also investigate the local convergence of gradient descent, showing that it can exhibit local directional convergence rather than global. They characterize these local directions through the SVM solutions of pseudo TPGs. Overall, the paper provides a comprehensive analysis of the mechanics of next-token prediction using self-attention, showing how the model learns to select and compose tokens based on their priority in the training data. The results highlight the importance of understanding the implicit biases of self-attention in Transformer-based language models and provide a foundation for further research into more complex architectures.
Reach us at info@study.space
[slides] Mechanics of Next Token Prediction with Self-Attention | StudySpace