The Benefits of Reusing Batches for Gradient Descent in Two-Layer Networks: Breaking the Curse of Information and Leap Exponents

The Benefits of Reusing Batches for Gradient Descent in Two-Layer Networks: Breaking the Curse of Information and Leap Exponents

30 Jun 2024 | Yatin Dandi¹,², Emanuele Troiani², Luca Arnaboldi¹, Luca Pesce¹, Lenka Zdeborová², and Florent Krzakala¹
The paper investigates the training dynamics of two-layer neural networks when learning multi-index target functions, focusing on multi-pass gradient descent (GD) that reuses batches multiple times. It shows that multi-pass GD significantly changes the conclusions about which functions are learnable compared to single-pass GD. Multi-pass GD with finite stepsizes overcomes the limitations of gradient flow and single-pass GD, as dictated by the information exponent and leap exponent of the target function. The network achieves overlap with the target subspace in just two time steps even for functions not satisfying the staircase property. The paper characterizes the class of functions efficiently learned in finite time and provides a closed-form description of the dynamical process of the low-dimensional projections of the weights, along with numerical experiments. The paper demonstrates that gradient descent surpasses the limitations imposed by the information and leap exponents, achieving a positive correlation with the target function for a broader class of functions than staircase functions, even with minimal data batch repetition. It characterizes the class of functions efficiently learned in finite time and shows that symmetric functions remain challenging due to their extended symmetry-breaking times. The analysis is based on Dynamical Mean-Field Theory (DMFT), which provides a rigorous framework for understanding the learning dynamics. The results show that reusing batches leads to non-Gaussian pre-activations correlated with the targets, which is crucial for learning directions compared to the one-pass setting. The paper also shows that gradient descent on the same batch can surpass one-pass SGD on different batches, even when one-pass SGD uses more data points. The findings challenge the common wisdom that more data is always better and provide insights into incremental learning in the presence of correlations between data points across batches. The results are supported by rigorous mathematical proofs rooted in DMFT and provide an analytic description of the dynamic processes of low-dimensional weight projections.The paper investigates the training dynamics of two-layer neural networks when learning multi-index target functions, focusing on multi-pass gradient descent (GD) that reuses batches multiple times. It shows that multi-pass GD significantly changes the conclusions about which functions are learnable compared to single-pass GD. Multi-pass GD with finite stepsizes overcomes the limitations of gradient flow and single-pass GD, as dictated by the information exponent and leap exponent of the target function. The network achieves overlap with the target subspace in just two time steps even for functions not satisfying the staircase property. The paper characterizes the class of functions efficiently learned in finite time and provides a closed-form description of the dynamical process of the low-dimensional projections of the weights, along with numerical experiments. The paper demonstrates that gradient descent surpasses the limitations imposed by the information and leap exponents, achieving a positive correlation with the target function for a broader class of functions than staircase functions, even with minimal data batch repetition. It characterizes the class of functions efficiently learned in finite time and shows that symmetric functions remain challenging due to their extended symmetry-breaking times. The analysis is based on Dynamical Mean-Field Theory (DMFT), which provides a rigorous framework for understanding the learning dynamics. The results show that reusing batches leads to non-Gaussian pre-activations correlated with the targets, which is crucial for learning directions compared to the one-pass setting. The paper also shows that gradient descent on the same batch can surpass one-pass SGD on different batches, even when one-pass SGD uses more data points. The findings challenge the common wisdom that more data is always better and provide insights into incremental learning in the presence of correlations between data points across batches. The results are supported by rigorous mathematical proofs rooted in DMFT and provide an analytic description of the dynamic processes of low-dimensional weight projections.
Reach us at info@study.space
[slides and audio] The Benefits of Reusing Batches for Gradient Descent in Two-Layer Networks%3A Breaking the Curse of Information and Leap Exponents