Tensor Attention Training: Provably Efficient Learning of Higher-order Transformers

Tensor Attention Training: Provably Efficient Learning of Higher-order Transformers

26 May 2024 | Jiuxiang Gu, Yingyu Liang, Zhenmei Shi, Zhao Song, Yufa Zhou
This paper presents a method for efficiently training higher-order transformers using Tensor Attention, which captures high-order correlations among multiple modalities. The key challenge is the high computational complexity of Tensor Attention, which is $ \Omega(n^3) $, making it impractical for large input sequences. The authors prove that the backward gradient of Tensor Attention training can be computed in almost linear time $ n^{1+o(1)} $, matching the complexity of its forward computation under a bounded entries assumption. They provide a closed-form solution for the gradient and propose a fast computation method using polynomial approximation and tensor algebraic tricks. Additionally, they prove the necessity and tightness of their assumption through hardness analysis, showing that slightly weakening it makes the gradient problem unsolvable in truly subcubic time. The theoretical results establish the feasibility of efficient higher-order transformer training and may facilitate practical applications of tensor attention architectures. The paper also discusses related work, preliminary definitions, and the complexity analysis of tensor attention gradient computation. The main contributions include the closed-form solution of the gradient, the fast computation method for the gradient, and the proof of the necessity and tightness of the assumption. The results show that Tensor Attention can be trained efficiently, overcoming the cubic complexity barrier in both forward and backward computation.This paper presents a method for efficiently training higher-order transformers using Tensor Attention, which captures high-order correlations among multiple modalities. The key challenge is the high computational complexity of Tensor Attention, which is $ \Omega(n^3) $, making it impractical for large input sequences. The authors prove that the backward gradient of Tensor Attention training can be computed in almost linear time $ n^{1+o(1)} $, matching the complexity of its forward computation under a bounded entries assumption. They provide a closed-form solution for the gradient and propose a fast computation method using polynomial approximation and tensor algebraic tricks. Additionally, they prove the necessity and tightness of their assumption through hardness analysis, showing that slightly weakening it makes the gradient problem unsolvable in truly subcubic time. The theoretical results establish the feasibility of efficient higher-order transformer training and may facilitate practical applications of tensor attention architectures. The paper also discusses related work, preliminary definitions, and the complexity analysis of tensor attention gradient computation. The main contributions include the closed-form solution of the gradient, the fast computation method for the gradient, and the proof of the necessity and tightness of the assumption. The results show that Tensor Attention can be trained efficiently, overcoming the cubic complexity barrier in both forward and backward computation.
Reach us at info@futurestudyspace.com
[slides and audio] Tensor Attention Training%3A Provably Efficient Learning of Higher-order Transformers