10 Feb 2025 | Luca Arnaboldi, Yatin Dandi, Florent Krzakala, Luca Pesce, and Ludovic Stephan
The paper investigates the training dynamics of two-layer shallow neural networks using gradient-based algorithms, focusing on how they learn relevant features in multi-index models, which are target functions with low-dimensional relevant directions. In the high-dimensional regime where the input dimension \( d \) diverges, the authors show that a simple modification of the idealized single-pass gradient descent training scenario, where data can be repeated or iterated upon twice, significantly improves computational efficiency. This modification surpasses the limitations previously believed to be dictated by the Information and Leap exponents associated with the target function to be learned. The results highlight the ability of networks to learn relevant structures from data alone without any preprocessing. Specifically, the authors demonstrate that (almost) all directions are learned with a complexity of \( O(d \log d) \) steps, except for a set of hard functions that includes sparse parities. For these hard functions, the learning can be achieved through a hierarchical mechanism that generalizes the concept of staircase functions. The findings are supported by a rigorous study of the evolution of relevant statistics in high-dimensional dynamics. The paper also discusses the differences between their work and other recent studies, emphasizing the importance of realistic training scenarios that consider correlations in the data.The paper investigates the training dynamics of two-layer shallow neural networks using gradient-based algorithms, focusing on how they learn relevant features in multi-index models, which are target functions with low-dimensional relevant directions. In the high-dimensional regime where the input dimension \( d \) diverges, the authors show that a simple modification of the idealized single-pass gradient descent training scenario, where data can be repeated or iterated upon twice, significantly improves computational efficiency. This modification surpasses the limitations previously believed to be dictated by the Information and Leap exponents associated with the target function to be learned. The results highlight the ability of networks to learn relevant structures from data alone without any preprocessing. Specifically, the authors demonstrate that (almost) all directions are learned with a complexity of \( O(d \log d) \) steps, except for a set of hard functions that includes sparse parities. For these hard functions, the learning can be achieved through a hierarchical mechanism that generalizes the concept of staircase functions. The findings are supported by a rigorous study of the evolution of relevant statistics in high-dimensional dynamics. The paper also discusses the differences between their work and other recent studies, emphasizing the importance of realistic training scenarios that consider correlations in the data.