How Transformers Learn Causal Structure with Gradient Descent

How Transformers Learn Causal Structure with Gradient Descent

August 14, 2024 | Eshaan Nichani, Alex Damian, and Jason D. Lee
The paper explores how transformers learn causal structure through gradient descent, focusing on the self-attention mechanism that enables information transfer within sequences. The authors introduce an in-context learning task to study this process, proving that a simplified two-layer transformer learns to solve this task by encoding the latent causal graph in the first attention layer. They show that the gradient of the attention matrix encodes mutual information between tokens, and the largest entries in this gradient correspond to edges in the latent causal graph. When sequences are generated from Markov chains, the transformer learns an induction head. The paper also discusses the generalization of this learning process to graphs with multiple parents and provides empirical evidence supporting the theoretical findings.The paper explores how transformers learn causal structure through gradient descent, focusing on the self-attention mechanism that enables information transfer within sequences. The authors introduce an in-context learning task to study this process, proving that a simplified two-layer transformer learns to solve this task by encoding the latent causal graph in the first attention layer. They show that the gradient of the attention matrix encodes mutual information between tokens, and the largest entries in this gradient correspond to edges in the latent causal graph. When sequences are generated from Markov chains, the transformer learns an induction head. The paper also discusses the generalization of this learning process to graphs with multiple parents and provides empirical evidence supporting the theoretical findings.
Reach us at info@study.space