19 Mar 2024 | Tao Li, Pan Zhou, Zhengbao He, Xinwen Cheng, Xiaolin Huang
The paper "Friendly Sharpness-Aware Minimization" by Tao Li, Pan Zhou, Zhengbao He, Xinwen Cheng, and Xiaolin Huang investigates the mechanisms behind the generalization improvements of Sharpness-Aware Minimization (SAM) and introduces a new variant called "Friendly-SAM" (F-SAM). SAM is a deep neural network optimization algorithm that aims to minimize both training loss and loss sharpness to improve generalization. The authors decompose the minibatch gradient into two components: the full gradient component and the stochastic gradient noise component. They find that using only the full gradient component degrades generalization, while excluding it improves performance. The key insight is that the full gradient component increases the sharpness loss for the entire dataset, creating inconsistencies with the subsequent sharpness minimization step on the current minibatch data.
To address this issue, F-SAM removes the full gradient component estimated by an exponentially moving average (EMA) of historical stochastic gradients and leverages the stochastic gradient noise for improved generalization. The paper provides theoretical validation for the EMA approximation and proves the convergence of F-SAM on non-convex problems. Extensive experiments demonstrate that F-SAM outperforms vanilla SAM in terms of generalization and robustness, showing superior performance across various datasets and tasks, including training from scratch and transfer learning. The code for F-SAM is available at <https://github.com/nbit/F-SAM>.The paper "Friendly Sharpness-Aware Minimization" by Tao Li, Pan Zhou, Zhengbao He, Xinwen Cheng, and Xiaolin Huang investigates the mechanisms behind the generalization improvements of Sharpness-Aware Minimization (SAM) and introduces a new variant called "Friendly-SAM" (F-SAM). SAM is a deep neural network optimization algorithm that aims to minimize both training loss and loss sharpness to improve generalization. The authors decompose the minibatch gradient into two components: the full gradient component and the stochastic gradient noise component. They find that using only the full gradient component degrades generalization, while excluding it improves performance. The key insight is that the full gradient component increases the sharpness loss for the entire dataset, creating inconsistencies with the subsequent sharpness minimization step on the current minibatch data.
To address this issue, F-SAM removes the full gradient component estimated by an exponentially moving average (EMA) of historical stochastic gradients and leverages the stochastic gradient noise for improved generalization. The paper provides theoretical validation for the EMA approximation and proves the convergence of F-SAM on non-convex problems. Extensive experiments demonstrate that F-SAM outperforms vanilla SAM in terms of generalization and robustness, showing superior performance across various datasets and tasks, including training from scratch and transfer learning. The code for F-SAM is available at <https://github.com/nbit/F-SAM>.