2024 | Frederik Kunstner, Robin Yadav, Alan Milligan, Mark Schmidt, Alberto Bietti
Heavy-tailed class imbalance is a key factor explaining why Adam outperforms gradient descent (GD) on language models. Language tasks exhibit heavy-tailed class imbalance, where rare classes account for a large fraction of the data, leading to imbalanced and correlated gradients and Hessians. GD struggles with low-frequency classes, while Adam and sign-based methods are less sensitive to this issue. This behavior is observed across various architectures and data types, including language transformers, vision CNNs, and linear models. On a linear model with cross-entropy loss, heavy-tailed class imbalance leads to imbalanced gradients and Hessians, which are hypothesized to benefit Adam. Theoretical analysis shows that GD converges slowly on low-frequency classes in continuous time, while sign descent is insensitive to class frequencies. Experiments on vision models and linear models confirm that heavy-tailed class imbalance significantly impacts performance, with Adam outperforming GD. The performance gap is attributed to the correlation between gradient and Hessian magnitudes across parameters, which is more pronounced in language tasks. While class imbalance is not the only reason Adam outperforms GD, it is a significant factor. The findings suggest that heavy-tailed class imbalance has a significant impact on training performance and should be considered for future optimizers to perform well on language and other tasks with heavy-tailed class imbalance.Heavy-tailed class imbalance is a key factor explaining why Adam outperforms gradient descent (GD) on language models. Language tasks exhibit heavy-tailed class imbalance, where rare classes account for a large fraction of the data, leading to imbalanced and correlated gradients and Hessians. GD struggles with low-frequency classes, while Adam and sign-based methods are less sensitive to this issue. This behavior is observed across various architectures and data types, including language transformers, vision CNNs, and linear models. On a linear model with cross-entropy loss, heavy-tailed class imbalance leads to imbalanced gradients and Hessians, which are hypothesized to benefit Adam. Theoretical analysis shows that GD converges slowly on low-frequency classes in continuous time, while sign descent is insensitive to class frequencies. Experiments on vision models and linear models confirm that heavy-tailed class imbalance significantly impacts performance, with Adam outperforming GD. The performance gap is attributed to the correlation between gradient and Hessian magnitudes across parameters, which is more pronounced in language tasks. While class imbalance is not the only reason Adam outperforms GD, it is a significant factor. The findings suggest that heavy-tailed class imbalance has a significant impact on training performance and should be considered for future optimizers to perform well on language and other tasks with heavy-tailed class imbalance.