Learning to (Learn at Test Time): RNNs with Expressive Hidden States
๐ Abstract
The paper proposes a new class of sequence modeling layers called Test-Time Training (TTT) layers, where the hidden state is a machine learning model and the update rule is a step of self-supervised learning. The authors introduce two instantiations, TTT-Linear and TTT-MLP, and show that they outperform strong Transformer and RNN baselines, especially in long context. The paper also discusses practical innovations to improve the wall-clock time efficiency of TTT layers.
๐ Q&A
[01] Method
1. What is the key idea behind TTT layers? The key idea is to make the hidden state a machine learning model itself, and the update rule a step of self-supervised learning. This allows the hidden state to effectively capture the underlying structures and relationships in the historic context, which is crucial for performance in long context.
2. How do TTT layers differ from RNN layers and self-attention?
- RNN layers compress context into a fixed-size hidden state, limiting their performance in long context.
- Self-attention explicitly stores all historic context without compression, but has quadratic complexity.
- TTT layers aim to strike a balance - they compress context into a model-based hidden state, which can be more expressive than RNN layers while maintaining linear complexity.
3. How is the self-supervised task designed in TTT layers? The self-supervised task is learned as part of the outer loop, by adding learnable parameters (the "reconstruction views") to a basic reconstruction loss. This allows the self-supervised task to be optimized end-to-end for the final goal of next-token prediction.
4. What are the two practical innovations introduced to improve the wall-clock time of TTT layers?
- Mini-batch TTT: Parallelizing the update rule by computing gradients on mini-batches of tokens instead of individual tokens.
- Dual form: Reformulating the computations to better utilize modern GPU/TPU hardware through matrix-matrix multiplications.
5. What are the theoretical connections between TTT layers and other sequence modeling layers?
- TTT layers with a linear hidden state and batch gradient descent are equivalent to linear attention.
- TTT layers with a Nadaraya-Watson estimator as the hidden state are equivalent to self-attention.
[02] Experiments
1. How do TTT-Linear and TTT-MLP perform compared to the Transformer and Mamba baselines?
- At short context (2k), the performance is comparable, with Transformer and Mamba being slightly better.
- At longer context (8k and 32k), TTT-Linear and TTT-MLP significantly outperform Mamba, while Transformer remains competitive but is less efficient in terms of FLOPs.
2. What is the effect of using the Mamba backbone vs. the Transformer backbone for TTT layers? The Mamba backbone, which includes temporal convolutions, performs better for TTT layers, especially when the hidden state is less expressive (TTT-Linear vs. TTT-MLP).
3. Why do the results not cleanly fit a linear scaling trend as observed in prior work? The authors note that following the Chinchilla recipe does not lead to a clean linear fit in their experiments, likely due to differences in dataset, context length, tokenizer, and architecture compared to the original Chinchilla study. They encourage the community to explore solutions to this problem.
4. How do TTT-Linear and TTT-MLP perform in terms of wall-clock time compared to the baselines? With preliminary systems optimization, TTT-Linear is already faster than Transformer at 8k context and matches Mamba in wall-clock time. TTT-MLP still faces challenges in memory I/O but shows larger potential in long context.
</output_format>