The paper "FLASHATTENTION-2: Faster Attention with Better Parallelism and Work Partitioning" addresses the challenge of scaling Transformers to longer sequence lengths, which is crucial for improving performance in language modeling, high-resolution image understanding, and applications in code, audio, and video generation. The attention layer, a key component of Transformers, has quadratic runtime and memory requirements in sequence length, making it a significant bottleneck. FLASHATTENTION [5] reduces memory usage and runtime by exploiting the asymmetric GPU memory hierarchy, achieving 2-4x speedup over optimized baselines. However, it still falls short of optimized matrix-multiply (GEMM) operations, reaching only 25-40% of the theoretical maximum FLOPs/s.
To further improve efficiency, the authors propose FLASHATTENTION-2, which includes several enhancements:
1. **Algorithm Tweak**: Reduces non-matmul FLOPs to maximize matmul throughput.
2. **Parallelism**: Parallelizes the attention computation across different thread blocks and warps to increase occupancy and reduce shared memory reads/writes.
3. **Work Partitioning**: Distributes work between warps within each thread block to minimize communication.
These improvements yield a 2× speedup compared to FLASHATTENTION, reaching up to 73% of the theoretical maximum FLOPs/s on A100 GPUs. Empirical validation shows that FLASHATTENTION-2 can achieve training speeds of up to 225 TFLOPs/s per A100 GPU when used end-to-end for training GPT-style models, with 72% model FLOPs utilization. The paper also discusses future directions, including broader device compatibility and optimization for H100 GPUs.The paper "FLASHATTENTION-2: Faster Attention with Better Parallelism and Work Partitioning" addresses the challenge of scaling Transformers to longer sequence lengths, which is crucial for improving performance in language modeling, high-resolution image understanding, and applications in code, audio, and video generation. The attention layer, a key component of Transformers, has quadratic runtime and memory requirements in sequence length, making it a significant bottleneck. FLASHATTENTION [5] reduces memory usage and runtime by exploiting the asymmetric GPU memory hierarchy, achieving 2-4x speedup over optimized baselines. However, it still falls short of optimized matrix-multiply (GEMM) operations, reaching only 25-40% of the theoretical maximum FLOPs/s.
To further improve efficiency, the authors propose FLASHATTENTION-2, which includes several enhancements:
1. **Algorithm Tweak**: Reduces non-matmul FLOPs to maximize matmul throughput.
2. **Parallelism**: Parallelizes the attention computation across different thread blocks and warps to increase occupancy and reduce shared memory reads/writes.
3. **Work Partitioning**: Distributes work between warps within each thread block to minimize communication.
These improvements yield a 2× speedup compared to FLASHATTENTION, reaching up to 73% of the theoretical maximum FLOPs/s on A100 GPUs. Empirical validation shows that FLASHATTENTION-2 can achieve training speeds of up to 225 TFLOPs/s per A100 GPU when used end-to-end for training GPT-style models, with 72% model FLOPs utilization. The paper also discusses future directions, including broader device compatibility and optimization for H100 GPUs.