FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

Tri Dao

FlashAttention-2 optimizes GPU work partitioning to reach 73% of theoretical peak throughput.

How can we optimize the attention mechanism's GPU execution by reducing non-matrix-multiplication overhead and improving work partitioning across thread blocks?

Standard attention implementations are bottlenecked by memory access, and even the original FlashAttention leaves significant GPU compute capacity unused due to inefficient work distribution. FlashAttention-2 re-partitions the attention computation across GPU thread blocks and warps to maximize occupancy and minimize shared-memory communication, while reducing non-matrix-multiply operations. These changes yield a 2× speedup over the original FlashAttention, reaching up to 225 TFLOPs/s per A100 GPU.

Paper Primer

The core innovation is a shift in how work is distributed across the GPU's hierarchy. By parallelizing the attention computation along the sequence-length dimension and re-partitioning the workload within each thread block to avoid redundant shared-memory synchronization, the algorithm keeps the GPU's specialized matrix-multiply units saturated.

FlashAttention-2 achieves significantly higher hardware utilization than its predecessor.

Benchmarking on A100 GPUs shows the forward pass reaching 73% of theoretical maximum FLOPs/s. Approximately 2× faster than the original FlashAttention and up to 10× faster than standard PyTorch attention.

The method scales effectively to end-to-end training of large models.

Training GPT-style models (1.3B–2.7B parameters) reached 225 TFLOPs/s per A100 GPU. 1.3× speedup over the original FlashAttention in end-to-end training throughput.

Why does this paper focus on "work partitioning" rather than just using more compute?

Modern GPUs have specialized units for matrix multiplication that are significantly faster than general-purpose arithmetic. The paper observes that previous implementations were bottlenecked by suboptimal scheduling that forced these fast units to wait for data or perform unnecessary memory operations.

Is this an approximation method that trades accuracy for speed?

No. Like the original FlashAttention, this algorithm produces the exact same output as standard attention, maintaining mathematical equivalence while optimizing the execution path.

Introduction and Motivation

Attention’s quadratic cost blocks longer contexts, and FlashAttention‑2 reshapes the bottleneck by cutting non‑matmul work and improving GPU parallelism.

Transformers process sequences with an attention layer whose runtime and memory scale as $N^2$, where $N$ is the sequence length. This quadratic growth makes it impractical to train models on long texts, high‑resolution images, or extended video clips.

Recent large‑scale models such as GPT‑4 (32 k tokens), MPT (65 k), and Claude (100 k) illustrate the demand for longer contexts, yet most training pipelines still use the standard attention formulation. FlashAttention mitigates the memory burden by reordering the computation, achieving 2–4× wall‑clock speedups and linear memory usage, but it still attains only 30–50 % of the device’s peak FLOPs/s because of extra non‑matmul work and inefficient GPU work partitioning.

Even after making attention memory‑linear, the kernel still spends a sizable fraction of time on operations that GPUs cannot execute as fast as dense matrix‑multiply, and the work is not evenly spread across thread blocks, leading to low occupancy.

Standard attention: compute 16 dot‑products → 16 FLOPs, store 16 scores.

FlashAttention: compute 16 dot‑products (matmul) + 4 softmax normalizations + 4 index shuffles → 24 FLOPs total, but the 8 non‑matmul FLOPs run on slower units.

FlashAttention‑2 removes the 8 non‑matmul FLOPs, leaving only the 16 matmul FLOPs, which the GPU can execute at peak throughput.

Even on a tiny $N=4$ example, the non‑matmul work can dominate the runtime because it cannot be accelerated by the GPU’s matrix‑multiply engines; eliminating it yields a clear compute‑bound kernel.

The key shift is moving attention from a memory‑bound bottleneck to a compute‑bound regime by cutting non‑matmul overhead and improving work partitioning.

GPU Execution and Attention Basics

Understanding GPU execution and the classic attention pipeline is essential before introducing FlashAttention.

The performance of attention kernels is dictated by GPU hardware and its execution model, and by how the attention algorithm maps onto that model.

A GPU launches millions of lightweight threads; threads are grouped into thread blocks that share fast on‑chip SRAM, and blocks are scheduled onto streaming multiprocessors (SMs).

Attention first computes a similarity matrix between queries and keys, normalizes each row with softmax, and then uses the resulting weights to blend the values.

FlashAttention tiles the attention computation, keeping each tile’s intermediate scores in on‑chip SRAM and applying an online softmax that rescales partial results, thereby eliminating any $O(N^{2})$ writes to global memory.

**Figure 1.** Diagram of how FLASHATTENTION forward pass is performed, when the key K is partitioned into two blocks and the value V is also partitioned into two blocks. By computing attention with respect to each block and rescaling the output, we get the right answer at the end, while avoiding expensive memory reads/writes of the intermediate matrices S and P. We simplify the diagram, omitting the step in softmax that subtracts each element by the row-wise max.

Forward Pass Optimization

Describes forward‑pass tricks that cut non‑matmul work and speed up long‑context attention.

On GPUs, a non‑matmul floating‑point operation costs roughly sixteen times more than a matmul FLOP, so the bulk of runtime comes from the extra arithmetic in the softmax update.

Instead of scaling the partial output after every block, we keep an unscaled accumulator and apply the final diagonal scaling only once, and we store just the log‑sum‑exp needed for the backward pass.

Initialize $\tilde O^{(0)}=0$, $m^{(0)}=-\infty$, $\ell^{(0)}=0$.

After block 1: $\tilde O^{(1)} = e^{S^{(1)}-m^{(1)}} V^{(1)} = [[2,0],[0,1]]$, $L^{(1)} = m^{(1)} + \log(\ell^{(1)}) = [1+\log2, 0+\log1]$.

Compute scaling factor for block 2: $e^{m^{(1)}-m^{(2)}} = [e^{-1}, e^{-1}]$.

Update unscaled output: $\tilde O^{(2)} = \operatorname{diag}(e^{m^{(1)}-m^{(2)}})\tilde O^{(1)} + e^{S^{(2)}-m^{(2)}} V^{(2)} = [[2e^{-1}+1,0],[0, e^{-1}+1]]$.

Accumulate log‑sum‑exp: $\ell^{(2)} = e^{m^{(1)}-m^{(2)}}\ell^{(1)} + \operatorname{rowsum}(e^{S^{(2)}-m^{(2)}}) = [2e^{-1}+3, 1e^{-1}+2]$, $L^{(2)} = m^{(2)} + \log(\ell^{(2)})$.

Final scaling: $O = \operatorname{diag}(\ell^{(2)})^{-1}\tilde O^{(2)}$, yielding the exact softmax output for the four tokens.

The example shows that only one diagonal scaling is performed at the end, while the intermediate updates are simple additions, cutting the number of expensive non‑matmul operations.

How does this differ from the original FlashAttention online softmax?

The original version rescales the partial output after each block and keeps both the max $m^{(j)}$ and the sum $\ell^{(j)}$ for the backward pass. FlashAttention‑2 keeps the output unscaled, adds the new contribution directly, and stores only the combined log‑sum‑exp $L^{(j)}$, eliminating the extra diagonal multiplications.

FlashAttention‑2 forward pass – high‑level pseudocode.

Causal masking lets us skip any block whose column indices lie entirely after its row indices, cutting roughly half the block computations and giving a 1.7–1.8× speedup for autoregressive models.

Backward Pass Optimization

Shows how the backward pass drops extra softmax work by using a single log‑sum‑exp.

The backward pass of attention normally computes a row‑wise max then a sum of exponentials, incurring extra non‑matmul FLOPs. FlashAttention‑2 replaces that two‑step reduction with a single row‑wise log‑sum‑exp, cutting those FLOPs and simplifying the kernel. This section details that optimization and its blockwise implementation.

Replace the two‑step softmax reduction (row‑wise max then sum of exponentials) with a single row‑wise log‑sum‑exp $L$, eliminating the extra non‑matmul FLOPs.

How does using only the row‑wise log‑sum‑exp differ from the original two‑step softmax?

The original backward pass first finds the per‑row maximum, then computes the sum of exponentials and finally normalizes; that requires an extra max reduction and an additional pass over the scores. The log‑sum‑exp $L$ combines the max and sum into a single stable log‑sum, so the softmax probabilities are obtained with one subtraction and exponentiation, removing the extra non‑matmul FLOPs.

Compute raw scores for the first query tile: $S_{11}=Q_1 K_1^{\top}$ (a $2\times2$ matrix).

Calculate $L_1 = \log\!\big(\exp(S_{11,11})+\exp(S_{11,12})+\exp(S_{11,21})+\exp(S_{11,22})\big)$.

Form probabilities $P_{11} = \exp(S_{11} - L_1)$, yielding a $2\times2$ softmax matrix.

Update gradients $dV_1 \mathrel{+=} (P_{11}\circ dO_1) V_1^{\top}$ and $dK_1 \mathrel{+=} P_{11}^{\top} dQ_1$ using the same $L_1$.

Repeat the same steps for the second query tile ($i=2$) with the same $L$ computation pattern.

The log‑sum‑exp replaces two separate reductions (max and sum) with a single stable computation, cutting the non‑matmul FLOPs while keeping the softmax numerically accurate.

Divide $Q$ into $T_r=\lceil N/B_r\rceil$ row tiles; divide $K$ and $V$ into $T_c=\lceil N/B_c\rceil$ column tiles.

Divide the output $O$, its gradient $dO$, and the auxiliary vector $L$ into the same $T_r$ row tiles.

Allocate zero‑initialized gradient buffers $dQ$, $dK$, $dV$ in HBM and split them into matching tiles.

Compute the auxiliary sum $D = \operatorname{rowsum}(dO\circ O)$ and split it into $T_r$ row tiles.

For each key/value tile $j$ (outer loop): load $K_j$, $V_j$ into SRAM; zero $dK_j$, $dV_j$.

For each query tile $i$ (inner loop): load $Q_i$, $O_i$, $dO_i$, $dQ_i$, $L_i$, $D_i$ into SRAM; compute scores $S_{ij}=Q_i K_j^{\top}$.

Form probabilities $P_{ij}=\exp(S_{ij}-L_i)$ using the pre‑computed $L_i$.

Accumulate gradient contributions: $dV_j\mathrel{+=}(P_{ij}\circ dO_i)V_j^{\top}$, $dS_{ij}=dP_{ij}\circ P_{ij}$, $dQ_i\mathrel{+=}dS_{ij}K_j$, $dK_j\mathrel{+=}dS_{ij}^{\top}Q_i$.

After finishing all $i$, write $dK_j$ and $dV_j$ back to HBM.

Finally, output the assembled gradients $dQ$, $dK$, $dV$.

By collapsing the softmax reduction to a single log‑sum‑exp, FlashAttention‑2’s backward pass eliminates the extra non‑matmul FLOPs that dominated the original implementation, especially for long sequences.

Parallelism and Work Partitioning

Parallelism is reshaped to keep GPUs busy on long sequences by splitting work across thread blocks and warps.

When batch size or head count is small, the original per‑head thread‑block schedule leaves most of the 108 SMs on an A100 idle, throttling throughput on long sequences.

Instead of assigning a whole head to one block, we cut the long‑sequence dimension into chunks and give each chunk its own thread block, so every SM stays busy even when the batch is tiny.

We give each warp a distinct slice of the query matrix $Q$ while letting all warps share the key $K$ and value $V$, eliminating the need for inter‑warp reductions.

Why isn’t the “split‑K” scheme sufficient for the backward pass?

During back‑propagation each warp would still need to read partial gradients from shared memory and then write them back, re‑introducing the synchronization bottleneck. Splitting $Q$ avoids that because each warp’s gradient contribution to $dQ$ is computed locally and only the final $dQ$ accumulation uses atomic adds.

Warp 0 holds $Q_{0:2}$ (rows 0‑1); Warp 1 holds $Q_{2:4}$ (rows 2‑3).

Both warps read the full $K^\top$ (4 × 8) and $V$ (4 × 8) from global memory.

Warp 0 computes its $QK^\top$ slice (2 × 4) and multiplies with the shared $V$ slice, producing output rows 0‑1.

Warp 1 does the same for rows 2‑3, producing output rows 2‑3.

No shared‑memory writes are needed; each warp writes its two output rows directly to the output matrix.

By giving each warp its own $Q$ slice we eliminate the inter‑warp reduction step entirely, which is the dominant overhead in the original split‑K design.

**Figure 2.** In the forward pass (left), we parallelize the workers (thread blocks) where each worker takes care of a block of rows of the attention matrix. In the backward pass (right), each worker takes care of a block of columns of the attention matrix.

**Figure 3.** Work partitioning between different warps in the forward pass

Performance Benchmarks

FlashAttention‑2 delivers dramatic speed gains and high TFLOPs on long‑context training.

FlashAttention‑2 builds on the original FlashAttention by trimming non‑matmul work and reshaping thread‑block partitioning. The result is a markedly faster attention kernel on modern GPUs.

FlashAttention‑2 is up to $10\times$ faster than the standard PyTorch attention implementation.

Measured on an A100 GPU across sequence lengths up to 16 k, the optimized kernel consistently outpaces the baseline by a factor of ten.

**Figure 4.** Attention forward + backward speed on A100 GPU

**Figure 5.** Attention forward speed on A100 GPU

**Figure 6.** Attention backward speed on A100 GPU

**Figure 7.** Attention forward + backward speed on H100 GPU

**Table 1.** Performance comparison of GPT3 models using different attention mechanisms.

Discussion and Impact

We discuss the impact, future plans, and acknowledge contributors to FlashAttention‑2.

FlashAttention‑2 runs about twice as fast as FlashAttention, letting us train models with a 16 k context length for the same cost as an 8 k model.

This speedup opens the door to processing long books, reports, high‑resolution images, audio, and video with a single model.

Beyond training, the same gains accelerate fine‑tuning and inference of existing models.

We plan to work with researchers and engineers to broaden FlashAttention‑2 to diverse hardware—including H100 and AMD GPUs—and to support new data types such as FP8.

Our immediate target is to exploit H100 features like TMA, fourth‑generation Tensor Cores, and fp8 precision.

By pairing low‑level optimizations with algorithmic variants such as local, dilated, or block‑sparse attention, we anticipate training with substantially longer contexts, and we aim to collaborate with compiler researchers to make these techniques easy to program.

We thank Phil Tillet and Daniel Haziza for their FlashAttention implementations in Triton and xformers.

We thank the NVIDIA CUTLASS team—Vijay Thakkar, Cris Cecka, Haicheng Wu, and Andrew Kerr—for clean abstractions that underpin FlashAttention‑2.

We appreciate Driss Guessous for integrating FlashAttention into PyTorch.

We thank many colleagues (Phil Wang, Markus Rabe, James Bradbury, Young‑Jun Ko, Julien Launay, Daniel Hesslow, Michaël Benesty, Horace He, Ashish Vaswani, Erich Elsen) for insightful discussions.

We thank Stanford CRFM and Stanford NLP for compute support.

We thank Dan Fu and Christopher Ré for collaboration, feedback, and encouragement on designing hardware‑efficient algorithms.

Finally, we thank Albert Gu and Beidi Chen for their helpful suggestions on early drafts of this technical report.

Read the original paper

Open the simplified reader on Paperglide