Pretraining Recurrent Networks without Recurrence

Supervised Memory Training (SMT) decouples memory representation from dynamics to enable parallel, stable RNN training.

How can we train recurrent neural networks (RNNs) using supervised learning on memory states instead of backpropagation through time (BPTT)?

Backpropagation Through Time (BPTT) forces Recurrent Neural Networks (RNNs) to unroll sequences, creating sequential bottlenecks and unstable gradients that make learning long-range dependencies difficult. Supervised Memory Training (SMT) sidesteps this by training a Transformer encoder to act as a "teacher" that generates optimal memory states, reducing RNN training to a parallel, one-step supervised learning task. On language and pixel-sequence modeling, SMT outperforms BPTT in capturing long-range dependencies while requiring significantly less sequential computation.

Paper Primer

SMT treats the past as a set of timestamped events rather than a rigid sequence, allowing a Transformer encoder to learn a "predictive state"—a compressed representation of the past sufficient to predict the future. The RNN then learns to update this memory state in one-step transitions, effectively decoupling the "what to remember" (representation) from the "how to update" (dynamics).

SMT provides a stable $O(1)$ gradient path for long-range credit assignment.

Unlike BPTT, where gradients must propagate through $O(T)$ unrolled steps, SMT's gradient path length is independent of sequence length, preventing the vanishing/exploding gradient issues common in recurrent training. SMT outperforms BPTT across all synthetic tasks, including retrieval, string copying, and state tracking, where BPTT struggles as sequence length increases.

To bridge the gap between the teacher-generated memory and the RNN's own autoregressive rollouts, the authors introduce DAgger Memory Training (DMT). This fine-tuning phase uses on-policy imitation learning to correct the drift that accumulates when the RNN relies on its own predicted memories rather than the encoder's.

Why is this approach better than just using a Transformer for everything?

While Transformers are powerful, they lack a compressed memory of the past, causing their memory and compute costs to grow linearly with sequence length. RNNs trained with SMT maintain a fixed-size memory and $O(1)$ inference cost, making them more efficient for long-horizon tasks.

Does SMT replace BPTT entirely?

SMT is primarily a pretraining method. Because the teacher encoder is constrained by its own parallel architecture, the RNN may require lightweight post-training (like DMT) to adapt to specific tasks and achieve expressivity beyond the teacher's limitations.

SMT enables the scaling of nonlinear RNNs by transforming them into parallel-trainable models, offering a path to build temporal abstractions of past experience that are both computationally efficient and capable of long-range reasoning.

The Problem with Recurrent Training

RNN training struggles with long‑range credit assignment, and SMT reframes it as supervised one‑step memory learning.

Backpropagation Through Time (BPTT) trains recurrent networks by unrolling the computation graph across the sequence and propagating gradients backward. Because the gradient must travel through up to $T$ timesteps, the memory cost grows with sequence length and the gradients become unstable, often vanishing or exploding. This creates a fundamental trade‑off: longer sequences give richer context but degrade training stability.

Supervised Memory Training (SMT) sidesteps recurrent credit propagation by turning the problem into supervised learning on one‑step memory transition labels $(m_t, x_{t+1}) \rightarrow m_{t+1}$. The labels are obtained from a teacher Transformer encoder that learns a predictive state—a representation that captures exactly the information needed to forecast the future. With this decoupling, the RNN only needs to learn how to update its memory one step at a time, yielding a constant‑length ($O(1)$) gradient path regardless of sequence length, which enables fully parallel training.

**Figure 1.** **BPTT vs SMT.** Left: BPTT trains an RNN by recurrently unrolling the “updater” network in time, and backpropagating gradients through the entire graph. Right: Supervised Memory Training (SMT) trains an RNN with supervised learning on one-step memory transition labels, which are generated by a Transformer encoder-decoder model pair trained to produce predictive states. SMT is fully time-parallel. In SMT, the longest gradient path between tokens is $O(1)$ (compared to $O(T)$ in BPTT), which stabilizes gradients, making learning long-range dependencies qualitatively easier.

BPTT's reliance on unrolling creates a fundamental trade‑off between sequence length and training stability.

Supervised Memory Training

We describe how SMT and DMT replace BPTT by supervising memory states with a teacher encoder.

Standard BPTT forces the RNN to unroll over the entire sequence, inflating memory use to $O(MT)$ and exposing gradients to vanishing or exploding dynamics. Those two drawbacks motivate a different training regime that separates memory representation from its dynamics.

SMT teaches the RNN to copy a “teacher” memory that a parallel encoder has already compressed, turning the recurrent update into a simple supervised prediction.

How does SMT differ from ordinary behavior cloning on hidden states?

In ordinary cloning the RNN tries to mimic its own hidden state, which is still learned end‑to‑end. SMT instead provides an external “teacher” memory $m_t$ generated by a parallel encoder, so the RNN only learns a deterministic mapping $ (m_t, x_{t+1}) \rightarrow m_{t+1}$.

Compute the squared error for the first transition: $\|m_1 - \hat{m}_1\|^2 = (1-0.8)^2 + (2-1.9)^2 = 0.05$.

After the RNN updates, the next teacher memory is $m_2=(2,4)$; the RNN predicts $\hat{m}_2=(1.9,3.8)$.

Second error: $\|m_2 - \hat{m}_2\|^2 = (2-1.9)^2 + (4-3.8)^2 = 0.05$.

Both errors are small, so the total supervised loss $L_{\text{SMT}}$ is $0.10$ for the two steps.

Even with a tiny memory vector, the supervised loss directly measures how well the RNN copies the teacher’s representation, making the learning signal dense and stable.

DMT fine‑tunes the RNN on its own predicted memories, correcting the drift that accumulates when the model is unrolled autoregressively.

Why is DMT still considered “lightweight” if it unrolls the RNN over the whole sequence?

Because the loss only compares the RNN’s memory to the pre‑computed teacher memory; gradients do not need to propagate through the encoder, and the credit path for long‑range dependencies is already encoded in $m_t$, so the computational overhead is essentially a single forward pass.

**Figure 15.** Model Architecture for SMT. Left: The encoder reads the input context tokens and a set of learned register tokens, and outputs the memory, $m_t$, which is a set of memory tokens. The decoder takes in this memory and the future input tokens and predicts the future output tokens, using a causal mask. This setup forces information from the context to be compressed into a memory that is useful for predicting the future outputs, given future inputs. Middle: Our RNN maps $(m_t, x_{t+1})$ to $m_{t+1}$ using a Transformer-backbone. Since the memory is a list of tokens and the input is a token, we simply use a full attention Transformer to transform the current memory into the next timestep's memory. Right: Readout is performed by a full attention Transformer over the memory tokens.

**Figure 2.** **SMT vs DMT.** SMT trains the RNN with behavior cloning on the encoder-generated memory states (off-policy imitation learning). DMT unrolls the RNN with its own memory states and then imitates the encoder trajectory (on-policy imitation learning). Figure design inspired by Jacobs et al. [59].

**Table 1.** Resource requirements. $T$ is token sequence length. $T_c$ is SMT encoder context length. For RNNs, $M$ is the memory state size. We ignore log terms for simplicity. LA denotes linear attention (in its parallel and recurrent form). Complexity classes are from Merrill et al. [84].

Empirical Performance and Scaling

SMT→DMT reduces test loss by 0.30 when scaling context length, beating BPTT on long‑horizon tasks.

When the context length $T_c$ is increased from 16 to 512 on TinyStories, SMT→DMT lowers test loss by 0.30 compared with BPTT.

Figure 7 shows a smooth loss drop from ≈0.85 to ≈0.55 as $T_c$ grows, while the BPTT curve remains above 0.80.

We evaluate three training regimes—BPTT, SMT, and SMT→DMT—on nonlinear RNNs built on Transformer, MLP, and GRU backbones. Datasets include TinyStories (character‑level language modeling), MNIST pixel‑sequence, and Sketchy (sparse line‑art). Synthetic tasks isolate gradient stability, memory capacity, state tracking, associative recall, and in‑context learning. All experiments exclude linear RNNs and Transformers, which belong to a different model class.

**Figure 3.** Synthetic Task Experiments. We evaluate BPTT, SMT, and SMT→DMT using five synthetic tasks with various settings to probe different properties of the algorithms. * signifies that the SMT Encoder is the teacher Transformer (not an RNN) and is used only as a reference. Across all tasks and task settings, SMT→DMT outperforms BPTT, signaling that SMT has better gradient properties, memory utilization, state tracking, associative recall, and in-context learning than BPTT.

**Figure 4.** Attneave’s MNIST Generation. BPTT fails to effectively capture the long-range dependencies required for pixel sequence modeling, even with a GRU. SMT→DMT captures these dependencies with a non-gated RNN architecture. More samples are in Appendix Figure 17.

**Figure 5.** Attneave's Sketchy Generation SMT→DMT captures the stroke structure of human-drawn sketches through only pixel sequence modeling on sparse images. More samples are in Appendix Figure 18.

**Figure 6.** **Sequential Compute and Data Efficiency.** We sweep training hyperparameters for BPTT, SMT, and SMT→DMT and plot the resulting runs’ performance along sequential compute (SeqFLOPs) used and data processed (Tokens), across different RNN architectures and datasets. Runs are capped at one day on an H200 GPU. * signifies that the SMT Encoder is the teacher Transformer (not an RNN) and is used only as a reference. Generally, SMT and SMT→DMT are more efficient than BPTT in sequential compute, and around the same or better efficiency in data.

**Figure 7.** Scaling Context and Memory. SMT→DMT shows smooth performance improvements as you increase the context length and the memory size in TinyStories.

**Figure 8. Scaling Model Size.** Sweeping the width and depth of the RNN and teacher shows smooth performance improvements in TinyStories. The RNN imitates the teacher performance better at larger scale.

**Figure 9.** Scaling Laws for Compression. We plot iso-loss contours for SMT-trained encoder models across a range of memory state sizes and training compute budgets. For a fixed target performance, SMT can achieve higher compression (smaller memory size) using additional compute. This result suggests a new property to scale when given more training compute: memory state compression.

SMT/DMT consistently outperforms BPTT on long‑horizon tasks while maintaining better scaling properties.

Context and Prior Approaches

We survey prior approaches to training recurrent models and situate our method among them.

Training RNNs via BPTT requires unrolling the network and back‑propagating through time, which is memory‑intensive; SMT instead trains the RNN to predict teacher‑generated memory states, turning recurrence into a supervised problem. Early work on recurrent networks was motivated by their resemblance to biological brains and their applicability to any sequential task. Researchers explored many learning algorithms, from random guessing to evolutionary methods, Hebbian learning, and real‑time recurrent learning.

BPTT remains the only widely adopted training algorithm, but it suffers from unstable gradients that vanish, explode, or show high variance. Architectural tweaks such as residual connections and gating mechanisms gave rise to the LSTM and GRU families, while orthogonal weight parameterizations were introduced to prevent exponential growth or decay across time.

Other research directions added external memory modules, hierarchical modeling, and various ways to increase expressivity without sacrificing stability. More recently, linear state‑space models, linear attention mechanisms, and fully nonlinear RNN designs have re‑emerged, often appearing in diffusion‑based architectures, looped Transformers, and reasoning‑oriented networks.

Transformers introduced time‑parallel training, which leveraged modern hardware to scale sequence modeling dramatically. Linear RNNs can be parallelized efficiently using the associative scan algorithm, eliminating the sequential bottleneck of classic recurrent updates.

Parallelizing nonlinear RNNs has been tackled by formulating the forward pass as an iterative optimization problem and solving the resulting system of equations with Newton’s method. This approach inherits BPTT’s $O(T)$ credit‑path length and adds convergence concerns, whereas SMT achieves an $O(1)$ credit‑path by training a predictive state encoder.

The computational‑complexity view ties a model’s sequential depth to the class of problems it can solve; models with constant or logarithmic depth per layer (e.g., Transformers, linear RNNs) are provably limited to low‑depth tasks. Nonlinear RNNs, whose depth grows with sequence length, can in principle address problems requiring deep sequential computation.

Predictive State Representations (PSRs) model partially observed dynamical systems by predicting future observations; they have been incorporated into RNNs, but those works still rely on BPTT and thus lack time‑parallel training.

Cross‑architecture teacher‑student distillation methods, such as Next‑Latent Prediction (NextLat), train an RNN with memory‑state supervision from a Transformer and under certain hyper‑parameters become equivalent to SMT. The Recurrent Transformer attains an $O(1)$ gradient path by attending to all past hidden states, but its memory grows unboundedly during inference.

Because the teacher model in SMT is time‑parallel, its expressivity may limit the RNNs it trains, sometimes necessitating BPTT fine‑tuning. Our current SMT variant trains only a single memory state per sequence; training all states showed no benefit at our scale, though this may change for larger models. DMT offers a remedy but is not time‑parallel, though it could be parallelized via DEER.

RNNs hold the promise of solving tasks that span unbounded horizons, yet traditional training methods struggle with credit assignment over long sequences. Our approach eliminates the $O(T)$ credit‑assignment bottleneck by providing an $O(1)$ connection path, enabling memories that become useful many steps later.

Ablations and Technical Analysis

Ablation results quantify how each component affects training efficiency and performance.

This appendix reports ablation studies that isolate the contribution of each training component.

Joint training of SMT eliminates the need for a large future horizon, achieving near‑zero loss for all future lengths $T_f$.

Figure 10 shows the Detached SMT curve staying above a loss of 3 – 4 for $T_f<64$, while the Joint SMT curve remains at essentially 0 across the entire range.

SMT gradients remain stable across timesteps, unlike BPTT which exhibits exploding or vanishing gradients.

Figure 11 plots gradient magnitudes; BPTT reaches $10^{2}$ at early timesteps, whereas SMT stays around $10^{-7}$ throughout.

Applying DMT reduces rollout drift, lowering the $R^{2}$ error by roughly $0.8$ compared with plain SMT.

In Figure 12 the left panel, the SMT curve climbs toward an error of 1.0, while the DMT‑enhanced curve stays near 0.2.

SMT→DMT RNN generalizes to much longer sequences than the Transformer teacher, keeping test loss $0.23$ lower at $T=512$.

Figure 13 shows the SMT→DMT line at a loss of about 0.12 versus the Transformer line at roughly 0.35 for the longest sequence.

Optimal performance requires a moderate dynamics loss and an almost‑zero uniformity loss; the best RNN test loss observed is $0.52$.

Figure 16’s heatmap peaks at a test loss of $0.52$ when $\lambda_{\text{dyn}}$ is moderate and $\lambda_{\text{unif}}$ is near zero.

**Figure 10.** **Joint SMT Ablation.** Here, the task requires credit assignment across $T$ timesteps. When the RNN is detached during SMT, $T_f$ must be large enough to capture the task signal ($T_f = T$). With joint training, SMT solves the task even when $T_f$ is small.

**Figure 11.** Gradient Properties of BPTT and SMT. In the needle retrieval task, the loss is applied at the last timestep. BPTT propagates gradients backward through all timesteps, risking vanishing/exploding gradients for each $m_t$, depending on the weight initialization. SMT is non-recurrent and has a $\mathcal{O}(1)$ credit path length, making its gradients agnostic to initialization and time-horizon.

**Figure 12.** Impact of DMT across many runs with different SMT $\lambda_{dec}$ and $\lambda_{dyn}$ hyperparameters. Left: Applying DMT reduces the drift of the RNN rollout (measured with $1 - R^2$ of RNN memory prediction $\hat{m}_t$ of encoder ground truth $m_t$). Middle: DMT significantly improves RNN performance across settings. Right: The one-step drift of the RNN only partially correlates with the rollout drift.

**Figure 13.** Sequence Length Generalization. An SMT→DMT trained RNN generalizes better than its Transformer teacher when evaluated on sequence lengths longer than training. The task is synthetic state tracking.

**Figure 16.** Sweep of $\lambda_{dyn}$ and $\lambda_{unif}$. Cell color indicates the RNN test loss for each setting. Top number in each cell is the RNN test loss. Bottom number in each cell shows the $\mathcal{L}^{unif}$. $\mathcal{L}^{unif}$ varies from 0 (collapsed latent space) to $-4$ (fully uniform latent space).

Questions & answers

What is the main contribution of this paper?

The paper introduces Supervised Memory Training (SMT), a method that pretrains RNNs without Backpropagation Through Time (BPTT) by using a Transformer encoder to generate teacher memory states, turning recurrent training into a parallel, one-step supervised learning problem with an O(1) credit-assignment path.

What problem does SMT address and why does it matter?

SMT addresses the fundamental limitations of BPTT, which requires unrolling sequences and propagating gradients through up to T timesteps, causing O(MT) memory costs and vanishing or exploding gradients that make learning long-range dependencies difficult. This matters because longer sequences provide richer context but are precisely where BPTT becomes most unstable.

How does Supervised Memory Training (SMT) work?

SMT trains a Transformer encoder to learn a 'predictive state'—a compressed representation of the past sufficient to predict the future—and uses these encoder-generated memory states as labels for one-step supervised transitions of the form (m_t, x_{t+1}) → m_{t+1}. Because gradients do not propagate through the encoder, the RNN learns memory dynamics in a single parallel forward pass rather than through sequential unrolling.

What is DAgger Memory Training (DMT) and why is it needed?

DMT is a fine-tuning phase that applies on-policy imitation learning to correct the distributional drift that accumulates when the RNN relies on its own predicted memories during autoregressive rollout rather than the teacher encoder's memories. It is described as lightweight because gradients still only compare the RNN's memory to pre-computed teacher memories and do not propagate through the encoder.

Does SMT completely replace BPTT?

SMT is primarily a pretraining method and does not fully replace BPTT; because the teacher encoder is constrained by its own parallel architecture, the RNN may require lightweight post-training such as DMT to adapt to specific tasks and achieve expressivity beyond the teacher's limitations.

What datasets and benchmarks are used to evaluate SMT?

The paper evaluates SMT on TinyStories (character-level language modeling), MNIST pixel-sequence modeling, and Sketchy (sparse line-art), as well as synthetic tasks designed to isolate gradient stability, memory capacity, state tracking, associative recall, and in-context learning. The paper explicitly excludes linear RNNs from these experiments.

What RNN architectures are tested with SMT?

The paper evaluates three training regimes—BPTT, SMT, and SMT→DMT—on nonlinear RNNs built on Transformer, MLP, and GRU backbones.

What are the key empirical results of SMT compared to BPTT?

SMT and DMT consistently outperform BPTT on long-horizon tasks while maintaining better scaling properties, as demonstrated across language modeling, pixel-sequence modeling, and synthetic tasks. The paper does not report a single aggregate numeric improvement figure but characterizes the advantage as consistent across all evaluated settings.

Why is an RNN trained with SMT more efficient than a Transformer at inference time?

Transformers lack a compressed memory of the past, so their memory and compute costs grow linearly with sequence length, whereas RNNs trained with SMT maintain a fixed-size memory and O(1) inference cost, making them more efficient for long-horizon tasks.

What are the limitations of SMT acknowledged in the paper?

The paper acknowledges that the teacher encoder's expressivity may limit the RNNs it trains, sometimes necessitating BPTT fine-tuning; the current SMT variant trains only a single memory state per sequence (training all states showed no benefit at the evaluated scale, though this may change for larger models); and DMT is not time-parallel, though the paper notes it could potentially be parallelized via DEER.

How does SMT differ from ordinary behavior cloning on hidden states?

In ordinary behavior cloning the RNN tries to mimic its own hidden state, which is still learned end-to-end. SMT instead provides an external teacher memory m_t generated by a parallel encoder, so the RNN only learns a deterministic mapping (m_t, x_{t+1}) → m_{t+1} without any gradient flowing back through the sequence.

How does SMT relate to prior work such as Next-Latent Prediction (NextLat) and Predictive State Representations (PSRs)?

The paper notes that Next-Latent Prediction (NextLat) trains an RNN with memory-state supervision from a Transformer and under certain hyperparameters becomes equivalent to SMT. Predictive State Representations (PSRs) model partially observed dynamical systems by predicting future observations and have been incorporated into RNNs, but those prior works still rely on BPTT and thus lack time-parallel training, unlike SMT.

How does SMT compare to parallelizing nonlinear RNNs via Newton's method?

Parallelizing nonlinear RNNs via Newton's method formulates the forward pass as an iterative optimization problem but inherits BPTT's O(T) credit-path length and adds convergence concerns, whereas SMT achieves an O(1) credit-path by training a predictive state encoder.

What is the theoretical motivation for using nonlinear RNNs rather than linear RNNs or Transformers?

From a computational-complexity perspective, models with constant or logarithmic depth per layer—such as Transformers and linear RNNs—are provably limited to low-depth tasks, whereas nonlinear RNNs, whose depth grows with sequence length, can in principle address problems requiring deep sequential computation.

Who are the authors of this paper and where was it published?

The paper does not specify author names or a publication venue in the provided text; it is available on arXiv at https://arxiv.org/abs/2606.06479.

How would a practitioner apply or reproduce SMT?

A practitioner would first train a Transformer encoder to produce predictive memory states from input sequences, then use those states as one-step supervised labels to train an RNN's memory transition function (m_t, x_{t+1}) → m_{t+1} without unrolling, and optionally fine-tune with DMT using on-policy imitation learning to correct autoregressive drift. The paper does not specify code availability or exact hyperparameter settings in the provided text.

Key terms

Backpropagation Through Time (BPTT)
The standard algorithm for training recurrent networks by unrolling the computation graph across all timesteps and propagating gradients backward, which causes memory costs and gradient instability to grow with sequence length.
Supervised Memory Training (SMT)
A pretraining method that replaces BPTT by using a Transformer encoder to generate target memory states, turning RNN training into a parallel, one-step supervised learning problem with an O(1) credit-assignment path.
DAgger Memory Training (DMT)
A fine-tuning phase built on on-policy imitation learning that corrects the distributional drift accumulated when an RNN trained with SMT generates its own memory states during autoregressive rollout.
Predictive state
A compressed representation of all past observations that contains exactly the information needed to forecast future observations, learned here by the Transformer teacher encoder.
Recurrent Neural Network (RNN)
A neural network that processes sequences by maintaining a hidden memory state that is updated at each timestep based on the current input and the previous state.
Transformer encoder
A neural network architecture that processes entire sequences in parallel using self-attention, used in SMT as the teacher that generates optimal memory state labels.
Distributional drift (covariate shift)
The accumulation of errors that occurs when a model is trained on one distribution of inputs (teacher-generated memories) but must operate on a different distribution at inference (its own predicted memories).
Predictive State Representations (PSRs)
A framework for modeling partially observed dynamical systems by representing the system's state as predictions of future observations rather than latent variables.
Next-Latent Prediction (NextLat)
A cross-architecture teacher-student distillation method that trains an RNN using memory-state supervision from a Transformer, which the paper notes becomes equivalent to SMT under certain hyperparameters.
Associative scan
A parallel algorithm that allows linear RNNs to be trained efficiently by exploiting the associativity of their recurrence, eliminating the sequential bottleneck of classic recurrent updates.
DEER
A method mentioned in the paper as a potential way to parallelize DMT, though the paper does not elaborate on its details in the provided text.
GRU (Gated Recurrent Unit)
A type of recurrent neural network architecture that uses gating mechanisms to control information flow, helping mitigate vanishing gradient problems compared to simple RNNs.
O(1) inference cost
A computational complexity property meaning that the time and memory required per inference step remains constant regardless of how long the input sequence is, as achieved by fixed-size RNN memory states.
TinyStories
A character-level language modeling dataset used in the paper's empirical evaluation of SMT versus BPTT.
Sketchy
A sparse line-art dataset used in the paper's empirical evaluation of sequence modeling with SMT.
Credit assignment
The problem of determining which past inputs or decisions are responsible for a current outcome, which in BPTT requires gradients to travel through all T timesteps of a sequence.

Read the original paper

Open the simplified reader on Paperglide