FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision
๐ Abstract
The paper introduces FlashAttention-3, a new algorithm for speeding up attention on Hopper GPUs. The key contributions are:
- Exploiting asynchrony of the Tensor Cores and TMA to overlap computation and data movement via warp-specialization.
- Interleaving block-wise matrix multiplication (GEMM) and softmax operations to hide the lower-throughput softmax computations.
- Using block quantization and incoherent processing to leverage hardware support for FP8 low-precision, achieving close to 1.2 PFLOPs/s on H100 GPUs.
The authors demonstrate that FlashAttention-3 achieves 1.5-2.0x speedup over the previous FlashAttention-2 algorithm, reaching up to 740 TFLOPs/s (75% utilization) in FP16 and close to 1.2 PFLOPs/s in FP8. They also show that FP8 FlashAttention-3 has 2.6x lower numerical error compared to a baseline FP8 attention implementation.
๐ Q&A
[01] Attention and GPU Characteristics
1. What is the main computational bottleneck in the Transformer architecture? The attention mechanism constitutes the primary computational bottleneck in the Transformer architecture, as computing the self-attention scores of queries and keys has quadratic scaling in the sequence length.
2. What are the key hardware characteristics of GPUs that the paper aims to leverage? The paper focuses on leveraging the following GPU hardware characteristics:
- Asynchrony of the Tensor Cores and Tensor Memory Accelerator (TMA) to overlap computation and data movement
- Hardware support for low-precision computation, specifically FP8 Tensor Cores on the Hopper GPU
3. How does the paper's approach differ from the previous FlashAttention-2 algorithm? Compared to FlashAttention-2, the key differences in FlashAttention-3 are:
- Exploiting asynchrony through warp-specialized software pipelining
- Overlapping the softmax computation with the asynchronous block-wise GEMMs
- Adapting the algorithm to target the FP8 Tensor Cores, including handling layout constraints
[02] FlashAttention-3 Algorithm
1. What are the three main techniques introduced in FlashAttention-3? The three main techniques in FlashAttention-3 are:
- Exploiting asynchrony of the Tensor Cores and TMA to overlap computation and data movement via warp-specialization.
- Interleaving block-wise matrix multiplication (GEMM) and softmax operations to hide the lower-throughput softmax computations.
- Using block quantization and incoherent processing to leverage hardware support for FP8 low-precision.
2. How does the 2-stage pipelining algorithm work? The 2-stage pipelining algorithm in FlashAttention-3 overlaps the softmax computation of one warpgroup with the GEMM of another warpgroup. This is achieved by carefully scheduling the GEMM and softmax operations to exploit the asynchronous nature of the Tensor Cores and TMA.
3. What modifications are required to support FP8 precision in FlashAttention-3? To support FP8 precision, FlashAttention-3 needs to address two main challenges:
- Layout conformance: Performing in-kernel transposition of the input tensors to satisfy the k-major layout requirement of FP8 WGMMA instructions.
- Numerical accuracy: Employing block quantization and incoherent processing techniques to mitigate the loss of accuracy when moving to FP8 precision.
[03] Empirical Validation
1. What are the key performance results of FlashAttention-3 compared to prior work? The key performance results of FlashAttention-3 are:
- FP16 FlashAttention-3 achieves 1.5-2.0x speedup over FlashAttention-2 in the forward pass, reaching up to 740 TFLOPs/s (75% utilization) on the H100 GPU.
- FP8 FlashAttention-3 achieves close to 1.2 PFLOPs/s on the H100 GPU.
- For large sequence lengths, FP16 FlashAttention-3 outperforms and FP8 FlashAttention-3 is competitive with a state-of-the-art cuDNN implementation.
2. How does the numerical accuracy of FP8 FlashAttention-3 compare to a baseline FP8 attention implementation? The paper shows that FP8 FlashAttention-3 with block quantization and incoherent processing is 2.6 times more accurate than a baseline FP8 attention implementation that uses per-tensor quantization, especially in cases with outlier features.
3. What are the key takeaways from the ablation study on the 2-stage pipelining algorithm? The ablation study confirms that the algorithmic improvements in FlashAttention-3, including asynchrony with warp-specialization and overlapping GEMM and softmax computations, lead to significant speedups compared to the baseline. Specifically, the 2-stage pipelining algorithm improves performance from 570 to 661 TFLOPs/s on the H100 GPU.