Transformers Provably Learn Sparse Token Selection While Fully-Connected Nets Cannot

Transformers Provably Learn Sparse Token Selection While Fully-Connected Nets Cannot

11 Jun 2024 | Zixuan Wang, Stanley Wei, Daniel Hsu, and Jason D. Lee
Transformers have shown superior performance in learning sparse token selection tasks compared to fully-connected networks (FCNs). This paper demonstrates that a one-layer transformer trained with gradient descent (GD) can provably learn the sparse token selection task (STS_q), achieving strong out-of-distribution (OOD) length generalization. In contrast, FCNs require exponentially larger widths to approximate the same task. The study establishes an algorithmic separation between transformers and FCNs, showing that transformers can learn STS_q with a width of O(d + q log T), while FCNs need Ω(Td) neurons in the first layer. The paper also proves that the trained transformer with stochastic positional encoding generalizes well to longer sequences, with OOD loss converging to zero when in-distribution loss converges. Empirical simulations validate these theoretical findings, showing that stochastic positional encoding outperforms fixed positional encoding in length generalization. The work extends previous results by considering average-case settings and provides a deeper understanding of the training dynamics of transformers. It also highlights the importance of positional encoding in enabling length generalization, showing that randomized positional encoding enhances this capability. The paper contributes to the theoretical understanding of transformers by proving their learnability and generalization properties, and by showing that their expressive power translates to actual learning performance. The results demonstrate that transformers can achieve exponential separation in both expressiveness and learnability over FCNs, making them more effective in approximating certain arithmetic tasks.Transformers have shown superior performance in learning sparse token selection tasks compared to fully-connected networks (FCNs). This paper demonstrates that a one-layer transformer trained with gradient descent (GD) can provably learn the sparse token selection task (STS_q), achieving strong out-of-distribution (OOD) length generalization. In contrast, FCNs require exponentially larger widths to approximate the same task. The study establishes an algorithmic separation between transformers and FCNs, showing that transformers can learn STS_q with a width of O(d + q log T), while FCNs need Ω(Td) neurons in the first layer. The paper also proves that the trained transformer with stochastic positional encoding generalizes well to longer sequences, with OOD loss converging to zero when in-distribution loss converges. Empirical simulations validate these theoretical findings, showing that stochastic positional encoding outperforms fixed positional encoding in length generalization. The work extends previous results by considering average-case settings and provides a deeper understanding of the training dynamics of transformers. It also highlights the importance of positional encoding in enabling length generalization, showing that randomized positional encoding enhances this capability. The paper contributes to the theoretical understanding of transformers by proving their learnability and generalization properties, and by showing that their expressive power translates to actual learning performance. The results demonstrate that transformers can achieve exponential separation in both expressiveness and learnability over FCNs, making them more effective in approximating certain arithmetic tasks.
Reach us at info@study.space