Transformers Learn Nonlinear Features In Context: Nonconvex Mean-field Dynamics on the Attention Landscape

Transformers Learn Nonlinear Features In Context: Nonconvex Mean-field Dynamics on the Attention Landscape

2024 | Juno Kim, Taiji Suzuki
This paper investigates the optimization dynamics of a Transformer model consisting of a two-layer MLP followed by a linear attention layer, pretrained on linear transformations of feature representations. The authors prove that the infinite-dimensional loss landscape for the distribution of parameters, while highly nonconvex, becomes benign in the mean-field and two-timescale limit. They analyze the second-order stability of mean-field dynamics and show that Wasserstein gradient flow almost always avoids saddle points. Additionally, they derive concrete improvement rates in three regions: away from saddle points, near global minima, and near saddle points. This represents the first saddle point analysis of mean-field dynamics in general, with techniques of independent interest. The paper also includes numerical experiments to complement the theoretical findings.This paper investigates the optimization dynamics of a Transformer model consisting of a two-layer MLP followed by a linear attention layer, pretrained on linear transformations of feature representations. The authors prove that the infinite-dimensional loss landscape for the distribution of parameters, while highly nonconvex, becomes benign in the mean-field and two-timescale limit. They analyze the second-order stability of mean-field dynamics and show that Wasserstein gradient flow almost always avoids saddle points. Additionally, they derive concrete improvement rates in three regions: away from saddle points, near global minima, and near saddle points. This represents the first saddle point analysis of mean-field dynamics in general, with techniques of independent interest. The paper also includes numerical experiments to complement the theoretical findings.
Reach us at info@study.space