Transformers Provably Learn Sparse Token Selection While Fully-Connected Nets Cannot

Transformers Provably Learn Sparse Token Selection While Fully-Connected Nets Cannot

11 Jun 2024 | Zixuan Wang, Stanley Wei, Daniel Hsu, and Jason D. Lee
The paper "Transformers Provably Learn Sparse Token Selection While Fully-Connected Nets Cannot" by Zixuan Wang, Stanley Wei, Daniel Hsu, and Jason D. Lee explores the capabilities of transformers in learning sparse token selection tasks compared to fully-connected networks (FCNs). The authors build upon previous work by Sanford et al., who introduced the *q-sparse averaging* task, where a transformer can efficiently approximate the task with a logarithmic number of dimensions, while FCNs require a linear number of neurons in the first layer. The main contributions of this paper are: 1. **Gradient Descent Convergence**: The authors prove that a one-layer transformer trained with gradient descent can globally converge to the optimal solution for the sparse token selection task (STSq) with a width of \(O(d + q \log T)\), where \(d\) is the token dimension, \(q\) is the subset size, and \(T\) is the sequence length. 2. **Expressive Power Separation**: They show that FCNs cannot approximate STSq with a width of less than \(\Omega(Td)\), demonstrating an exponential separation in expressive power between transformers and FCNs. 3. **Length Generalization**: The paper also investigates the length generalization capability of the trained transformer, proving that it converges to zero loss on out-of-distribution data, while fixed positional encoding architectures fail to do so. The authors provide empirical simulations to support their theoretical findings, showing that the stochastic positional encoding used in the transformer significantly improves length generalization compared to fixed positional encoding. The paper concludes with a discussion of open questions, including the extension of these results to other tasks and practical settings.The paper "Transformers Provably Learn Sparse Token Selection While Fully-Connected Nets Cannot" by Zixuan Wang, Stanley Wei, Daniel Hsu, and Jason D. Lee explores the capabilities of transformers in learning sparse token selection tasks compared to fully-connected networks (FCNs). The authors build upon previous work by Sanford et al., who introduced the *q-sparse averaging* task, where a transformer can efficiently approximate the task with a logarithmic number of dimensions, while FCNs require a linear number of neurons in the first layer. The main contributions of this paper are: 1. **Gradient Descent Convergence**: The authors prove that a one-layer transformer trained with gradient descent can globally converge to the optimal solution for the sparse token selection task (STSq) with a width of \(O(d + q \log T)\), where \(d\) is the token dimension, \(q\) is the subset size, and \(T\) is the sequence length. 2. **Expressive Power Separation**: They show that FCNs cannot approximate STSq with a width of less than \(\Omega(Td)\), demonstrating an exponential separation in expressive power between transformers and FCNs. 3. **Length Generalization**: The paper also investigates the length generalization capability of the trained transformer, proving that it converges to zero loss on out-of-distribution data, while fixed positional encoding architectures fail to do so. The authors provide empirical simulations to support their theoretical findings, showing that the stochastic positional encoding used in the transformer significantly improves length generalization compared to fixed positional encoding. The paper concludes with a discussion of open questions, including the extension of these results to other tasks and practical settings.
Reach us at info@study.space
Understanding Transformers Provably Learn Sparse Token Selection While Fully-Connected Nets Cannot