We report a number of developing ideas on the Anthropic interpretability team, which might be of interest to researchers working actively in this space. Some of these are emerging strands of research where we expect to publish more on in the coming months. Others are minor points we wish to share, since we're unlikely to ever write a paper about them.
We'd ask you to treat these results like those of a colleague sharing some thoughts or preliminary experiments for a few minutes at a lab meeting, rather than a mature paper.
Tom Conerly, Hoagy Cunningham, Adly Templeton, Jack Lindsey, Basil Hosmer, and Adam Jermyn
An earlier version of this page incorrectly wrote the initialization as U(-\frac{1}{n}, \frac{1}{n}) instead of U(-\frac{1}{\sqrt{n}}, \frac{1}{\sqrt{n}})
Since our last publication, we’ve made some improvements to how we train sparse autoencoders and crosscoders. While we haven’t extensively ablated all the decisions here, we wanted to share a description of our setup in the hope that it will be a useful starting point for external groups training sparse autoencoders. Our setup uses techniques from Rajamanoharan et al (2024).
Let n be the input dimension and o the output dimension and m be the autoencoder hidden layer dimension. Let s be the size of the dataset. Given encoder weights W_e \in R^{m \times n}, decoder weights W_d \in R^{n \times o}, log thresholds t \in R^{m}, biases b_e \in R^{m}, b_d \in R^{o}, and hyperparameters w, \lambda_S, \lambda_P, \varepsilon, and c, the operations and loss function over a dataset X \in R^{s,n}, Y \in R^{s,o} with datapoints x \in R^{n}, y \in R^{o} are:
f(x) = \text{JumpReLU}( W_e x+b_e, t)
\text{JumpReLU}(x, t) = \begin{cases} x& \text{if } x > \exp(t)\\ 0 & \text{otherwise} \end{cases}
\dfrac{\mathrm{d}\text{JumpReLU}(x, t)}{\mathrm{d}x}(x, t) = \begin{cases} 1& \text{if } x > \exp(t)\\ 0 & \text{otherwise}\end{cases}
\dfrac{\mathrm{d}\text{JumpReLU}(x, t)}{\mathrm{d}t}(x, t) = \begin{cases} -\frac{\exp(t)}{\varepsilon}& \text{if } -\frac{1}{2} < \frac{x - \exp(t)}{\varepsilon} < \frac{1}{2}\\ 0 & \text{otherwise}\end{cases}
\hat{y}(x) = W_d f(x)+b_d
\mathcal{L}(x, y) = ||y-\hat{y}(x)||_2^2 + \lambda_S\sum_i \tanh(c \ast|f_i(x)| ||W_{d,i}||_2) + \mathcal{L_P}(x)
\mathcal{L_P}(x) = \lambda_P\sum_i \text{ReLU}(\exp(t) - f_i(x)) ||W_{d,i}||_2
Our implementation of JumpReLU uses a straight-through estimator of the gradient through the discontinuity of the nonlinearity as in Rajamanoharan et al (2024), but unlike Rajamanoharan et al. we allow the gradient to flow through straight-through estimator to all model parameters, not just the JumpReLU thresholds. Also note that we use a tanh penalty to encourage sparsity rather than the penalty introduced by Rajamanoharan et al.
\mathcal{L_P}, which we call the pre-act loss, applies a small penalty to features which don't fire. We've found this extremely helpful in reducing dead features. Note that this provides a gradient signal whenever a feature is inactive, so the appropriate scale is a factor of the typical feature activation density lower than the appropriate scale for other loss terms.
We use c=4, \varepsilon=2, \lambda_P=3\ast10^{-6} and values of \lambda_S around 10. b_d is initialized to all zeros. t is initialized to 0.1.
We initialize W_d from U(-\frac{1}{\sqrt{n}}, \frac{1}{\sqrt{n}}). If X=Y we initialize W_e = \frac{n}{m}W_d^T. If X \ne Y, we initialize W_e from U(-\frac{1}{\sqrt{m}}, \frac{1}{\sqrt{m}}).
We initialize b_e by examining a subset of the data and picking a constant per feature such that each feature activates \frac{10000}{m} of the time. In aggregate roughly 10,000 features will fire per datapoint. We think this initialization is important for avoiding dead features.
The rows of the dataset X are shuffled. The dataset is scaled by a single constant such that \mathbb{E}_{\mathbb{x} \in X}[||x||_2] = \sqrt{n}. The goal of this change is for the same value of \lambda_S to mean the same thing across datasets generated by different size transformers.
During training we use Adam optimizer beta1=0.9, beta2=0.999 and no weight decay. Our learning rate varies based on scaling laws, but 2e-4 is a reasonable default. The learning rate is decayed linearly to zero over the last 20% of training. We vary training steps based on scaling laws. We use batch size 32,768 which we believe to be under the critical batch size. The gradient norm is clipped to 1 (using clip_grad_norm). We vary \lambda_S during training, it is initially 0 and linearly increases to its final value over the entire training period. A reasonable default for \lambda_S is 20 given our other parameter settings. We warmup \lambda_S linearly over the entire duration of training.
Conceptually a feature’s activation is now \mathbf{f}_i ||W_{d,i}||_2 instead of \mathbf{f}_i. To simplify our analysis code we construct a model which makes identical predictions but has an L2 norm of 1 on the columns of W_d. We do this by W_e' = W_e ||W_d||_2, b_e' = b_e ||W_d||_2, W_d' = \frac{W_d}{||W_d||_2} and b_d'=b_d.