Gemma Scope: Open Sparse Autoencoders Everywhere All At Once on Gemma 2
๐ Abstract
The article discusses the use of sparse autoencoders (SAEs) as an unsupervised method for learning a sparse decomposition of a neural network's latent representations into interpretable features. It introduces Gemma Scope, an open suite of JumpReLU SAEs trained on various layers and sub-layers of the Gemma 2 language models. The goal is to enable and accelerate research on interpretability and safety by the broader community.
๐ Q&A
[01] Training Details
1. What data was used to train the SAEs?
- The SAEs were trained on activations generated from the same text data distribution as the pretraining data for Gemma 1, except for the SAEs trained on the instruction-tuned (IT) Gemma 2 9B model.
- The activation vectors were normalized to have unit mean squared norm.
2. How were the SAEs trained?
- Optimization used the Adam optimizer with a cosine learning rate warmup and linear sparsity coefficient warmup.
- The JumpReLU threshold was initialized as a vector of zeros, and the encoder and decoder weights were initialized using He-uniform initialization.
- SAEs with 16.4K latents were trained for 4B tokens, 1M-width SAEs for 16B tokens, and all other SAEs for 8B tokens.
3. What infrastructure was used to train the SAEs?
- Most SAEs were trained using TPUv3 in a 4x2 configuration, with some wider SAEs trained on TPUv5p.
- A shared server system was used to enable parallel disk reads and dynamic data fetching to overcome disk throughput limitations.
[02] Evaluation
1. How was the sparsity-fidelity tradeoff evaluated?
- The sparsity was measured by the mean L0-norm of the latent activations.
- Fidelity was measured using two metrics: delta LM loss (the increase in cross-entropy loss when the SAE is spliced into the LM) and fraction of variance unexplained (FVU).
2. How did the performance of the SAEs vary by sequence position?
- Reconstruction loss increased rapidly over the first few tokens, then plateaued or increased more gradually.
- Delta LM loss showed signs of being slightly lower for the first few tokens, particularly for residual stream SAEs.
3. How did the performance of SAEs vary with width?
- Wider SAEs provided better reconstruction fidelity at a given level of sparsity.
- However, wider SAEs also showed a phenomenon of "feature splitting", where latents in a narrow SAE split into multiple specialized latents in a wider SAE.
4. How did SAEs trained on the base Gemma 2 models perform on the instruction-tuned (IT) Gemma 2 9B model?
- SAEs trained on the base models were able to faithfully reconstruct the activations of the IT model, with only a small increase in delta LM loss compared to SAEs trained directly on the IT model.
</output_format>