You Only Index Once: Cross-Layer Sparse Attention with Shared Routing

Cross-Layer Sparse Attention (CLSA) accelerates long-context LLM decoding by sharing a single routing index across layers.

How can we reduce the KV-cache footprint and inference latency of long-context LLMs by sharing a single routing index across all cross-decoder layers?

Long-context LLM inference is bottlenecked by the high cost of recomputing token-level sparse routing across every decoder layer, which negates the speed gains of sparse attention. The authors introduce Cross-Layer Sparse Attention (CLSA), which computes a single top-k routing index for the shared KV cache and reuses it across all cross-decoder layers. This approach amortizes the expensive routing overhead while preserving the fine-grained selectivity of token-sparse attention. At 128K context, this architecture achieves up to 7.6× faster decoding and 17.1× higher overall throughput compared to standard Transformer baselines, with negligible impact on model quality.

Paper Primer

CLSA is a structural optimization for KV-sharing architectures like YOCO. It replaces dense cross-attention with a routed sparse mechanism: compute the routing index once, then broadcast that index to all cross-decoder layers so they attend only to the same subset of salient tokens.

CLSA achieves near-lossless performance compared to dense baselines.

Evaluation across benchmarks including ARC-Challenge, GSM8K, and DROP shows CLSA matches or exceeds dense YOCO performance, while long-context RULER retrieval remains competitive. 17.1× overall throughput improvement at 128K context.

Shared routing eliminates the per-layer top-k bottleneck.

By computing the top-k index once for the shared KV cache, the model avoids the irregular, expensive routing operations that typically dominate per-layer latency in token-sparse architectures.

Why does this method require a multi-layer distillation objective?

Because the shared routing index must serve the needs of the entire decoder stack simultaneously, the indexer is trained via distillation to identify consensus salient tokens that remain important across all layers.

How does this differ from standard token-sparse attention?

Standard token-sparse methods recompute top-k routing independently in every layer, which is computationally expensive and poorly suited for GPU execution; CLSA computes this index exactly once per step.

The KV-Cache Bottleneck

Long‑context LLMs are bottlenecked by repeated routing computation, motivating shared‑index sparse attention.

Long‑context inference in modern LLMs is increasingly limited by decoding cost, because models must repeatedly attend to an ever‑growing KV cache while generating long reasoning chains. Existing sparse‑attention methods either use block‑structured sparsity, which speeds up GPU execution but coarsens the attention pattern, or token‑sparse routing, which preserves fine‑grained relevance yet incurs a costly top‑k selection that must be recomputed for every cross‑decoder layer. Standard cross‑attention recomputes routing indices for each layer, so CLSA shares a single routing index across all cross‑decoder layers, amortizing the routing cost and shrinking the KV cache.

The key design choice is balancing how far the context can extend against the per‑token inference cost.

The CLSA Mechanism

Describes how CLSA shares a routing index across decoder layers to cut attention cost.

Figure 1 visualizes our method built on the YOCO architecture. The self‑decoder creates a single KV cache, and a lightweight indexer produces a shared routing index that all cross‑decoder layers reuse.

**Figure 1.** Overview of cross-layer sparse attention. The self-decoder first produces a shared KV cache, which is computed only once and then reused by all subsequent cross-decoder layers. During this stage, a shared query-aware indexer jointly generates the routing queries and keys and computes a token-level sparse top-k index for each query token. This sparse index is also produced only once and is shared across the following cross-decoder layers, allowing them to reuse the same selected KV positions instead of recomputing layer-specific sparse routing.

YOCO splits the model into a self‑decoder that builds a shared KV cache and a cross‑decoder that reads from that cache.

A single‑head query‑aware module projects the shared hidden states to produce a token‑level top‑$k$ index that is reused by all cross‑decoder layers.

CLSA binds the routing index to the shared KV cache, so each cross‑decoder layer attends only to the selected tokens, avoiding per‑layer index recomputation.

Compute $Q_{\text{idx}} = \text{HW}_Q H = [1,0,1,0]^{\top}$ and $K_{\text{idx}} = \text{HW}_K H = [0,1,1,0]^{\top}$.

Form scores $Q_{\text{idx}} K_{\text{idx}}^{\top} = [0,1,1,0]$, then take the top‑$2$ indices → $\text{idx} = \{2,3\}$ (the second and third tokens).

Extract the selected key/value rows: $K_{\text{St}} = K_{\{2,3\}}$, $V_{\text{St}} = V_{\{2,3\}}$.

For layer 1, query $Q^{(1)}$ attends only to rows $2$ and $3$, producing $O^{(1)}_{\text{t}}$; layer 2 repeats the same restricted attention with its own $Q^{(2)}$.

All cross‑decoder layers reuse the same two token positions, so the expensive top‑$k$ computation runs only once while each layer still gets a distinct query.

Multi-Layer Distillation

Multi‑Layer Distillation trains a shared routing index to satisfy several decoder layers at once.

Sharing a single routing index across decoder layers cuts compute, but the selected tokens must be useful to every layer rather than optimal for just one. A dedicated distillation loss forces the index to capture the common signal among layers.

We train the shared index by aligning its token‑selection distribution with the average of each layer’s attention distribution, so the same set of tokens works for all layers.

Layer 1: $[0.5,\,0.2,\,0.2,\,0.1]$

Layer 2: $[0.4,\,0.3,\,0.2,\,0.1]$

Layer 3: $[0.45,\,0.25,\,0.15,\,0.15]$

Average (target) distribution = $[0.45,\,0.25,\,0.18,\,0.12]$.

The indexer selects the two highest‑probability tokens (positions 1 and 2), forming a sparse mask $[1,1,0,0]$.

KL divergence between the mask‑induced distribution (uniform over the two selected tokens) and the target is $0.31$, which the optimizer minimizes.

This tiny example shows how the shared index compromises: it captures the common high‑probability tokens across layers, but the exact probabilities of each layer are flattened to a uniform sparse distribution.

Training the Shared Indexer

Defines the distillation loss that trains the shared routing index to match cross‑layer attention.

Dense cross‑attention recomputes routing at every decoder layer, inflating KV‑cache cost. The training objective aligns a single routing index with the consensus of all layers.

First the indexer learns a stable routing pattern while the backbone is frozen; then the whole model is fine‑tuned jointly with language modeling, letting the backbone adapt to the sparse attention induced by the indexer.

Layer 1, Head 1: $a^{(1,1)} = [0.4, 0.3, 0.2, 0.1]$

Layer 1, Head 2: $a^{(1,2)} = [0.3, 0.4, 0.2, 0.1]$

Layer 2, Head 1: $a^{(2,1)} = [0.35, 0.35, 0.15, 0.15]$

Layer 2, Head 2: $a^{(2,2)} = [0.25, 0.45, 0.15, 0.15]$

Aggregate $\bar{A}$ by averaging the four vectors: $\bar{A}= [0.325, 0.375, 0.175, 0.125]$.

Indexer logits $I = [0.6, 0.2, 0.1, 0.1]$ → $\text{softmax}(I) = [0.55, 0.20, 0.13, 0.12]$.

KL divergence for each token (using the stop‑gradient target $\bar{A}$) yields $L_{\text{KD}} \approx 0.07$.

The example shows how averaging across layers smooths noisy head‑specific peaks, giving the indexer a clear consensus to imitate.

Inference Efficiency Gains

Sharing a single routing index cuts inference cost while keeping quality.

Standard cross‑attention recomputes routing at every decoder layer; Cross‑Layer Sparse Attention (CLSA) shares a single routing index across all cross‑decoder layers, amortizing the indexer cost.

YOCO (CLSA) matches dense baselines on most tasks while improving ARC‑C, GSM8K, and DROP.

Table 2 shows ARC‑C +0.004, GSM8K +0.040, and DROP +0.004 over the dense YOCO baseline.

**Table.** Comparison of computational complexity for different models across KV Cache Memory, Prefilling Time, and Decoding Time.

Performance Benchmarks

YOCO (CLSA) matches dense baselines while improving key benchmark scores.

We evaluate three 4 B‑scale models—standard Transformer, dense YOCO, and YOCO with cross‑layer sparse attention (CLSA)—using identical width/depth and the same training hyper‑parameters (8 M tokens per step, up to 32 K context, learning rates 3×10⁻⁴/3×10⁻⁵). The setup mirrors the configurations described in Section C, ensuring a fair comparison across benchmarks.

A standard decoder‑only Transformer that applies full self‑attention and cross‑attention at every layer.

YOCO separates self‑decoder and cross‑decoder layers but retains dense attention in both, employing a sliding‑window self‑attention of size 512.

**Table 1.** Performance comparison of Transformer, YOCO (Dense), and YOCO (CLSA) across various benchmarks including ARC-C, BBH, GSM8K, HellaSwag, HumanEval, MMLU, DROP, and WinoGrande.

YOCO (CLSA) improves the RULER average score at 16 K context by +2.0 points over the standard Transformer.

Table 3 reports an average of 98.4 for CLSA versus 96.4 for the Transformer baseline.

This table lists various hyper-parameters and their corresponding values used in the model configuration.

Throughput and Latency Analysis

CLSA speeds decoding up to 7.6× faster than the Transformer at 128K context.

CLSA speeds decoding up to 7.6× compared to the standard Transformer at 128K context.

Measured on NVIDIA B200 GPUs, CLSA achieves 7.6× higher decode throughput than the Transformer when the context length is 128K.

During prefill, YOCO (Dense) and YOCO (CLSA) are both substantially faster than the Transformer because the decoder avoids quadratic full‑context attention, and the two YOCO variants remain close in performance.

Top‑k routing is irregular and poorly matched to the wide, data‑parallel execution that dense matrix multiplies exploit; a standalone top‑k pass at 128K can take time comparable to dense attention despite involving far fewer arithmetic operations.

**Figure 3.** Inference throughput relative to the Transformer for prefill and decode across different context lengths. Both YOCO variants substantially accelerate prefill, while CLSA provides the largest decoding gains and widens its advantage as the context grows.

**Figure 4.** 128K latency analysis for different components. After amortizing routing, the amortized top-k becomes efficient. Without amortization, the unamortized top-k stage can be comparable to or even larger than dense attention.

**Figure 5.** Per-layer latency comparison across representative sparse attention methods and dense baselines at 128K context. CLSA achieves the lowest latency by amortizing routing across cross-decoder layers.

Table 4 reports attention coverage and cross‑entropy loss for three domains under sparse selection; selecting 2048 tokens recovers 84 % of dense attention mass while incurring only ~0.57 loss for StarCoder, 1.75 for Books, and 1.08 for ArXiv.

Attention Sparsity Analysis

We contrast prior sparse‑attention and hybrid methods, highlighting how CLSA’s shared routing differs.

Table 4 quantifies how sparse attention recovers dense attention mass and model quality as the token budget grows.

**Table 4.** Attention coverage and cross-entropy loss under sparse selection across selected-token budgets. Larger budgets recover more dense attention mass. Importantly, sparse selection introduces negligible cross-entropy loss degradation. 2048 selected tokens provide a favorable trade-off across domains.

These results show a small active subset can capture most of the attention mass and preserve language‑modeling quality, even occasionally surpassing dense YOCO on the StarCoder domain.

Prior work on sparse attention falls into three strands: training‑aware dynamic sparsity, cross‑layer token‑reuse techniques, and hybrid architectures that blend attention with alternative sequence operators.

Training Stability and Curves

Dense-stage curves confirm YOCO’s competitiveness and validate CLSA’s shared routing index.

Dense‑stage training curves on HumanEval, DROP, and HellaSwag demonstrate that YOCO stays on par with the Transformer throughout training. Because CLSA reuses the shared routing index on top of the YOCO backbone, these curves also verify that YOCO supplies a strong dense attention foundation, even for retrieval‑style tasks like DROP. Consequently, we can begin from a solid dense model and apply a near‑lossless sparse‑attention adaptation instead of redesigning the pretraining pipeline.

**Figure 7.** Dense-stage training curves on HumanEval, DROP, and HellaSwag as a function of training tokens. YOCO remains competitive with the Transformer throughout training, supporting its use as a stable dense backbone before sparse adaptation.

Training Hyper-parameters

Lists the exact optimizer and schedule settings for dense pretraining and sparse adaptation.

The table compares the performance of three different architectures—Transformer, YOCO (Dense), and YOCO (CLSA)—across five different context lengths ranging from 8K to 128K.

**Table 6.** Shared optimization settings across all training stages.

Model Architectural Details

Model sizes, hyperparameters, and latency details for all variants.

Training uses a batch size of 8 M tokens, Adam optimizer with $\beta$ = (0.9, 0.95), $\epsilon$ = 10⁻⁸, and weight decay = 0.1.

All evaluated models share the same overall width and depth: hidden size 2560, FFN width 7680, 32 layers, 20 attention heads, 4 KV heads, and head dimension 128. QK normalization is enabled and weight tying is disabled.

Latency numbers in Table 9 were derived by scaling raw per‑layer timings to milliseconds and averaging across the 32 layers; for YOCO (Dense) the attention term mixes sliding‑window attention in the first 16 layers with dense cross‑decoder attention in the second 16, while YOCO (CLSA) mixes sliding‑window and CLSA attention. The top‑k column reports the amortized per‑layer routing cost, obtained by dividing the one‑off routing latency by the full depth.

**Figure 6.** Per-layer latency breakdown at 8K, 32K, and 128K context. For YOCO (Dense), the attention cost is averaged over SWA and dense attention layers. For YOCO (CLSA), the attention cost is averaged over SWA and CLSA layers, and the top-k cost is amortized across cross-decoder layers. At 128K context, the amortized top-k stage takes about 0.08 ms per layer.

Raw Throughput Data

Provides the raw inference throughput numbers underlying Figure 3.

The three tables below list the absolute token‑per‑second rates measured for pre‑fill, decode, and end‑to‑end generation. Each entry reports the throughput of a standard Transformer baseline, the dense‑cache variant YOCO (Dense), and the cross‑layer sparse variant YOCO (CLSA) at context lengths of 8 K, 16 K, 32 K, 64 K, and 128 K. All experiments used the same hardware (NVIDIA B200 GPUs) and software stack as the main paper.

**Table 11:** Raw decode throughput (tokens/s) used in the right panel of Figure 3

**Table 12.** Raw overall throughput (tokens/s) measured under the same setup as Figure 3.

Questions & answers

What is the main contribution of this paper?

The paper introduces Cross-Layer Sparse Attention (CLSA), a method that computes a single top-k sparse routing index for the shared KV cache and reuses it across all cross-decoder layers, eliminating the need to recompute routing at every layer and amortizing the indexing overhead across the entire decoder stack.

What problem does CLSA address?

CLSA addresses the bottleneck in long-context LLM inference where recomputing token-level sparse routing independently at every decoder layer is computationally expensive and poorly suited for GPU execution, negating the speed benefits of sparse attention.

Why is recomputing top-k routing at every layer problematic?

Top-k routing is irregular and poorly matched to the wide, data-parallel execution that dense matrix multiplies exploit; a standalone top-k pass at 128K context can take time comparable to dense attention despite involving far fewer arithmetic operations.

How does CLSA work mechanically?

CLSA is built on the YOCO architecture, where a self-decoder creates a single shared KV cache; a lightweight indexer then produces one routing index that all cross-decoder layers reuse, so each layer attends only to the same subset of salient tokens rather than recomputing its own selection.

Why does CLSA require a multi-layer distillation objective during training?

Because the single shared routing index must serve all cross-decoder layers simultaneously, a dedicated distillation loss is used to train the indexer to identify consensus salient tokens that remain important across all layers, rather than tokens optimal for just one layer.

How does CLSA differ from standard token-sparse attention methods?

Standard token-sparse methods recompute top-k routing independently in every decoder layer, whereas CLSA computes this index exactly once per decoding step and broadcasts it to all cross-decoder layers, amortizing the routing cost.

What speedups does CLSA achieve?

At 128K context, CLSA achieves up to 7.6× faster decoding and 17.1× higher overall throughput compared to a standard Transformer baseline.

What models and benchmarks were used to evaluate CLSA?

Three 4B-scale models were evaluated—a standard Transformer, dense YOCO, and YOCO with CLSA—on benchmarks including HumanEval, DROP, HellaSwag, and language modeling on StarCoder, Books, and ArXiv domains, at context lengths from 8K to 128K.

What are the key architectural details of the evaluated models?

All models share hidden size 2560, FFN width 7680, 32 layers, 20 attention heads, 4 KV heads, and head dimension 128, trained with a batch size of 8M tokens, Adam optimizer (β=(0.9, 0.95), ε=10⁻⁸), weight decay 0.1, and learning rates of 3×10⁻⁴/3×10⁻⁵.

How well does sparse selection recover dense attention quality?

Selecting 2048 tokens recovers 84% of dense attention mass, with cross-entropy losses of approximately 0.57 for StarCoder, 1.75 for Books, and 1.08 for ArXiv, and sparse attention occasionally surpasses dense YOCO on the StarCoder domain.

How does CLSA perform during prefill compared to decoding?

During prefill, both YOCO (Dense) and YOCO (CLSA) are substantially faster than the standard Transformer because the decoder avoids quadratic full-context attention, and the two YOCO variants remain close to each other in prefill performance; the largest gains for CLSA appear during decoding.

What hardware was used for experiments?

All throughput and latency experiments were conducted on NVIDIA B200 GPUs; the paper does not specify the exact number of GPUs used.

What are the limitations of CLSA?

The paper does not explicitly enumerate limitations, but the shared routing index is a consensus approximation that is not optimal for any single layer individually, and the approach is designed specifically for KV-sharing architectures like YOCO rather than standard Transformers.

How does CLSA relate to prior sparse attention work?

The paper situates CLSA among three strands of prior work: training-aware dynamic sparsity, cross-layer token-reuse techniques, and hybrid architectures blending attention with alternative sequence operators; CLSA is distinguished by sharing a single routing index across all cross-decoder layers within a KV-sharing backbone.

What is the YOCO architecture and why is it the foundation for CLSA?

YOCO (You Only Cache Once) is a KV-sharing architecture where a self-decoder produces a single shared KV cache used by all cross-decoder layers; CLSA exploits this shared cache by computing one routing index that all cross-decoder layers can reuse, making the combination natural and efficient.

Does CLSA hurt model quality?

The paper reports negligible impact on model quality; training curves on HumanEval, DROP, and HellaSwag show YOCO stays on par with the standard Transformer throughout training, and sparse selection at 2048 tokens recovers 84% of dense attention mass.

How is the latency per layer computed in the paper?

Latency numbers were derived by scaling raw per-layer timings to milliseconds and averaging across the 32 layers; for YOCO (CLSA), the attention term mixes sliding-window attention in the first 16 layers with CLSA attention in the second 16, and the top-k column reports the amortized per-layer routing cost.

What training procedure is used to apply CLSA to a pretrained model?

The paper describes starting from a solid dense YOCO model and applying a subsequent sparse fine-tuning stage using the multi-layer distillation objective; the dense-stage training uses up to 32K context with 8M tokens per step.

Key terms

CLSA (Cross-Layer Sparse Attention)
A sparse attention mechanism that computes a single top-k token routing index once per decoding step and reuses it across all cross-decoder layers, rather than recomputing routing independently at each layer.
YOCO (You Only Cache Once)
A KV-sharing Transformer architecture in which a self-decoder produces one shared KV cache that all subsequent cross-decoder layers attend to, avoiding redundant KV computation.
top-k routing
A sparse selection mechanism that identifies and attends to only the k most relevant tokens in the KV cache, discarding the rest to reduce computation.
shared routing index
A single set of token indices computed once by the indexer and broadcast to all cross-decoder layers so they all attend to the same subset of tokens.
KV cache
A stored record of key and value vectors from previously processed tokens, reused during autoregressive decoding to avoid recomputing attention from scratch at each step.
token-sparse attention
An attention variant that selects a small subset of tokens to attend to based on relevance scores, rather than attending to all tokens in the context.
block-structured sparsity
A form of sparse attention that skips contiguous rectangular blocks of the attention matrix, which is GPU-friendly but coarsens the attention pattern compared to token-level selection.
multi-layer distillation
A training objective that aligns the shared routing index with the consensus attention patterns across all decoder layers simultaneously, so the selected tokens are useful to every layer.
cross-decoder layer
A decoder layer in the YOCO architecture that attends to the shared KV cache produced by the self-decoder, as opposed to computing its own local KV representations.
self-decoder
The component of the YOCO architecture that processes the input and produces the single shared KV cache consumed by all cross-decoder layers.
attention coverage
The fraction of the total dense attention mass that is captured by a sparse selection of tokens, used as a proxy for how well sparse attention approximates full attention.
sliding-window attention
An attention pattern where each token attends only to a fixed-size local window of nearby tokens, reducing quadratic complexity to linear in sequence length.
amortized routing cost
The per-layer cost of computing the routing index when that cost is divided across all layers that reuse the same index, making it much cheaper than computing a fresh index at every layer.
QK normalization
A technique that applies layer normalization to the query and key vectors before computing attention scores, improving training stability.
indexer
A lightweight module in CLSA that computes the shared top-k routing index used by all cross-decoder layers to select which tokens to attend to.

Read the original paper

Open the simplified reader on Paperglide