FLASHATTENTION-2: Faster Attention with Better Parallelism and Work Partitioning

FLASHATTENTION-2: Faster Attention with Better Parallelism and Work Partitioning

July 18, 2023 | Tri Dao
FLASHATTENTION-2 improves the efficiency of attention mechanisms in Transformer models by enhancing parallelism and work partitioning. The original FLASHATTENTION reduces memory usage and runtime by leveraging GPU memory hierarchy, achieving 2-4× speedup over optimized baselines without approximation. However, it still lags behind optimized matrix-multiply (GEMM) operations, reaching only 25-40% of theoretical maximum FLOPs/s. FLASHATTENTION-2 addresses these inefficiencies through better work partitioning, reducing non-matmul FLOPs, parallelizing attention computation across thread blocks, and distributing work among warps within thread blocks to minimize shared memory access. These improvements result in around 2× speedup compared to FLASHATTENTION, reaching 50-73% of theoretical maximum FLOPs/s on A100 GPUs. When used end-to-end for training GPT-style models, FLASHATTENTION-2 achieves up to 225 TFLOPs/s per A100 GPU (72% model FLOPs utilization). The algorithm optimizes both forward and backward passes, parallelizing across sequence length, batch size, and number of heads. It also reduces shared memory access by splitting Q across warps instead of K and V, improving efficiency. Empirical validation shows FLASHATTENTION-2 is 1.7-3.0× faster than FLASHATTENTION, 1.3-2.5× faster than FLASHATTENTION in Triton, and 3-10× faster than standard attention implementations. The method is applicable to various devices and data types, with future work focusing on optimizing for H100 GPUs and new data types like FP8.FLASHATTENTION-2 improves the efficiency of attention mechanisms in Transformer models by enhancing parallelism and work partitioning. The original FLASHATTENTION reduces memory usage and runtime by leveraging GPU memory hierarchy, achieving 2-4× speedup over optimized baselines without approximation. However, it still lags behind optimized matrix-multiply (GEMM) operations, reaching only 25-40% of theoretical maximum FLOPs/s. FLASHATTENTION-2 addresses these inefficiencies through better work partitioning, reducing non-matmul FLOPs, parallelizing attention computation across thread blocks, and distributing work among warps within thread blocks to minimize shared memory access. These improvements result in around 2× speedup compared to FLASHATTENTION, reaching 50-73% of theoretical maximum FLOPs/s on A100 GPUs. When used end-to-end for training GPT-style models, FLASHATTENTION-2 achieves up to 225 TFLOPs/s per A100 GPU (72% model FLOPs utilization). The algorithm optimizes both forward and backward passes, parallelizing across sequence length, batch size, and number of heads. It also reduces shared memory access by splitting Q across warps instead of K and V, improving efficiency. Empirical validation shows FLASHATTENTION-2 is 1.7-3.0× faster than FLASHATTENTION, 1.3-2.5× faster than FLASHATTENTION in Triton, and 3-10× faster than standard attention implementations. The method is applicable to various devices and data types, with future work focusing on optimizing for H100 GPUs and new data types like FP8.
Reach us at info@study.space