FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
FlashAttention reduces memory-bound attention overhead by tiling and recomputing to avoid HBM access.
How can we compute exact attention without the quadratic memory bottleneck caused by materializing the large attention matrix in GPU HBM?
Transformers struggle to scale to long sequences because the self-attention mechanism requires quadratic memory and time, bottlenecked by slow memory access. FlashAttention is an I/O-aware algorithm that tiles the attention computation to keep data in fast on-chip SRAM and recomputes intermediate values during the backward pass to avoid storing large matrices in slow HBM. This approach achieves up to 3× speedup on GPT-2 and 15% faster training on BERT-large compared to standard implementations, while enabling context lengths up to 64K.
Paper Primer
Standard attention implementations are memory-bound: they frequently read and write large intermediate matrices to high-bandwidth memory (HBM). FlashAttention treats this as an I/O problem, using tiling to compute softmax in blocks and recomputing attention during the backward pass to eliminate the need for storing the full $N \times N$ attention matrix.
FlashAttention achieves significant wall-clock speedups over standard attention implementations.
Training BERT-large is 15% faster than the MLPerf 1.1 record, and GPT-2 training is up to 3× faster than HuggingFace baselines. Up to 3× speedup on GPT-2; 15% speedup on BERT-large.
FlashAttention enables scaling to significantly longer sequence lengths than standard Transformers.
The algorithm enables the first Transformers to achieve better-than-chance performance on Path-X (16K) and Path-256 (64K) challenges.
Why does this paper focus on I/O awareness rather than just reducing FLOPs?
Most attention operations are memory-bound, meaning their runtime is determined by the speed of memory access rather than the number of arithmetic operations. Reducing FLOPs often fails to translate to wall-clock speedup if the algorithm still requires frequent, slow reads and writes to HBM.
Does this method change the model output compared to standard attention?
No, FlashAttention computes exact attention. It produces the same output as standard attention implementations, ensuring numerical stability and identical training curves.
Researchers can now train Transformers with significantly longer context lengths without sacrificing model quality or incurring the memory overhead of standard attention, provided they use the provided CUDA-based kernels.
The Attention Bottleneck
Self‑attention’s quadratic cost is a memory‑access bottleneck that FlashAttention eliminates.
Self‑attention scales as $O(N^{2})$ because it materialises an $N\times N$ attention matrix in GPU high‑bandwidth memory (HBM). This quadratic memory traffic dominates runtime on modern GPUs, where compute outpaces memory bandwidth.
Standard attention must materialise the full $N\!\times\!N$ similarity matrix, forcing every entry to be read from and written back to slow HBM before the softmax can be applied.
Each float32 entry occupies 4 bytes, so the matrix needs $256\text{ KB}$ of HBM per head.
With $h=12$ heads, total HBM demand becomes $12\times256\text{ KB}=3.07\text{ MB}$.
During back‑propagation the same amount must be read again, doubling the HBM traffic to ≈ 6 MB.
Even modest sequence lengths already exhaust HBM bandwidth, explaining why scaling $N$ quickly becomes memory‑bound.
Quadratic complexity is a memory‑access problem, not just a FLOPs problem.
GPU Performance Characteristics
This section explains GPU memory tiers and how they affect attention’s memory traffic.
We first describe the GPU memory hierarchy and execution model that underlie modern deep‑learning kernels. Then we discuss how arithmetic intensity determines whether an operation is compute‑bound or memory‑bound, and why kernel fusion is a common optimization. Finally we review the standard attention implementation that our method will improve.
HBM is the large, relatively slow memory on a GPU, while SRAM is the tiny, extremely fast memory located directly on the compute units.
A GPU kernel launches a massive grid of threads; each thread loads inputs from HBM into registers and SRAM, performs its computation, then writes results back to HBM.
Arithmetic intensity—operations per byte of memory accessed—determines whether an operation’s runtime is limited by compute or by memory traffic.
Kernel fusion is the standard technique for accelerating memory‑bound operations: if several elementwise kernels consume the same input, the input can be loaded once from HBM and reused across the fused kernel. However, during training the intermediate results must still be written back to HBM for the backward pass, limiting the benefit of naive fusion.
Standard attention computes the matrix $S = QK^{\top}$, applies row‑wise softmax to obtain $P$, and finally multiplies $P$ by $V$ to produce the output $O$. The implementation materializes $S$ and $P$ in HBM, incurring $O(N^{2})$ memory traffic, which makes the operation largely memory‑bound.
Load $Q$ and $K$ from HBM in blocks, compute $S = QK^{\top}$, and write $S$ back to HBM.
Read $S$ from HBM, compute $P = \text{softmax}(S)$, and write $P$ back to HBM.
Load $P$ and $V$ from HBM in blocks, compute $O = PV$, and write $O$ back to HBM.
Return $O$ as the attention output.
The FlashAttention Algorithm
FlashAttention reorganizes attention into SRAM‑resident tiles, cutting HBM traffic while preserving exact results.
Standard attention materializes an $N \times N$ matrix in slow HBM, causing huge memory traffic. FlashAttention reshapes the computation into SRAM‑resident tiles, cutting those accesses dramatically.
Break the $Q$, $K$, $V$ matrices into blocks that fit in on‑chip SRAM, compute each block’s contribution, and combine them with proper normalization to recover exact attention.
How is this tiling different from simply multiplying sub‑matrices of $Q$ and $K$?
Naïve block multiplication would compute each block’s raw scores but still require the full $N \times N$ softmax denominator. Tiling tracks the per‑row maximum and sum across blocks, allowing the exact softmax to be assembled without ever materializing the denominator.
Load $K_1$, $V_1$ (rows 1‑2) into SRAM; compute $S_{11}=Q_1 K_1^{\top}$.
Compute row‑wise max $m_{11}$ and sum $\ell_{11}$ for $S_{11}$; update $O_1$ using the formula in the algorithm.
Load $K_2$, $V_2$ (rows 3‑4); compute $S_{12}=Q_1 K_2^{\top}$, obtain $m_{12}$, $\ell_{12}$, and merge with previous aggregates.
Proceed to $Q_2$ and repeat the same two tile computations, finally producing $O$.
Even with tiny blocks, the incremental normalization reproduces the exact softmax that a full $4 \times 4$ computation would yield.
FlashAttention reads each $K$–$V$ tile once and writes each output row once, keeping I/O linear in $N$.
Why does keeping just $(m,\ell)$ per row give the exact softmax result?
Softmax of a concatenated vector can be expressed using the maximum of each sub‑vector and the sum of exponentials of each sub‑vector. By maintaining the global maximum $m$ and the global sum $\ell$ across blocks, the final normalized probabilities are identical to those obtained from the full vector.
FlashAttention forward pass – fused CUDA kernel.
Algorithm 1 returns $O = \operatorname{softmax}(QK^{\top})V$ with $\mathcal{O}(N^{2}d)$ FLOPs and requires only $\mathcal{O}(N)$ additional memory beyond inputs and output.
For sequence length $N$, head dimension $d$, and SRAM size $M$ with $d \le M \le Nd$, FlashAttention needs $\Theta\!\bigl(N^{2} d^{2} / M\bigr)$ HBM accesses, versus $\Theta(Nd + N^{2})$ for standard attention.
No exact attention algorithm can achieve $o\!\bigl(N^{2} d^{2} / M\bigr)$ HBM accesses for all $M \in [d, Nd]$.
Block‑sparse FlashAttention with sparsity fraction $s$ requires $\Theta\!\bigl(Nd + N^{2} d^{2} / (M s)\bigr)$ HBM accesses.
**Figure 1.** **Left:** FlashAttention uses tiling to prevent materialization of the large $N \times N$ attention matrix (dotted box) on (relatively) slow GPU HBM. In the outer loop (red arrows), FlashAttention loops through blocks of the K and V matrices and loads them to fast on-chip SRAM. In each block, FlashAttention loops over blocks of Q matrix (blue arrows), loading them to SRAM, and writing the output of the attention computation back to HBM. **Right:** Speedup over the PyTorch implementation of attention on GPT-2. FlashAttention does not read and write the large $N \times N$ attention matrix to HBM, resulting in a 7.6x speedup on the attention computation.
Training Speed and Accuracy
FlashAttention delivers multi‑fold speedups and higher quality across a range of Transformer workloads.
FlashAttention speeds up GPT‑2 training up to 3.5× compared to the HuggingFace implementation while preserving perplexity.
Table 2 shows GPT‑2 small training time reduced from 9.5 days (1.0×) to 2.7 days (3.5×) with identical perplexity 18.2.
On the Long‑Range Arena benchmark FlashAttention achieves a 2.4× speedup over standard exact attention.
Table 3 reports a 2.4× runtime reduction for FlashAttention relative to the baseline Transformer.
**Figure 2.** **Left:** Forward + backward runtime of standard attention and FlashAttention for GPT-2 medium (seq. length 1024, head dim. 64, 16 heads, batch size 64) on A100 GPU. HBM access is the primary factor affecting runtime. **Middle:** Forward runtime of FlashAttention (seq. length 1024, head dim. 64, 16 heads, batch size 64) on A100 GPU. Fewer HBM accesses result in faster runtime, up to a point. **Right:** The runtime (for seq. length 4K) of block-sparse FlashAttention is faster than FlashAttention by a factor proportional to the sparsity.
**Figure 3.** **Left:** runtime of forward pass + backward pass. **Right:** attention memory usage.
**Figure 4.** Validation perplexity of GPT-2 small/medium using two implementations. We confirm that FLASHATTENTION yields the same validation curves as the baseline implementation from HuggingFace.
Limitations and Future Work
We recap the attention bottleneck and outline how related ideas intersect with our IO‑aware approach.
Standard attention materializes a full $N \times N$ matrix in slow GPU HBM; FlashAttention tiles the computation, keeping intermediate tiles in fast SRAM and thus cutting memory traffic.
Our current pipeline requires hand‑writing a new CUDA kernel for every attention variant, which forces developers into low‑level code and makes the implementation fragile across GPU generations.
Because the bottleneck lives in memory movement, the same IO‑aware principle could benefit any layer that reads or writes large tensors, not just attention.
Our implementation is optimal for a single GPU up to constant factors; scaling the same IO analysis to multiple GPUs introduces an extra data‑transfer dimension that remains largely unexplored.
IO‑aware runtime optimization has a long pedigree in systems research: classic I/O‑complexity analysis [1], the working‑set model [21], data‑locality frameworks [86], the Roofline model of arithmetic intensity [85], and scalability studies of parallel algorithms [59] all formalize the trade‑off between fast on‑chip memory and slower off‑chip storage.
Structured matrices—sparse, low‑rank, Toeplitz‑like, low‑displacement‑rank, and quasi‑separable forms—offer sub‑quadratic parameter counts and runtimes. Butterfly matrices [15, 64] and their products can represent any structured matrix with near‑optimal cost [16, 20]; later work [17, 18] refines them for hardware friendliness, yet the “hardware lottery” [41] still hampers their adoption in practice.
Sparse‑training research shows that dense networks contain small subnetworks (lottery tickets) that match full‑model performance [28‑30]; our block‑sparse FlashAttention can be viewed as a fixed lottery ticket for the attention pattern, preserving accuracy while reducing memory.
Efficient‑transformer work tackles the quadratic $N^2$ scaling of attention through hashing (Reformer [51]), low‑rank approximations (Performer [12, 54]), hybrid sparse/low‑rank schemes (Longformer [3], BigBird [92], Scatterbrain [9]), and sequence‑compression tricks (Linformer [84], Token‑to‑Token ViT [91]). These methods trade exactness for speed, whereas our IO‑aware tiling retains exact attention while shaving memory traffic.
Beyond attention, several families propose alternative long‑context mechanisms: HiPPO and its state‑space extensions (S4) [35, 31‑37] model history with continuous dynamics; LambdaNetworks [2], AFT [93], and FLASH [42] replace attention with learned kernels or linear transforms. Our IO‑aware mindset could be applied to these modules to further reduce off‑chip traffic.
Backward Pass and Recomputation
FlashAttention computes exact attention with linear extra memory by blockwise tiling and recomputation.
Standard attention materializes an N × N matrix in slow GPU HBM, causing quadratic memory traffic that dominates runtime.
Compute each query’s softmax normalizer once, then stream the values‑vector $V$ through the query loop, accumulating the output without ever storing the full attention matrix.
Compute L₁ = exp(1·1+0·1)+exp(1·0+0·1)+exp(1·1+0·0)+exp(1·0+0·0) = exp(1)+exp(0)+exp(1)+exp(0) ≈ 2.718+1+2.718+1 = 7.436.
Compute weight for v₁: $e^{q₁ \cdot k₁}$/L₁ = exp(1)/7.436 ≈ 0.365.
Accumulate o₁ ← 0 + 0.365·[2,0] = [0.73, 0].
Repeat for j=2,3,4, adding the weighted $v_j$ to o₁. Final o₁ ≈ [1.23, 0.73].
Repeat the same procedure for queries i=2,3,4, each time only storing the scalar $L_i$ and the running sum $o_i$.
The algorithm never creates a 4 × 4 attention matrix; it only stores four scalars $L_i$ and four 2‑dimensional running sums, illustrating the linear‑memory claim.
Why does pre‑computing $L_i$ remove the need for the full attention matrix?
Because each output $o_i$ depends on the softmax weights only through the denominator $L_i$; once $L_i$ is known, the numerator $e^{q_i \cdot k_j}$ can be recomputed on the fly for each j, so we never have to keep the N × N weight matrix in memory.
Re‑use the same per‑query normalizers $L_i$ computed in the forward pass, and express all gradients as sums over the same softmax weights, so the backward sweep also needs only linear extra storage.
For j=1, compute dv₁ = ∑_i (exp($q_i$·k₁)/$L_i$)·`do_i`. Using the previously computed $L_i$ values, the weighted sum yields dv₁ ≈ [0.03, ‑0.04].
Repeat for j=2,3,4 to obtain dv₂, dv₃, dv₄ —all without storing P.
Compute $D_i$ = `do_i`ᵀ $o_i$ for each i (e.g., D₁ ≈ 0.1·1.23 + (‑0.2)·0.73 ≈ ‑0.02).
Use $D_i$ to form dS and then `dq_i` = ∑_j (exp($q_i$·$k_j$)/$L_i$)(`do_i`·$v_j$ ‑ $D_i$) $k_j$, again only needing scalar $D_i$ and the per‑query $L_i$.
The backward computation mirrors the forward one: every step re‑uses the same O(N) statistics, confirming the linear‑memory claim for gradients.
How does this “recomputation” differ from classic gradient checkpointing?
Classic checkpointing stores intermediate activations at a few layers and recomputes the rest; here we recompute the softmax numerators on the fly using the already‑saved normalizers $L_i$, so we never need to store any N × N intermediate at all.
Initialize RNG state R and allocate SRAM of size M.
Choose block size B = ⌊M/(4d)⌋ and split Q, K, V into row‑blocks $Q_i$ and column‑blocks $K_j$, $V_j$.
For each column block j:
Load $K_j$, $V_j$ into SRAM.
For each row block i:
Load $Q_i$ and the current partial output $O_i$, ℓ_i, $m_i$ into SRAM.
Compute the scaled scores $S_{ij}$=$\tau$ $Q_i$ $K_j$ᵀ.
Mask $S_{ij}$ and compute the row‑wise max \tilde $m_{ij}$ and sum \tilde ℓ_{ij}=∑ exp($S_{ij}$‑\tilde $m_{ij}$).
Update the running max $m_i$←max($m_i$, \tilde $m_{ij}$) and the running normalizer ℓ_i←ℓ_i exp($m_i$‑\tilde $m_{ij}$) + \tilde ℓ_{ij}.
Form the softmax block $P_{ij}$=exp($S_{ij}$‑\tilde $m_{ij}$)/ \tilde ℓ_{ij}.
Apply dropout to $P_{ij}$ and accumulate $O_i$←$O_i$ + $P_{ij}^{\text{dropped}$} $V_j$.
Write back $O_i$, ℓ_i, $m_i$ to HBM.
Return the final output O together with the saved statistics ℓ, m and RNG state R.
Restore RNG state R from the forward pass.
Set the same block size B as in the forward pass and partition Q, K, V, O, dO into matching blocks.
Allocate zero‑initialized gradient buffers dQ, dK, dV in HBM.
Load $K_j$, $V_j$ into SRAM and zero‑initialize temporary gradients \tilde `dK_j`, \tilde `dV_j`.
Load $Q_i$, $O_i$, `dO_i`, ℓ_i, $m_i$ into SRAM.
Recompute $S_{ij}$=$\tau$ $Q_i$ $K_j$ᵀ and the masked version.
Reconstruct the softmax block $P_{ij}$=diag(ℓ_i)^{-1} exp($S_{ij}^{\text{masked}$}).
Regenerate the dropout mask $Z_{ij}$ from R and form $P_{ij}^{\text{dropped}$}=$P_{ij}$⊙$Z_{ij}$.
Accumulate \tilde `dV_j` += ($P_{ij}^{\text{dropped}$})ᵀ `dO_i`.
Compute dP_{ij} = (`dO_i` $V_j$ᵀ)⊙$Z_{ij}$.
Compute $D_i$ = row‑sum(`dO_i` ⊙ $O_i$).
Form dS_{ij}=$P_{ij}$⊙(dP_{ij} ‑ $D_i$).
Update `dQ_i` += $\tau$ dS_{ij} $K_j$ and \tilde `dK_j` += $\tau$ dS_{ij}ᵀ $Q_i$.
Write back the updated `dQ_i` to HBM.
After the inner loop, write \tilde `dK_j`, \tilde `dV_j` to the global `dK_j`, `dV_j` buffers.
Return the gradients dQ, dK, dV.
The blockwise design eliminates the quadratic HBM traffic of the naïve algorithm while preserving exact attention, and the recomputation formulas guarantee that the backward pass also stays linear in extra memory.
Complexity Analysis
Proofs establish FLOP counts, memory usage, and correctness of the streaming attention algorithm.
Theorem 1 quantifies the dominant arithmetic and auxiliary storage required by the tiled attention algorithm.
Base case $j=0$ holds because the algorithm initializes $m$, $\ell$, and $O$ to the neutral values $-\infty$, $0$, and $0$ respectively.
Inductive step: assuming the invariants hold after processing $j$ column blocks, the update formulas preserve them for $j+1$.
The rowsum $\ell(j+1)$ is updated by scaling the old sum and adding the new block’s contribution, which algebraically equals $\operatorname{rowsum}\!\bigl(\exp(S_{:,j+1}-m(j+1))\bigr)$.
The output block $O(j+1)$ is formed by normalising the accumulated weighted values with the new rowsum, yielding $P_{:,j+1}V_{:,j+1}$.
By induction the invariants hold for all $j=0,\dots,T_c$, and for $j=T_c$ we recover the full attention output $O=\operatorname{softmax}(QK^{\top})V$.
Theorem 2 compares the memory traffic of the naïve implementation with that of the tiled (streaming) version.
Deriving feasible tile dimensions from on‑chip memory size $M$ yields the block‑size constraints.
Assume an exact‑attention algorithm uses $o\!\bigl(N^{2}d^{2}/M\bigr)$ HBM accesses for all $M\in[d,Nd]$.
Reading the inputs $Q,K,V$ and writing the output $O$ each require $\Theta(Nd)$ accesses, so any algorithm must perform at least $\Omega(Nd)$ accesses.
This contradicts the assumed $o(Nd)$ bound, establishing the lower bound $\Omega(Nd)$ and, more precisely, $\Omega\!\bigl(N^{2}d^{2}/M\bigr)$ for the exact‑attention case.
Theorem 5 shows that the backward pass of FlashAttention inherits the same I/O complexity as the forward pass.
Initialize random seed $R$ and store it in HBM.
Allocate zeroed output $O$, row‑sum vector $\ell$, and row‑max vector $m$ in HBM.
Partition $Q$ into $T_r$ row tiles $Q_i$ of size $B_r\times d$; similarly partition $K$, $V$ into $T_c$ column tiles $K_j$, $V_j$ of size $B_c\times d$.
For each row tile $i$: load $Q_i$, $O_i$, $\ell_i$, $m_i$ into SRAM.
For each column tile $j$: if the sparsity mask $\mathcal{M}_{ij}=1$, load $K_j$, $V_j$ into SRAM; compute the scaled product $S_{ij}=\tau\,Q_iK_j^{\top}$.
Mask $S_{ij}$, then compute its row‑max $\tilde m_{ij}$ and row‑sum $\tilde\ell_{ij}$ of the exponentiated, shifted scores.
Update the running maxima $m_i\leftarrow\max(m_i,\tilde m_{ij})$ and the running sums $\ell_i\leftarrow\ell_i\exp(m_i-\tilde m_{ij})+\tilde\ell_{ij}$.
Form the unnormalised attention weights $\tilde P_{ij}=\exp(S_{ij}-\tilde m_{ij})$, apply dropout, and accumulate the contribution $\operatorname{diag}(\ell_i)^{-1}\bigl(\tilde P_{ij}V_j\bigr)$ into $O_i$.
Write the updated $O_i$, $\ell_i$, $m_i$ back to HBM before proceeding to the next column tile.
After all column tiles are processed, proceed to the next row tile.
Return the final output $O$ together with the auxiliary statistics $\ell$, $m$, and the RNG state $R$.
Block-Sparse FlashAttention
Extending I/O‑aware attention with block sparsity and broader hardware scenarios.
Standard FlashAttention already avoids materialising the full $N\times N$ matrix, but many workloads contain structured zeros that can be skipped entirely.
Instead of processing every $N\times N$ tile, the algorithm only loads tiles that correspond to non‑zero blocks in a sparsity mask, keeping all active data in fast SRAM and never touching the slow HBM for empty regions.
How does block‑sparse FlashAttention differ from applying a mask to standard FlashAttention?
Masking a dense FlashAttention still materialises every $M\times M$ tile before discarding zero entries, so HBM traffic remains $O(N^{2}d^{2}/M)$. Block‑sparse FlashAttention never loads the empty tiles at all; the sparsity mask drives the tile‑selection loop, cutting HBM reads by the factor $s$.
Compute the output‑write cost: $N d = 4\times2 = 8$ memory words.
Compute the tile‑read cost: $s\,N^{2}d^{2}/M = 0.25 \times 4^{2} \times 2^{2} / 2 = 0.25 \times 16 \times 4 / 2 = 8$ memory words.
Total HBM I/O = $8 + 8 = 16$ words, half of the dense case ($32$ words) because three empty tiles are never touched.
Even with $s=0.25$, the algorithm still writes the full $8$‑word output $O$, which dominates the cost when $s$ becomes smaller.
Block sparsity cuts the read side proportionally to $s$, but the write side is unchanged; therefore the benefit plateaus once $s$ is very low.
Beyond block sparsity, the I/O‑aware philosophy can be applied to other parts of large‑scale training pipelines.
Multi‑GPU attention can exploit the hierarchy of on‑chip SRAM, local GPU HBM, and remote GPU HBM by assigning different blocks to different devices and moving only the non‑zero tiles across the node‑level interconnect.
Sparse MLP layers often become memory‑bound; an I/O‑aware implementation would keep the sparse weight tiles in SRAM and stream only the active rows, mirroring the block‑sparse attention trick.
Kernel methods share the low‑rank structure $K_{ij}=k(x_i,x_j)$; like FlashAttention, they can recompute each kernel entry from the two input vectors $x_i,x_j$, avoiding a full $N\times N$ kernel matrix in HBM. The KeOps library demonstrates this principle for large‑scale kernel learning.
Extended Experimental Results
FlashAttention delivers multi‑task training speedups while preserving model quality.
We evaluate FlashAttention on three representative workloads—BERT‑large, GPT‑2 (small and medium), and the Long‑Range Arena (LRA) benchmark—using the same data splits and hyper‑parameters as the MLPerf 1.1 reference implementations.
FlashAttention trains BERT‑large to the target 72.0 % masked‑language‑model accuracy in roughly 17 minutes, a large reduction in wall‑clock time.
Eight A100‑80GB GPUs, batch size 448, LAMB optimizer (lr 3.75e‑3), FP16 with Apex AMP (O2); each run measured over 10 repetitions.
For GPT‑2 we follow the Megatron‑LM recipe, using mixed‑precision training on 8×A100‑40GB GPUs. Both the HuggingFace and Megatron implementations achieve identical validation perplexities, confirming functional parity.
On the five Long‑Range Arena tasks we observe that all attention variants reach comparable accuracy after hyper‑parameter tuning. FlashAttention’s wall‑clock‑time speedup is computed as the geometric mean of per‑task speedups.
We also compare against Apex FMHA, the fastest publicly available short‑sequence attention implementation at the time of this work. Apex FMHA is tuned for sequence lengths ≤512, whereas FlashAttention maintains its efficiency on much longer inputs.
Hardware Benchmarks
FlashAttention delivers multi‑fold speedups across GPUs and sequence lengths.
FlashAttention is 8 % faster than Apex FMHA at sequence length 256 on an A100 GPU.
Measured runtime: FlashAttention 0.29 s vs. Apex FMHA 0.22 s (batch 8, 12 heads, dim 64).
Apex FMHA fuses dropout‑masked softmax and the value‑multiply into a single CUDA kernel, but it still writes the full attention matrix to HBM for the backward pass.
How does Apex FMHA’s memory handling differ from FlashAttention’s?
Apex FMHA writes the entire softmax‑masked attention matrix to HBM during the forward pass and reads it back for the backward pass. FlashAttention instead tiles the computation, keeping intermediate tiles in SRAM and recomputing the matrix on the backward pass, which avoids the large HBM traffic.
**Figure 5.** Speedup over standard PyTorch attention at different sequence lengths, on A100.
**Figure 6.** Speedup over standard PyTorch attention at different sequence lengths, on A100, with head dimension 128.