February 15, 2024 | Clayton Sanford, Daniel Hsu, and Matus Telgarsky
This paper shows that transformers can efficiently simulate and be simulated by a constant number of communication rounds in massively parallel computation (MPC). This implies that logarithmic depth is sufficient for transformers to solve basic computational tasks that are not efficiently solvable by other neural sequence models or sub-quadratic transformer approximations. The key distinguishing property of transformers is their ability to leverage parallelism.
The paper establishes a formal connection between transformers and MPC by designing transformers that simulate MPC protocols and vice versa. This allows a wide range of computational tasks to be solved by logarithmic-depth transformers, including tasks that cannot be efficiently solved by other architectures such as graph neural networks and recurrent models.
The paper presents two main results. First, it shows that any R-round MPC protocol can be implemented by a transformer of depth O(R), and that any depth-L transformer can be simulated by an O(L)-round MPC protocol. This implies that several graph problems are solved by logarithmic-depth transformers and that these transformers are near-optimal under certain conjectures about MPC algorithms.
Second, the paper introduces a synthetic sequential task called the k-hop induction heads task, which is solved by logarithmic-depth transformers but not by other architectures. Theoretical results show that depth L = Θ(log k) is necessary and sufficient for efficient transformer representation. Empirical results show that transformers trained on this task obey the same threshold and recover a similar model to the theoretical construction. In contrast, non-parallelizable recurrent architectures, including state-space models like Mamba, are unable to solve the task in a size-efficient manner.
The paper also shows that transformers can solve graph connectivity more efficiently than GNNs. Additionally, it demonstrates that sub-quadratic attention transformers and shallow transformers with chain-of-thought prompting are unable to solve the k-hop task efficiently. These results highlight the importance of parallelism in transformers and their ability to solve tasks that are not efficiently solvable by other architectures.This paper shows that transformers can efficiently simulate and be simulated by a constant number of communication rounds in massively parallel computation (MPC). This implies that logarithmic depth is sufficient for transformers to solve basic computational tasks that are not efficiently solvable by other neural sequence models or sub-quadratic transformer approximations. The key distinguishing property of transformers is their ability to leverage parallelism.
The paper establishes a formal connection between transformers and MPC by designing transformers that simulate MPC protocols and vice versa. This allows a wide range of computational tasks to be solved by logarithmic-depth transformers, including tasks that cannot be efficiently solved by other architectures such as graph neural networks and recurrent models.
The paper presents two main results. First, it shows that any R-round MPC protocol can be implemented by a transformer of depth O(R), and that any depth-L transformer can be simulated by an O(L)-round MPC protocol. This implies that several graph problems are solved by logarithmic-depth transformers and that these transformers are near-optimal under certain conjectures about MPC algorithms.
Second, the paper introduces a synthetic sequential task called the k-hop induction heads task, which is solved by logarithmic-depth transformers but not by other architectures. Theoretical results show that depth L = Θ(log k) is necessary and sufficient for efficient transformer representation. Empirical results show that transformers trained on this task obey the same threshold and recover a similar model to the theoretical construction. In contrast, non-parallelizable recurrent architectures, including state-space models like Mamba, are unable to solve the task in a size-efficient manner.
The paper also shows that transformers can solve graph connectivity more efficiently than GNNs. Additionally, it demonstrates that sub-quadratic attention transformers and shallow transformers with chain-of-thought prompting are unable to solve the k-hop task efficiently. These results highlight the importance of parallelism in transformers and their ability to solve tasks that are not efficiently solvable by other architectures.