In a recent paper
Why should mechanistic interpretability care about overfitting? Despite overfitting being a central problem in machine learning, we have little mechanistic understanding of what exactly is going on when deep learning models overfit or memorize examples. Additionally, previous work has hinted that there may be an important link between overfitting and learning interpretable features
So understanding overfitting is important, but why should it be relevant to superposition? Consider the case of a language model which verbatim memorizes text. How can it do this? One naive idea is that it might use neurons to create a lookup table mapping sequences to arbitrary continuations. For every sequence of tokens it wishes to memorize, it could dedicate one neuron to detecting that sequence, and then implement arbitrary behavior when it fires. The problem with this approach is that it's extremely inefficient – but it seems like a perfect candidate for superposition, since each case is mutually exclusive and can't interfere.
In this note, we offer a very preliminary investigation of training the same toy models in our previous paper on limited datasets. Despite being extremely simple, the toy model turns out to be a surprisingly rich case study for overfitting. In particular, we find the following:
We hypothesize that real neural networks perform operations in a sparse, high-dimensional “feature” space, but these features are difficult for us to see directly because they’re stored in superposition. Motivated by this, we attempt to simulate this feature space using synthetic input vectors
We consider the the "ReLU Output" toy model, defined as
where
In this work we limit ourselves to uniform importance
We use 50,000 full-batch updates, as opposed to mini-batch, using the AdamW
Unless otherwise specified, we use a weight decay of
In the "normal superposition" we described in our previous paper, we found that the model embeds more features than it has dimensions, often mapping them to regular polytopes. For example, if the model has a two dimensional hidden space, sparse features will be organized as a pentagon:
But what happens if we train models on finite datasets instead? It turns out that the models we find will often look very messy and confusing if you try to look at them from the perspective of features, but very simple and clean if you look at them from the perspective of data point activations.
Let's visualize a few ReLU-output models trained on datasets of different sizes, with many sparse features. We'll focus on models with
The data points – rather than the features – are being represented as polytopes!
What happens as we make the dataset larger? Clearly our toy models behave very differently in the small data regime where they "use data points as features" and the infinite data regime where they learn the real, generalizing features. What happens in between?
In the original paper, the notion of "feature dimensionality" was helpful for studying how the geometry of features varies as we change models. For this note, we'll extend our notion of feature dimensionality (which we will denote as
where
We can now visualize how the geometry of features and data points changes as we vary the size of the dataset. In the middle pane below is a scatter plot of both feature and training-example dimensionalities for varying the dataset size (we will discuss the test loss in the top pane in a later section).
In the small data regime on the left, we see that while the feature dimensionalities are small, the training-example dimensionalities follow
In the large data regime on the right, we see 5 features whose dimensionalities are large, while the rest of the feature and training-example dimensionalities are small. The blue vector plots show that those 5 features are represented in a pentagon, while the rest are largely ignored. We provide some intuition as to why one should expect this ~5 feature solution in this colab. The fractional dimension of the pentagon features is notably less than the expected 2/5. We believe this is due to there being many other features (9,995) whose individually small contributions add-up to a significant fraction of the denominator in
Most data examples have nonzero values for only zero or one of the 5 pentagon features, causing the hidden-vectors to also trace out a pentagon in the bottom-right red subfigure. The outliers represent rare cases with >1 nonzero values.
In between these two extremes, things are messier and harder to interpret.
We did not use a consistent scale for the red and blue vector plots in the previous figure. Using a consistent scale (see below figure) reveals that lengths of both hidden and feature vectors vary widely with dataset size, peaking around
A few comments on these trends:
The phenomenon of models behaving very differently in two different regimes, with strange behavior in between, is eerily reminiscent of double descent
For a given
It’s interesting to note that we’re observing double-descent in the absence of label noise. That is to say: the inputs and targets are exactly the same. Here, the “noise” arises from the lossy compression happening in the downprojection. It is impossible to encode 10,000 features into 2 neurons with a linear projection, even in the sparse limit. Thus the reconstruction is necessarily imperfect, giving rise to unavoidable reconstruction error and consequently, double-descent
At this point, it's natural to wonder whether the double descent might be an artifact of only having
We visualize double descent as a two-dimensional function varying both the number of training examples,
There are clearly regions where "double descent" occurs – regions where bigger models or more data hurt performance.
Consistent with prior work on double descent, these results are sensitive to weight decay and the number of training epochs. In the appendix, we show that for
We find that, in toy models, memorization can be understood as models learning "single data point features" in superposition. These models exhibit double descent as they transition from this strategy of representing data points to representing features.
There is much more to explore. The most obvious question is whether the naive mechanistic theory of overfitting that these results suggest generalizes at all to real models. But there's also a lot to ask in the context of toy models:
Inspired by the original Circuits Thread and Distill's Discussion Article experiment, the authors invited several external researchers who we had previously discussed our preliminary results with to comment on this work. Their comments are included below.
Marius Hobbhahn is a PhD student at the University of Tuebingen.
I replicated most findings in the “Superposition, Memorization, and Double Descent” paper. I changed the setup by reducing the sparsity and the number of features by 10x respectively. I still find the double descent phenomenon as described in the paper with very similar constellations for features and hidden vectors. I also found double descent in multiple other settings, e.g. with different loss functions or when adding a ReLU activation between the layers. My preliminary takeaway from these findings is that the double descent is a fairly regular phenomenon that we should expect to happen in many settings. (Details can be found in my post More Findings on Memorization and Double Descent.)
Adam Jermyn is an independent researcher focused on AI alignment and interpretability.
One question I had reading this paper is: what sets the scale at which models learn generalizing features? When I asked this, the authors proposed two potential hypotheses:
The first hypothesis predicts that increasing the weight decay rate should decrease the generalizing scale.
The figure below shows the dimensionalities of features for models trained with different weight decay rates. Lines show the maximum feature dimension and points and lines are colored by the weight decay rate.
The generalizing scale corresponds to a jump in the dimensionalities. Importantly this scale does not appear to change with the weight decay rate, which is evidence against the first hypothesis.
The second hypothesis predicts that the generalizing scale occurs once the dataset is large enough that it contains multiple instances of each feature. That is, it occurs at
The figure below shows the dimensionalities of features for models trained with different weight decay rates. Lines show the maximum feature dimension and points and lines are colored by the feature frequency (
Indeed that appears to be the case! Models trained with very different sparsities learn generalizing features once datasets are large enough to see each feature roughly 10 times.
While this is suggestive, it is not clear that this is the whole story. For instance, for models with more hidden dimensions the dimensionality curves don’t lie as cleanly on top of each other (see below), and there are other trends that are puzzling (e.g. the peak feature dimensions decrease as the datasets grow post-generalization), so it seems possible that there is more going on.
Adam Jermyn is an independent researcher focused on AI alignment and interpretability.
When the authors shared a preliminary draft, they suggested it might be interesting to look at what happens when individual datapoints are repeated in the dataset.
When a datapoint appears a small number of times (2-3) the phenomenology is the same as in this paper, but with more repeats models switch to learning a combination of datapoints and generalizing features.
The figure below shows training histories of the feature and sample dimensions (left panels) as well as the final feature and sample embeddings (right panels) for a model with T=30,000 and a single feature (black) appearing 5 times. The repeated feature is embedded alongside four generalizing features and suppresses the fifth, effectively replacing one of the generalizing features that would ordinarily be learned.
Replication
Adam Jermyn is an independent researcher focused on AI alignment and interpretability.
After seeing preliminary results, I replicated the results in the section “How Do Models Change with Dataset Size?” for models with hidden dimensionm=2 . Overall I found good qualitative agreement. There are some quantitative differences between my results and those shown in the paper, but nothing that I expect to affect any of the conclusions.
The figure below corresponds to the first figure in that section, and shows qualitatively similar features:
In particular, this replication shows the same division into three regimes, of memorizing samples from small datasets, learning generalizing features from large datasets, and doing something more complicated in between, and the sample and feature embeddings look qualitatively similar between my models and the ones shown in the paper..
There are three differences between this and the corresponding figure in the paper that I can see, and I think they may be related:
I ran my models multiple times and verified that the different instances replicate these differences. I have not been able to pin down where these differences come from, and as far as I can tell I have trained my models precisely as described in the text, though it is certainly possible that I have missed something!
I also reproduced the second figure of the same section:
The general trends are very similar. In particular:
There are again differences, though these are quantitative rather than qualitative. In particular, the peak bias norms in my models are roughly 3 times larger than those in the paper, and I see a rise in the weight norms over the range T=100-1000 whereas the figure in the paper shows more of a plateau.
Original Authors' Response: Thanks for replicating this! It's really nice to see that everything qualitatively reproduced. We're uncertain what caused the shift in the dataset size at which the transition occurs. It seems like there must be some hyperparameter difference between our setups, but we're uncertain what it is! However, since we only really care about the existence of the transition, and not exactly where it falls for this toy problem, we're not that concerned about identifying the exact difference.