Transformer Circuits Thread

Superposition, Memorization, and Double Descent

Authors

Tom Henighan, Shan Carter, Tristan Hume, Nelson Elhage, Robert Lasenby, Stanislav Fort, Nicholas Schiefer, Christopher Olah

Affiliation

Anthropic

Published

January 5, 2023
* Core Research Contributor; ‡ Correspondence to colah@anthropic.com; Author contributions statement below.

In a recent paper , we found that simple neural networks trained on toy tasks often exhibit a phenomenon called superposition , where they represent more features than they have neurons. Our investigation was limited to the infinite-data, underfitting regime. But there's reason to believe that understanding overfitting might be important if we want to succeed at mechanistic interpretability, and that superposition might be a central part of the story.

Why should mechanistic interpretability care about overfitting? Despite overfitting being a central problem in machine learning, we have little mechanistic understanding of what exactly is going on when deep learning models overfit or memorize examples. Additionally, previous work has hinted that there may be an important link between overfitting and learning interpretable features .

So understanding overfitting is important, but why should it be relevant to superposition? Consider the case of a language model which verbatim memorizes text. How can it do this? One naive idea is that it might use neurons to create a lookup table mapping sequences to arbitrary continuations. For every sequence of tokens it wishes to memorize, it could dedicate one neuron to detecting that sequence, and then implement arbitrary behavior when it fires. The problem with this approach is that it's extremely inefficient – but it seems like a perfect candidate for superposition, since each case is mutually exclusive and can't interfere.

In this note, we offer a very preliminary investigation of training the same toy models in our previous paper on limited datasets. Despite being extremely simple, the toy model turns out to be a surprisingly rich case study for overfitting. In particular, we find the following:

Experiment Setup

We hypothesize that real neural networks perform operations in a sparse, high-dimensional “feature” space, but these features are difficult for us to see directly because they’re stored in superposition. Motivated by this, we attempt to simulate this feature space using synthetic input vectors x which are sparse, high-dimensional, and non-negative (similar to our previous paper). Concretely, x \in \mathbb{R}^n is a n=10,000 dimensional vector. We let individual features x_i = 0 with probability S=0.999, but otherwise uniformly distributed between [0, 1]. However in contrast to our previous work, we then rescale x such that ||x||^2 =1, as this will make memorization of training examples easier One could instead untie model weights W and W_{up} \neq W^T to achieve this, though the resulting h_i will generally have different lengths, and so not quite form perfect T-gons. Untied weights like this are likely a more faithful representation of most real neural networks. in the limit of very large n, one should expect training examples x_i to have equal norms thanks to the central limit theorem.. We also consider training sets of finite size T, whereas our previous work only considered T=\infty. We will use X \in \mathbb{R}^{n \times T} to refer to the matrix of training data, where each column X_i is a training data point.

We consider the the "ReLU Output" toy model, defined as

ReLU Output Model

h~=~Wx

x'~=~\text{ReLU}(W^Th+b)~=~\text{ReLU}(W^TWx + b)

where W is Xavier-initialized . Models are trained to minimize mean-squared reconstruction error.

L = \frac{1}{T} \sum_x \sum_i I_i (x_i - x'_i)^2

In this work we limit ourselves to uniform importance I_i = 1 \quad \forall i.

We use 50,000 full-batch updates, as opposed to mini-batch, using the AdamW optimizer. Our learning rate schedule includes a 2,500 step linear-warmup to 1e-3, followed by a cosine-decay to zero. The number of training updates, the use of full-batch optimization, and the annealed learning rate all seem important for our results. We focus on very low-dimensional hidden spaces which seems to make this a more difficult optimization problem. Additionally, having extremely sparse features makes gradients much noisier. For this reason,  it's important to optimize for a large number of steps using full-batches to produce these results. In some experiments, we'll further train multiple models with different random seeds for parameter-initialization and select the one with the lowest training loss in order to avoid local minima. However we qualitatively find these sensitivities are greatly reduced for models with m>3.

Unless otherwise specified, we use a weight decay of 1e-2. In line with previous work, we find that double descent is strongest with low values of weight decay and vice versa.

"Datapoint Features" vs "Generalizing Features"

In the "normal superposition" we described in our previous paper, we found that the model embeds more features than it has dimensions, often mapping them to regular polytopes. For example, if the model has a two dimensional hidden space, sparse features will be organized as a pentagon:

But what happens if we train models on finite datasets instead? It turns out that the models we find will often look very messy and confusing if you try to look at them from the perspective of features, but very simple and clean if you look at them from the perspective of data point activations.

Let's visualize a few ReLU-output models trained on datasets of different sizes, with many sparse features. We'll focus on models with m=2 hidden dimensions. We pick a two-dimensional hidden space to allow explicit visualization, while the task setup is chosen to produce the most extreme version of the phenomenon we wish to highlight.We find that models most clearly "treat individual data points as features" when the data points are orthogonal and of similar magnitudes (see this Colab notebook for more details). Highly sparse features increase the probability of orthogonal data points. Having lots of features avoids the trivial case where data points are just a single feature activating. Ideally, we'd have preferred to just use enough features that the central limit theorem causes data points to have similar norms, but achieving this with high sparsity would require annoyingly large vectors. If we visualize the resulting columns of W, as in the previous paper, everything is "messy", forming complicated scatters of features rather than clean polygons. But if we instead visualize the training-set hidden-vectors (h_i = W X_i), we see clean structure:

The data points – rather than the features – are being represented as polytopes!Intuitively this makes sense: if T \ll n, it must be easier to represent the T training examples, rather than the n features,  in the hidden m-dimensional space. We might think of these models as operating on a different kind of feature: a "single data point feature" that can be used for memorization, instead of a "generalizing feature". This suggests a naive mechanistic theory of overfitting and memorization: memorization and overfitting occur when models operate on "data point features" instead of "generalizing features". We expect this naive theory to be overly simplistic, but it seems possible that it's gesturing at useful principles!

How Do Models Change with Dataset Size?

What happens as we make the dataset larger? Clearly our toy models behave very differently in the small data regime where they "use data points as features" and the infinite data regime where they learn the real, generalizing features. What happens in between?

In the original paper, the notion of "feature dimensionality" was helpful for studying how the geometry of features varies as we change models. For this note, we'll extend our notion of feature dimensionality (which we will denote as D_{f_i} for the remainder of this text) to also define the dimensionality of a training example, D_{X_i}, as:

D_{X_i} ~=~ \frac{||h_i||^2}{\sum_j (\hat{h_i} \cdot h_j)^2}

where h_i=W X_i is the hidden-vector associated with a training example X_i and \hat{h_i} is the unit vector in the direction of h_i. If the training examples are embedded in an T-gon in m=2, we should expect each training example to have dimensionality m/T.

We can now visualize how the geometry of features and data points changes as we vary the size of the dataset. In the middle pane below is a scatter plot of both feature and training-example dimensionalities for varying the dataset size (we will discuss the test loss in the top pane in a later section).

In the small data regime on the left, we see that while the feature dimensionalities are small, the training-example dimensionalities follow m/T as expected. In the red vector-plots on the bottom, we see that the models are forming T-gons with the h_i. In fact, the ReLU output model can memorize any number of orthonormal training examples using T-gons if given infinite floating-point precision. This motivated our choice to normalize our training examples, though this constraint can be lifted if the W and W_{up} are untied. See this Colab notebook for more details.

In the large data regime on the right, we see 5 features whose dimensionalities are large, while the rest of the feature and training-example dimensionalities are small. The blue vector plots show that those 5 features are represented in a pentagon, while the rest are largely ignored. We provide some intuition as to why one should expect this ~5 feature solution in this colab. The fractional dimension of the pentagon features is notably less than the expected 2/5. We believe this is due to there being many other features (9,995) whose individually small contributions add-up to a significant fraction of the denominator in D_{f_i}.

Most data examples have nonzero values for only zero or one of the 5 pentagon features, causing the hidden-vectors to also trace out a pentagon in the bottom-right red subfigure. The outliers represent rare cases with >1 nonzero values.

In between these two extremes, things are messier and harder to interpret.

We did not use a consistent scale for the red and blue vector plots in the previous figure. Using a consistent scale (see below figure) reveals that lengths of both hidden and feature vectors vary widely with dataset size, peaking around T=30 before rapidly declining in the middle regime, and plateauing in the large-data regime. Plotting the Frobenius norm of W and the l2-norm of b reveals the same trend for the model parameters.

A few comments on these trends:

Superposition Double Descent

The phenomenon of models behaving very differently in two different regimes, with strange behavior in between, is eerily reminiscent of double descent , especially "data double descent" (eg. ). One striking phenomenon of data double descent is that test loss gets worse before it gets better – in violation of naive intuitions that more data should always reduce overfitting!

For a given T, the model’s solution will depend on the randomly-chosen training set, where some will lend themselves to memorization (e.g. orthogonal training examples) and others to generalization. To ensure our results aren’t a fluke, we trained 4 models with different dataset random seeds for each dataset size. We then plotted the average test loss (top pane in our first figure), revealing a clear bump at the transition between the "data points in superposition" regime and the "generalizing features in superposition" regime.

It’s interesting to note that we’re observing double-descent in the absence of label noise. That is to say: the inputs and targets are exactly the same. Here, the “noise” arises from the lossy compression happening in the downprojection. It is impossible to encode 10,000 features into 2 neurons with a linear projection, even in the sparse limit. Thus the reconstruction is necessarily imperfect, giving rise to unavoidable reconstruction error and consequently, double-descent .

The Effect of m on Double Descent

At this point, it's natural to wonder whether the double descent might be an artifact of only having m=2 hidden dimensions,or difficulties with optimization. In this section, we'll confirm that this isn't the case. We'll also be able to explore a theme in some of the double descent literature – understanding it not as a one-dimensional phenomenon, but as a multi-dimensional interaction between model size, dataset size, and training .

We visualize double descent as a two-dimensional function varying both the number of training examples, T, and the number of hidden dimensions, m. All other hyperparameters are the same as above. We train four models for each (T,m) configuration, averaging the resulting losses.  We empirically found optimization to yield much more consistent results for m>3.

There are clearly regions where "double descent" occurs – regions where bigger models or more data hurt performance.

Consistent with prior work on double descent, these results are sensitive to weight decay and the number of training epochs. In the appendix, we show that for m=4 models

Conclusion

We find that, in toy models, memorization can be understood as models learning "single data point features" in superposition. These models exhibit double descent as they transition from this strategy of representing data points to representing features.

There is much more to explore. The most obvious question is whether the naive mechanistic theory of overfitting that these results suggest generalizes at all to real models. But there's also a lot to ask in the context of toy models:







Comments & Replications

Inspired by the original Circuits Thread and Distill's Discussion Article experiment, the authors invited several external researchers who we had previously discussed our preliminary results with to comment on this work. Their comments are included below.

Replication

Adam Jermyn is an independent researcher focused on AI alignment and interpretability.

After seeing preliminary results, I replicated the results in the section “How Do Models Change with Dataset Size?” for models with hidden dimension m=2. Overall I found good qualitative agreement. There are some quantitative differences between my results and those shown in the paper, but nothing that I expect to affect any of the conclusions.

The figure below corresponds to the first figure in that section, and shows qualitatively similar features:

In particular, this replication shows the same division into three regimes, of memorizing samples from small datasets, learning generalizing features from large datasets, and doing something more complicated in between, and the sample and feature embeddings look qualitatively similar between my models and the ones shown in the paper..

There are three differences between this and the corresponding figure in the paper that I can see, and I think they may be related:

  1. The onset of the generalizing feature regime is at larger dataset sizes for my models. I see it at T=20,000 and the paper has it starting at T=5,000.
  2. The onset of the intermediate regime is at smaller dataset sizes for my models (T ~ 100 versus T ~ 700), as measured either by deviations from uniform sample embeddings or by jumps in the test loss.
  3. The test losses in the middle regime are much larger for my models: as high as 5,000 versus ~1.01 in the paper.

I ran my models multiple times and verified that the different instances replicate these differences. I have not been able to pin down where these differences come from, and as far as I can tell I have trained my models precisely as described in the text, though it is certainly possible that I have missed something!

I also reproduced the second figure of the same section:

The general trends are very similar. In particular:

  1. The bias norms start out smaller than the weight norms for small dataset sizes, become larger than the weight norms for somewhat larger datasets, and are then smaller in the generalizing regime.
  2. Both norms rise with increasing dataset size, and rapidly fall back down once the models learn generalizing features.

There are again differences, though these are quantitative rather than qualitative. In particular, the peak bias norms in my models are roughly 3 times larger than those in the paper, and I see a rise in the weight norms over the range T=100-1000 whereas the figure in the paper shows more of a plateau.


Original Authors' Response: Thanks for replicating this! It's really nice to see that everything qualitatively reproduced. We're uncertain what caused the shift in the dataset size at which the transition occurs. It seems like there must be some hyperparameter difference between our setups, but we're uncertain what it is! However, since we only really care about the existence of the transition, and not exactly where it falls for this toy problem, we're not that concerned about identifying the exact difference.

Replication

Marius Hobbhahn is a PhD student at the University of Tuebingen.

I replicated most findings in the “Superposition, Memorization, and Double Descent” paper. I changed the setup by reducing the sparsity and the number of features by 10x respectively. I still find the double descent phenomenon as described in the paper with very similar constellations for features and hidden vectors. I also found double descent in multiple other settings, e.g. with different loss functions or when adding a ReLU activation between the layers. My preliminary takeaway from these findings is that the double descent is a fairly regular phenomenon that we should expect to happen in many settings. (Details can be found in my post More Findings on Memorization and Double Descent.)

Extension: What controls the generalization scale?

Adam Jermyn is an independent researcher focused on AI alignment and interpretability.

One question I had reading this paper is: what sets the scale at which models learn generalizing features? When I asked this, the authors proposed two potential hypotheses:

  1. This is the scale where superposition of data points fails (e.g. due to weight decay constraining the weight norm)
  2. This is the scale where features occur multiple times, in different combinations, allowing them to be distinguished.

The first hypothesis predicts that increasing the weight decay rate should decrease the generalizing scale.

The figure below shows the dimensionalities of features for models trained with different weight decay rates. Lines show the maximum feature dimension and points and lines are colored by the weight decay rate.

The generalizing scale corresponds to a jump in the dimensionalities. Importantly this scale does not appear to change with the weight decay rate, which is evidence against the first hypothesis.

The second hypothesis predicts that the generalizing scale occurs once the dataset is large enough that it contains multiple instances of each feature. That is, it occurs at T \propto 1 / (1-S).

The figure below shows the dimensionalities of features for models trained with different weight decay rates. Lines show the maximum feature dimension and points and lines are colored by the feature frequency (1-S). The horizontal scale now is the expected number of times any given feature appears in the dataset (T (1-S)), and a prediction of the second hypothesis is that the generalizing scale should be fixed on this scale.

Indeed that appears to be the case! Models trained with very different sparsities learn generalizing features once datasets are large enough to see each feature roughly 10 times.

While this is suggestive, it is not clear that this is the whole story. For instance, for models with more hidden dimensions the dimensionality curves don’t lie as cleanly on top of each other (see below), and there are other trends that are puzzling (e.g. the peak feature dimensions decrease as the datasets grow post-generalization), so it seems possible that there is more going on.

Extension: Repeated Data Points

Adam Jermyn is an independent researcher focused on AI alignment and interpretability.

When the authors shared a preliminary draft, they suggested it might be interesting to look at what happens when individual datapoints are repeated in the dataset.

When a datapoint appears a small number of times (2-3) the phenomenology is the same as in this paper, but with more repeats models switch to learning a combination of datapoints and generalizing features.

The figure below shows training histories of the feature and sample dimensions (left panels) as well as the final feature and sample embeddings (right panels) for a model with T=30,000 and a single feature (black) appearing 5 times. The repeated feature is embedded alongside four generalizing features and suppresses the fifth, effectively replacing one of the generalizing features that would ordinarily be learned.