Heavy-Tailed Class Imbalance and Why Adam Outperforms Gradient Descent on Language Models

Heavy-Tailed Class Imbalance and Why Adam Outperforms Gradient Descent on Language Models

12 Jul 2024 | Frederik Kunstner, Robin Yadav, Alan Milligan, Mark Schmidt, Alberto Bietti
The paper investigates why Adam outperforms gradient descent (SGD) in large language models, particularly in the context of heavy-tailed class imbalance. Heavy-tailed class imbalance, where rare classes account for a significant portion of the data, is a key factor in the performance gap between Adam and SGD. When trained with SGD, the loss of infrequent words decreases more slowly than that of frequent words, leading to slow overall progress. In contrast, Adam and sign-based methods are less sensitive to this issue. The authors demonstrate that this behavior is consistent across different architectures and data types, including language transformers, vision CNNs, and linear models. They show that class imbalance leads to imbalanced and correlated gradients and Hessians, which benefit Adam. The paper also proves that in continuous time, SGD converges slowly on low-frequency classes while sign descent does not. The findings suggest that heavy-tailed class imbalance is a significant factor in the performance gap and should be considered in future optimization algorithms for language and other tasks with similar characteristics.The paper investigates why Adam outperforms gradient descent (SGD) in large language models, particularly in the context of heavy-tailed class imbalance. Heavy-tailed class imbalance, where rare classes account for a significant portion of the data, is a key factor in the performance gap between Adam and SGD. When trained with SGD, the loss of infrequent words decreases more slowly than that of frequent words, leading to slow overall progress. In contrast, Adam and sign-based methods are less sensitive to this issue. The authors demonstrate that this behavior is consistent across different architectures and data types, including language transformers, vision CNNs, and linear models. They show that class imbalance leads to imbalanced and correlated gradients and Hessians, which benefit Adam. The paper also proves that in continuous time, SGD converges slowly on low-frequency classes while sign descent does not. The findings suggest that heavy-tailed class imbalance is a significant factor in the performance gap and should be considered in future optimization algorithms for language and other tasks with similar characteristics.
Reach us at info@study.space
Understanding Heavy-Tailed Class Imbalance and Why Adam Outperforms Gradient Descent on Language Models