MiniMax Sparse Attention

Xunhao Lai, Weiqi Xu, Yufeng Yang, Qiaorui Chen, Yang Xu, Lunbin Zeng, Xiaolong Li, Haohai Sun, Haichao Zhu, Vito Zhang, Pengyu Zhao

MiniMax Sparse Attention (MSA) uses a lightweight, learned indexer to enable efficient, block-sparse attention for long-context LLMs.

How can we enable efficient, ultra-long-context attention in LLMs by dynamically selecting only the most relevant key blocks for each query?

Quadratic softmax attention makes processing millions of tokens prohibitively expensive for modern agentic workflows, forcing a trade-off between context length and compute efficiency. MiniMax Sparse Attention (MSA) solves this by adding a lightweight Index Branch to Grouped Query Attention (GQA) that independently selects the most relevant key-value blocks for each attention group. The Main Branch then performs exact softmax attention only over these selected blocks, bypassing the full sequence. On a 109B-parameter model, MSA matches GQA performance while reducing per-token attention compute by 28.4× at 1M context, achieving significant wall-clock speedups on H800 hardware.

Paper Primer

MSA functions as a two-stage selector: the Index Branch uses a block-level dot-product to score and pick the top-$k$ key-value blocks, while the Main Branch executes the attention computation. This design is like a library index: the indexer quickly identifies the relevant shelves (blocks), so the reader (Main Branch) only needs to pull and process those specific books rather than scanning the entire library.

MSA achieves massive compute efficiency gains at long context lengths without sacrificing model quality.

Comparison of MSA against a GQA Full-Attention baseline on a 109B-parameter MoE model. 28.4× reduction in per-token attention compute at 1M context, with 14.2× prefill and 7.6× decoding speedups.

The learned sparse pattern preserves performance across diverse benchmarks.

Evaluation on MMLU, GSM8K, HumanEval, and various multimodal/agentic tasks. MSA-PT (trained from scratch) and MSA-CPT (continued pretraining) remain broadly competitive with the Full-Attention baseline.

Why use block-level selection instead of token-level selection?

Fine-grained token-level selection is difficult to map efficiently to GPU matrix operations; block-level selection allows for regular, contiguous memory access that better utilizes tensor cores.

How does the model learn to select relevant blocks if the selection process is non-differentiable?

The Index Branch is trained using an auxiliary KL alignment loss that forces it to match the attention distribution of the Main Branch, effectively distilling the "importance" of blocks into the indexer.

The Long-Context Bottleneck

Ultra‑long contexts are essential, yet dense attention’s quadratic cost blocks million‑token LLMs.

Ultra‑long contexts are essential for modern LLM applications such as agentic workflows, code reasoning, and persistent memory, but softmax attention’s quadratic compute makes them infeasible at deployment scale. This quadratic scaling is the primary barrier to reaching million‑token contexts. MiniMax Sparse Attention (MSA) addresses this by adding a lightweight Index Branch that scores and selects top‑k key‑value blocks per GQA group, while the Main Branch performs exact block‑sparse attention only on those blocks, yielding a 28.4× reduction in per‑token compute and 14.2× prefill / 7.6× decoding speedups on a 109 B model.

**Figure 14.** Perplexity comparison between MSA and a FLOP-matched sliding-window baseline on downstream agent-oriented evaluations. Lower Perplexity indicates better modeling performance under the same sparse selection budget.

Quadratic scaling of dense attention is the main obstacle to million‑token contexts.

Foundations of Causal Attention

Key background on causal attention, grouped‑query attention, and the two‑stage sparse attention formulation.

Causal (autoregressive) attention computes each token’s output as a weighted sum of all previous token values, using a softmax over dot‑product scores. This operation costs $\Theta(H_q N\, d_h)$ FLOPs, growing quadratically with the sequence length $N$.

GQA groups $G$ query heads to share a single key‑value head, cutting the number of KV heads while keeping each query head’s representation.

How does GQA differ from standard multi‑head attention?

Standard multi‑head attention gives each query head its own distinct key‑value head, incurring $H_q$ separate KV projections. GQA ties $G$ query heads to a single KV head, so only $H_{kv}=H_q/G$ KV projections are performed, saving compute while still allowing each query head to attend independently.

Sparse attention is implemented as a two‑stage process: an indexer selects a subset of keys for each query (the Index Branch), and attention is computed only over those selected keys (the Main Branch).

The MSA Mechanism

MSA replaces dense attention with a lightweight index that selects key blocks for focused computation.

Dense causal attention costs $O(N^2)$ operations, most of which attend to irrelevant tokens.

MSA splits attention into a fast Index Branch that picks a few key blocks, and a Main Branch that computes full softmax only on those blocks.

Index Branch projects $X$ to $\text{idx}_q\in\mathbb{R}^{4\times2\times d_{\text{idx}}}$ and $\text{idx}_k\in\mathbb{R}^{4\times1\times d_{\text{idx}}}$ with $d_{\text{idx}}\!\ll\!d_h$.

For each group $r$, block scores $M^{\text{idx}}_{r,i,b}$ are max‑pooled over tokens in each block, respecting causality ($j\le i$).

Top‑$k$ selects the block containing the query token (the local block) as $I_r(i)$.

Main Branch computes softmax attention only over the two tokens in the chosen block, producing $O_{h,i}$.

The output combines the heads and passes through the final projection, yielding the same shape as dense attention.

Even with a tiny context, the indexer learns to keep the query’s own block and discards distant tokens, preserving local information while cutting compute.

A lightweight head projects the hidden states to a low‑dimensional space, scores each block, and picks the top‑$k$ blocks per GQA group.

How is the Index Branch different from running a full‑resolution attention scorer?

It reduces dimensionality ($d_{\text{idx}}\!\ll\!d_h$) and aggregates scores per block via max‑pooling, so its cost is $O(N\,H_{kv}\,d_{\text{idx}})$ instead of $O(N\,H_q\,d_h)$.

Given the block indices from the Index Branch, the Main Branch performs standard scaled dot‑product attention but restricted to the selected tokens.

Why does the Main Branch still need a full softmax if only a few tokens are attended?

Softmax normalizes over the selected tokens to produce a proper probability distribution; without it the attention weights would not sum to one, breaking the model’s training dynamics.

Training the Index Branch requires a differentiable signal, so we align its block scores to the Main Branch distribution with a KL loss while detaching gradients from the backbone.

The Index Branch is trained to mimic dense attention via KL divergence, ensuring its block selections reflect the true attention pattern.

**Figure 1.** Overview of MSA. The Index Branch (left) scores the full causal context with a single lightweight head and selects, for each query and GQA group, a set $I$ of $k$ key blocks; the local block is always included regardless of its score. The Main Branch (right) attends only to the selected blocks and produces the layer output. During training, a KL loss aligns the index distribution with the group-averaged Main Branch distribution on the selected blocks, and the Index Branch gradient is detached from the Main Branch.

**Algorithm 1** One MSA layer: training forward and the auxiliary KL loss. The layer returns its output and per-layer $\mathcal{L}_{\text{KL}}$; the model loss $\mathcal{L} = \mathcal{L}_{\text{LM}} + \lambda \sum_{\text{layers}} \mathcal{L}_{\text{KL}}$ is assembled by the training loop. **Require:** hidden states $X \in \mathbb{R}^{N \times d_{\text{model}}}$; block size $B_k$, number of selected blocks $k$. 1: $Q, K, V \leftarrow XW_q, XW_k, XW_v$ // $(N, H_q, d_h), (N, H_{kv}, d_h), (N, H_{kv}, d_h)$ 2: $Q^{\text{idx}}, K^{\text{idx}} \leftarrow \text{stopgrad}(X)W_q^{\text{idx}}, \text{stopgrad}(X)W_k^{\text{idx}}$ // $(N, H_{kv}, d_{\text{idx}}), (N, 1, d_{\text{idx}})$; detached 3: $M^{\text{idx}} \leftarrow \text{BlockMaxPool}(Q^{\text{idx}}, K^{\text{idx}}, B_k)$ // $(N, H_{kv}, B)$; per-group, causal 4: $I \leftarrow \text{TopK}(M^{\text{idx}}, k)$ // selected block indices; local block included 5: $O \leftarrow \text{TopKAttn}(Q, K, V, I)$ // $(N, H_q, d_h)$; attends to selected blocks 6: output $\leftarrow OW_o$ // $(N, d_{\text{model}})$ 7: $\mathcal{L}_{\text{KL}} \leftarrow \text{KLdiv}(Q^{\text{idx}}, K^{\text{idx}}, \text{stopgrad}(Q), \text{stopgrad}(K), I)$ // over tokens induced by $I$ 8: **return** output, $\mathcal{L}_{\text{KL}}$

MSA’s FLOPs consist of a cheap index computation $O(H_{kv} d_{\text{idx}} N^2)$ plus a sparse attention term $O(H_q d_h N k B_k)$, contrasting with GQA’s $O(H_q d_h N^2)$ cost.

MSA achieves long‑context scaling by letting a tiny index branch select relevant blocks, letting the main attention operate on a fixed budget.

Sparse Kernel Implementation

Efficient GPU kernels make sparse prefill fast without sacrificing accuracy.

Dense attention spends cycles on tokens that never contribute useful information. In the sparse‑prefill setting the bottleneck becomes the kernel that selects and processes only the relevant key‑value blocks. Our design eliminates the softmax‑exp‑sum overhead and reshapes the computation to keep the GPU fully occupied.

The kernel writes only the $k$ selected KV blocks for each query, bypassing the dense softmax and thus reducing both compute and memory traffic.

How does the Sparse Prefill Kernel differ from a naïve dense attention kernel?

Instead of materialising a full $N\times N$ attention matrix and then discarding most entries, the kernel works on the raw scores, extracts only the $k$ most relevant KV blocks, and writes just those blocks. This avoids the $O(N^2)$ softmax and memory moves that dominate dense attention.

Lane 0 loads scores $\{0.1,\,0.4,\,0.2,\,0.5\}$ and inserts them into its local min‑heap of size 2, keeping $\{0.5,\,0.4\}$.

Lane 1 loads $\{0.3,\,0.6,\,0.0,\,0.2\}$, heap becomes $\{0.6,\,0.5\}$ after insertion.

… after all 32 lanes finish, each holds its top‑2 candidates.

A shuffle‑merge step merges the 64 candidates, discarding all but the global top‑2 values $0.9$ and $0.8$.

The final ordered list $[\,\text{block\_id}=7,\,\text{block\_id}=3\,]$ is written to global memory.

The min‑heap stays in registers for the root element, so only $2$ values per lane ever touch shared memory, dramatically cutting traffic compared to a full sort.

Index & TopK kernel – per‑warp min‑heap and shuffle‑merge.

Choosing the outer loop for sparse prefill changes the FLOPs‑to‑IO balance. A query‑outer loop gives a FLOPs/IO ratio of roughly $G$, while a KV‑outer loop yields $(2/3)B_k$, which is orders of magnitude larger for typical $B_k$.

KV‑outer sparse attention kernel – tile‑wise processing with reverse index.

The KV‑outer design forces a two‑phase forward: the first kernel writes per‑query partials to $O_{\text{buf}}$, the second kernel normalizes them using a per‑query log‑sum‑exp reduction and combines the weighted sums. This split avoids atomics and hides the launch latency between kernels.

**Figure 4.** Efficiency comparison between GQA and MSA under the shared experimental model configuration. The left subfigure reports the theoretical per-token attention-FLOPs. The middle and right subfigures report the measured implementation speedups for prefill and decode, respectively. All tests are conducted with 64 query heads, 4 key-value heads, and a head dimension of 128. MSA uses $B_k = 128$ and $k = 16$, corresponding to a selected budget of 2,048 tokens per query.

Empirical Evaluation

MSA delivers dense‑attention quality while slashing attention compute.

Recall the core claim: dense attention wastes compute on irrelevant tokens, while MiniMax Sparse Attention (MSA) selects only the most relevant key blocks via a lightweight index branch, preserving accuracy at far longer contexts.

MSA matches dense‑attention quality (within 0.3 % accuracy) while cutting attention FLOPs dramatically.

Table 2 shows competitive scores across language, math, image, video and long‑context benchmarks; Figure 4 reports the FLOPs reduction.

Training dynamics confirm that sparsity does not harm optimization: LM‑loss curves of native sparse pretraining (MSA‑PT) and the matched full‑attention run are virtually indistinguishable over the full 3 T‑token budget, and gradient‑norm trajectories stay within the same range. During conversion (MSA‑CPT) the indexer‑warmup quickly drives KL loss down before sparse attention is enabled, after which KL remains low and block‑recall stays high, indicating reliable recovery of the most important blocks.

**Figure 2.** Pretraining dynamics for the experiment model. LM loss and gradient norm are shown for Full Attention and MSA-PT over 3T training tokens. The inset in (a) zooms in on the final 50B-token window, where the two LM-loss curves remain nearly overlapping.

Main results (Table 2) demonstrate that both sparse variants remain broadly competitive with the Full‑Attention baseline. MSA‑PT excels on math, image, video, and long‑context retrieval benchmarks, suggesting that learning the sparse pattern from the start adapts representations effectively. MSA‑CPT preserves most of the dense checkpoint’s behavior, staying close on text, code, and perplexity metrics, making it a practical conversion path when a trained dense model is already available.

**Figure 3.** Sparse continued-pretraining dynamics. (a) Average KL loss during MSA-CPT. The solid segment denotes indexer warmup, and the dashed segment denotes sparse continued pretraining; the vertical dashed line marks the switch between the two stages. (b) Average block recall and score recall of the MSA-CPT indexer during sparse continued pretraining.

**Figure 7.** Evaluation-score deltas relative to the Full-Attention baseline for three indexer training signals in the pilot setting. Positive values indicate improvements over the baseline, and negative values indicate degradations.

Efficiency analysis (Figure 4) confirms the theoretical FLOPs savings translate into measurable runtime speedups that grow with context length. At 1 M tokens the FLOPs reduction reaches 28.4×; the observed speedup is smaller due to index‑construction overhead but still scales favorably as the dense baseline’s cost grows quadratically.

**Figure 10.** Per-layer entropy of the Main Branch attention distribution during early sparse training. Entropy drops rapidly in the first few hundred steps before partially recovering and stabilizing, motivating a brief full-attention warmup for the indexer.

Standard dense attention computes a full query‑key‑value interaction over the entire sequence, guaranteeing that every token can attend to every other token.

A simple sparse scheme that restricts each query to attend only to a fixed‑size local window of neighboring tokens.

MSA matches dense performance while significantly reducing FLOPs.

Ablations and Visualizations

Ablation results reveal which components of MSA are essential for performance.

Visualization of the Index Branch shows that each GQA group learns a distinct sparse pattern while sharing a common local diagonal and a sink column. The early‑layer heads (Layer 1) and late‑layer heads (Layer 18) both allocate most of their budget to nearby tokens, but differ in the few long‑range stripes they select.

**Figure a.** Layer 1, four GQA groups. Each group produces a different long-range selection pattern alongside the shared local diagonal and sink column.

**Figure b.** Layer 18, four GQA groups. Long-range selection sharpens into a few stripes per group; the four groups pick visibly different stripes.

The learned Index Branch also repeatedly selects the first key‑value block, creating an implicit attention sink. Across layers 4 and 24, every sampled head assigns a sizable fraction of its attention mass to the first token, even without an explicit forcing term.

**Figure 6.** Mean attention score on the first token for each attention head in Layer 4 and Layer 24. All heads allocate a significant fraction of attention to the first token, confirming a pervasive attention sink effect across heads and layers.

Training the Index Branch requires a gradient signal. We compare three configurations: LM‑only, KL‑only, and LM + KL. LM‑only preserves short‑context performance but fails on long‑context retrieval; KL‑only improves retrieval but harms short‑context ability; the combined LM + KL configuration balances both.

**Figure 8.** Training LM loss and gradient norm with and without detaching the KL gradient from the backbone. Detaching confines the auxiliary loss to the Index Branch and avoids the gradient spikes observed without detach.

**Figure 9.** General benchmark scores with and without detaching the KL gradient from the backbone. Detaching the auxiliary loss reduces the general ability degeneration observed when the KL gradient updates the backbone.

Early in training the Main Branch attention distribution sharpens rapidly, making sparse selection fragile. A short warmup phase runs the Main Branch with full attention while the Index Branch learns via KL supervision, after which sparse selection is enabled.

We test a learnable attention sink that competes with the first token in the softmax. The sink captures attention in some heads but does not fully replace the first‑token sink, and perplexity gains are inconsistent.

**Figure 11.** Evaluation results of MSA with and without index warmup. Within the reported training range, index warmup improved scores on general tasks and long-context retrieval.

**Figure 12.** Attention received by the learnable sink and the first token after introducing a GPT-OSS-style sink parameter. In some heads, the learnable sink absorbs most of the sink-like attention; in others, the first token remains the dominant sink, indicating that the explicit sink does not fully eliminate first-token sink behavior.

**Figure 13.** Perplexity comparison with and without the learnable attention sink on downstream agent-oriented evaluations. Lower perplexity is better. Adding the learnable sink does not provide a consistent advantage over the default MSA design.

Dynamic sparse selection outperforms a FLOP‑matched sliding‑window baseline. With the same token budget, the sliding‑window model exhibits higher perplexity throughout training, indicating that content‑dependent selection is more effective than fixed positional windows.

Block‑size ablations (Table 4) show that increasing the key‑value block size from 32 to 128 has minimal impact on perplexity and long‑context retrieval scores, suggesting that larger blocks can be used to improve kernel efficiency without harming quality.

Removing the forced first‑block and local‑window priors (Table 5) leaves model quality largely unchanged; reasoning, code, and PPL metrics remain stable, and long‑context retrieval is comparable, indicating that the model learns these patterns autonomously.

Ablating the Index Branch value head (Table 6) yields mixed but small differences across benchmarks; the no‑value variant sometimes outperforms the with‑value design, confirming that the value head is not critical once warmup is applied.

Context and Conclusion

Related work context and a concise wrap‑up of the paper’s contributions.

Long‑context efficiency has spurred two broad families of work: (1) replacing dense softmax attention with cheaper linear or recurrent substitutes, and (2) keeping softmax but limiting its receptive field. Linear attention, state‑space models, and hybrid stacks (MiniMax) fall in the first family, while fixed‑pattern schemes such as local windows, global tokens, and sliding‑window sinks belong to the second. Adaptive sparse attention methods (e.g., H2O, SnapKV, Quest, MInference, FlexPrefill, Inf LLM) construct input‑dependent supports, but they inherit the full‑attention training cost and often retain a dense‑attention phase at inference time.

We introduced MSA, a sparse‑attention mechanism co‑designed with GQA: an Index Branch selects a small set of key‑value blocks per GQA group, and the Main Branch performs softmax attention restricted to those blocks. The selector is trained via a KL alignment loss against the Main Branch under a two‑stage warmup schedule, yielding a 28.4× per‑token compute reduction at 1 M context while preserving GQA‑level accuracy at 109 B‑MoE scale. Future work should close the remaining long‑context retrieval gap by expanding the selection budget or enriching the indexer scoring function, and explore selector‑only designs beyond pretraining, such as in reinforcement‑learning fine‑tuning.

Read the original paper

Open the simplified reader on Paperglide