10 Jun 2024 | Songlin Yang, Bailin Wang, Yu Zhang, Yikang Shen, Yoon Kim
This paper introduces a hardware-efficient algorithm for training DeltaNet, a variant of linear transformers that uses the delta rule for associative recall. DeltaNet replaces the additive outer-product update in linear transformers with the delta rule, which allows for more effective associative recall. However, existing training algorithms for DeltaNet do not parallelize over sequence length, leading to inefficiencies on modern hardware. The proposed algorithm exploits a memory-efficient representation for computing products of Householder matrices, enabling parallelization across sequence length. This allows DeltaNet to be scaled up to standard language modeling settings, where it outperforms recent linear-time baselines such as Mamba and GLA in terms of perplexity and zero-shot performance on downstream tasks. The paper also experiments with hybrid models that combine DeltaNet layers with sliding-window attention or global attention layers, finding that these models outperform strong transformer baselines. The algorithm is implemented in Triton and shows significant speed-ups for various sequence lengths and head dimensions. The paper also discusses the limitations of DeltaNet, including its limited length generalization and potential memory size constraints. Finally, the paper compares DeltaNet with other linear recurrent models and discusses the potential for a unifying framework for efficient autoregressive sequence transformations.This paper introduces a hardware-efficient algorithm for training DeltaNet, a variant of linear transformers that uses the delta rule for associative recall. DeltaNet replaces the additive outer-product update in linear transformers with the delta rule, which allows for more effective associative recall. However, existing training algorithms for DeltaNet do not parallelize over sequence length, leading to inefficiencies on modern hardware. The proposed algorithm exploits a memory-efficient representation for computing products of Householder matrices, enabling parallelization across sequence length. This allows DeltaNet to be scaled up to standard language modeling settings, where it outperforms recent linear-time baselines such as Mamba and GLA in terms of perplexity and zero-shot performance on downstream tasks. The paper also experiments with hybrid models that combine DeltaNet layers with sliding-window attention or global attention layers, finding that these models outperform strong transformer baselines. The algorithm is implemented in Triton and shows significant speed-ups for various sequence lengths and head dimensions. The paper also discusses the limitations of DeltaNet, including its limited length generalization and potential memory size constraints. Finally, the paper compares DeltaNet with other linear recurrent models and discusses the potential for a unifying framework for efficient autoregressive sequence transformations.