Eric Michaud on Quantum Interpretability

Eric Michaud is a PhD student in the Department of Physics at MIT working with Max Tegmark on improving our scientific/theoretical understanding of deep learning – understanding what deep neural networks do internally and why they work so well. This is part of a broader interest in the nature of intelligent systems, which previously led him to work with SETI astronomers, with Stuart Russell’s AI alignment group (CHAI), and with Erik Hoel on a project related to integrated information theory.

In this episode we mostly talk about Eric’s paper, The Quantization Model of Neural Scaling, but also about Grokking, in particular his two recent papers, Towards Understanding Grokking: an effective theory of representation learning, and Omnigrok: Grokking Beyond Algorithmic Data.

(Note: this episode was recorded outside, without any preparation, so this conversation is a bit less structured than usual. It’s also more technical annd might require some familiary with deep learning or at least scaling. And, as always, you can click on any sub-topic of your liking in the outline below and then come back to the outline by clicking on the green arrow)



Eric’s research: interpretability, grokking, scaling, quantization

Michaël: I’m here with Eric Michaud. We were at a party yesterday, and I asked him if he wanted to do an interview hiking. And he said yes. So Eric, who are you? Yeah, I’m a PhD student at MIT in Max Tegmark’s group.

Eric: And I think about what neural networks do, what they learn, and why they work well, and this sort of thing.

Michaël: And today we’re going to be talking about grokking and quanta…

Eric: Scaling and that sort of thing, yeah.

Michaël: Can you just give a few words on what you’re interested in your research?

Eric: Yeah. So I guess the most interesting thing to maybe talk about today is this paper from me and our group called the Quantization Model of Neural Scaling. And this is basically interested in what is the difference between what small networks and large networks learn, or just more generally, what effect does scaling have on what neural networks learn? And how should we think about the difference between small and large networks? And how does this inform how we might try to do mechanistic interpretability? Or what are the right objects of study when trying to understand the networks mechanistically? And we could also talk about grokking. I was just a middle author on some papers on grokking from our group last year.

Michaël: So you prefer to talk about papers before you’re the first author.

Eric: We could talk a little bit about grokking, too. There’s something kind of related, I think, between the grokking stuff and some of the scaling stuff.

How this helps with reducing existential risk from AI

Michaël: Do you have any pitch for how this helps making sure we’re alive at the end of the century?

Eric: Yeah. So I guess part of the risk maybe from advanced AI systems is that it arises from the fact that we maybe don’t understand the internals of the models very well. And if we could really deeply, rigorously understand mechanistically the computations happening in the model, and ideally, eventually, who knows, understand how it reasons about various things deception and this kind of thing, then you might be able to avoid some of the worst case maybe alignment failures. So I don’t know. I’m just generally hopeful that if we had a much more rigorous understanding of neural networks, that we’d be in a much better position. This is a little bit risky, because you also might discover something that accelerates things.

Eric: But I don’t know. I feel on net, it just feels really potentially quite good to have a better understanding of the systems. I don’t know. I’m not the first person to make this kind of comparison. The engineering is ahead of the science a little bit, maybe similarly to how people had steam engines and this sort of thing before we had sort of our modern understanding of thermodynamics. And now we have these deep neural networks. And we sort of are good at the engineering. We maybe observe empirically some interesting properties scaling laws, which maybe point at something principled going on that we could hope to have a better theory eventually for how to think about the networks. And it just feels having that would be very useful for alignment, at least hopefully.

Eric’s Background: From Math to Astronomy and AI alignment

Michaël: And yeah, can you give a few words on your background? How did you came to do this kind of research? If you were doing physics before, or when you got interested in maybe doing this kind of research, if you decided to help push AI towards good outcomes and decided to do this research, or were you just generally interested in deep learning and found this kind of work useful and interesting?

Eric: Yeah, so I kind of did a bunch of different things as an undergrad. So I was an undergrad just down the hills here in Berkeley. And I was a math major, but I took a bunch of CS in physics. And during undergrad, I wasn’t just doing AI stuff.

Eric: For a while, I worked with SETI astronomers, doing radio astronomy, looking for aliens, and this kind of thing. And I worked with-

Michaël: Are there any aliens out there?

Eric: I didn’t see any.

Michaël: Is Robin Hanson right?

Eric: Yeah, I don’t really have very sophisticated takes on grabbing aliens or this kind of thing. But it actually does feel like, well, it’s tough to work in SETI. It was really fun. And the people are great. You have to have a good sense of humor, I think, to do it. But it’s a little bit tricky, because you could potentially spend your whole life searching and never find anything. And maybe that’s the most likely thing, is that you might not find anything.

Michaël: Maybe it’s the same with trying to interpret neural networks. You might not find anything.

Eric: If they’re looking for it in the SETI case. Hopefully not. But anyway, yeah. So I worked with SETI people. And I worked with this neuroscientist, Eric Coel, doing this thing. It did involve deep learning. It was sort of this information theory type thing we were measuring in neural networks. But I guess my first introduction to the AI alignment folks was when I interned at CHI with Adam Gleave. And I think it was kind of a slow build towards doing that. I’d sort of been aware of some of these concerns for a while. I had read Life 3.0, for instance, as a freshman in college. And now Max, the author, is my advisor, which is kind of fun. And yeah, I just kind of slowly became more aware of these kinds of things. And it just kind of fit well with my interests.

Mechanistic Interpretability, Polysemanticity, Right Level of Abstraction

Eric: Maybe I’ll just say a little bit about what I’ve been thinking about a little bit lately. I don’t know. It kind of feels there’s not an obvious answer to what the right things to look at within the network are when trying to understand it mechanistically or something. So people try to, there’s this recent paper from OpenAI, trying to do automated interpretability on the neurons of GPT-2. And yeah, it just maybe seems a lot of the neurons are not super interpretable or something. And maybe if you looked at something else, I don’t know exactly what this would be, but combinations of the neurons, if superposition is going on or this kind of thing, they’re interested in one possible explanation for the polysemanticity of neurons. So many neurons in networks maybe activate in lots of different circumstances.

Eric: And so you might hope for each neuron to be very specific to some particular type of input or something that it responds to something very particular. But maybe in general, a bunch of neurons are more messy and they respond to lots of different things. And one reason why this could be the case is superposition, where there’s more features than there are orthogonal directions in the activation space. And so it sort of has to place the features in a way such that they’re not orthogonal. And that can lead to polysemanticity. But I don’t know. It feels there’s also this maybe kind of analogy to the brain or something, where you have neurons and what are the right things to look at within the brain?

Eric: You could try to look at individual neurons, but you could also look at groups of neurons or groups of groups of neurons. And well, what is the level of abstraction there that is going to have the most information about the things that you care about or the dynamics of the system? I wonder if there’s a similar kind of thing that you could ask about deep neural networks.

Michaël: I guess some of the work has been on building circuits and higher structures that can represent simple functions. And I feel this is the group of neurons you’re talking about, more circuits that do specific things.

Eric: Yeah, so some of the language here, I’m familiar with it from this neuroscientist, theoretical neuroscientist Eric Hoell, where for a time, he was interested in what is the functional unit of the brain? And you might wonder, well, what is the functional unit of a deep neural network? And yeah, what is the right thing to look at if you want to understand it? Maybe something circuits.

Eric: I don’t have a super precise definition of what a circuit is. But it feels that might be more meaningful, because it’s a collection of things which together implement some specific computation. And individually, it might not make a lot of sense. But the computation implemented as a whole by that subpart of the system might seem to make more sense or something. But I don’t know. Yeah, when you look at the computation implemented by the network, do you look at the floating point, arithmetic? Or what’s the kind of way you’re thinking about this? And I know in your paper with Quanta, you talk about things being discrete.

What do we mean by discrete when training neural networks?

Michaël: I’m curious, what do you mean by discrete in these kind of scenarios, when everything is floating point arithmetic?

Eric: Yeah, that’s a great question. Yeah, so I guess maybe there are a couple high-level models you could have for thinking about what neural networks are doing and what they consist of. Maybe there’s one model, which is that, oh, it’s just kind of performing this generic function approximation. There’s some manifold that the data lives on. And then the network is just approximating some function defined on this manifold. And bigger networks globally approximate this function with greater resolution. And then there’s maybe this other view of neural networks, which maybe sees them as collections of algorithms or something, collections of these circuits, which perform specialized computations. And so maybe an example of something that feels more discrete, an algorithm in the network, is the circuit that does induction. So one thing that language models are really good at is if there’s some text in their context where there’s some sequence, ABC, and then at the end of the context, you’re predicting the next token. And the last two tokens in the context are AB, then they’re pretty good at saying that C is going to come after that.

Eric: And this is not a kind of thing that’s purely learned over the course of training in the sense of bigram or trigram type statistics. But maybe no matter what the context is, they’re able to do this copying. And it seems like, or at least as reported from this paper from Anthropic, that the way that this capability emerges in the model is more in this phase change, phase transition type way, where there’s some period in the training process where this capability is not relevantly present, and then there’s a transition, and then for the rest of training, it is relatively, relevantly present in the model. And so this is maybe pointing towards a kind of discreteness where it’s like, even if the underlying dynamics is this smooth walk around the parameter space, in terms of the net effect of whether the right parameters are configured in the right way and can almost coordinate with each other in this multi-part algorithm, that is either present or not, roughly.

Eric: Now, it’s a little bit tricky because it’s not fundamentally discrete. You could imagine, for any one of these types of phase transitions, probably, which arise out of this smooth training process, you can imagine defining a progress measure where it’s just like, what is the distance from your current place in parameters in parameter space to a place that would implement the algorithm? And you could just watch that distance decrease over training. But at least maybe from the perspective of the loss, there’s maybe in practice only a short window in which it’s only kind of present or something. And for most of the time, it’s either not or it is. And this is what we mean when we talk about discreteness in our paper.

Michaël: Cool. Thank you very much. This was Eric Michaud after literally one mile of hiking. We’ll go more into details of your papers in the next few hours.

The Quantization Model of Neural Scaling

Michaël: So I’m here with Eric Michaud. And as you can see, there’s some cows in the background that are looking at us. And Eric, I still don’t really understand your quantum paper.

Main idea of the paper

Michaël: What’s the name of the paper and the one-tweet summary?

Eric: Yeah, so the name of the paper is the quantization model of neural scaling. And the one-tweet summary is that it’s possible for smooth loss curves on average to average over lots of small, discrete phase changes in the network performance. What if there were a bunch of things that you need to learn to do prediction well in something language? And so these things could be pieces of knowledge or different abilities to perform certain types of specific computations. So we can imagine enumerating this set of things that you need to learn to do prediction well. And we call these the quanta of the prediction problem. And then what if the frequency in natural data that these were useful, each of these quanta, each of these pieces of knowledge or computational ability, what if the frequency that they were useful for prediction followed a power law? What’s that power law? So power law is that the probability or something that something occurs or the probability that a given one of these quanta is going to be useful for prediction on a given sample is going to be proportional to n to the minus alpha or something. So basically, it drops off in this way where it’s 1 over n to the something.

Michaël: And if you look at the scaling curve, when you have this power law, it’s a straight line on a log-log plot.

Eric: Right, so the famous scale is all you need plots are a log-log plot. And they show power laws, a straight line on the log-log plot for compute data and parameters.

Michaël: Wh en you talk about the things that are useful to predict, if you give concrete examples, those are predicting a certain token in some kind of language modeling task, right?. There’s specific tasks that the model learns or doesn’t learn, right?

Eric: Yeah, so if you think about what is required to do prediction well in natural language and predict the next token in a document across documents on the internet, there are a huge number of different things that you need to know as a language model in order to be able to do this well. If you’re doing prediction or something in a news article that appeared on the internet and it states some fact, then in order to predict certain words in the sentence that state the fact, then the language model kind of needs to know that fact. The president of the United States in 2020 was dot, dot, dot.

Michaël: And so basically, your subtask is where the model knows this specific knowledge. And then there’s more reasoning tasks or mathematical tasks. Do those count as quanta as well?

Clustering of samples, similar cross-entropy loss, cosine similarity

Eric: Yeah, so in the paper, we tried to actually enumerate what some of these are for language modeling. And we did this by clustering together samples based on whether the model’s gradient was

Eric: similar for those samples. And the clusters you get, some of them are junk, but many of them are quite interesting.

Michaël: What do you mean, having similar gradients?

Eric: Yeah, so you just evaluate for a given sample the model’s cross-entropy loss. So a sample here is predicting a token from some context from some document on the internet. And you can just look at the model’s cross-entropy loss in predicting this token and back propagate gradients throughout the model parameters. And you have basically this big gradient vector for the gradient of the model’s loss with respect to its parameters. And basically, you just want to find or you want to group together parameters, or rather, group together samples where these gradients are pointed in similar directions in this very high-dimensional space. We happen to use this spectral clustering approach to doing this, where we compute this cosine similarity between the gradient vectors for all pairs of samples, which is fairly expensive.

Eric: But the clusters that you get are sometimes fairly interesting. It was for a fairly small language model, so they’re not super sophisticated. But things predicting a new line at the end of some line of text within a document where the line lengths are limited. The model, in order to predict the new line, has to count line lengths for the previous lines in the document. And then it’s able to use that to accurately predict when a new line should be present.

Eric: And you can find just a large number of clusters where the thing that is common between the clusters just seems to be that it’s the same type of problem, or doing prediction on those samples requires the same piece of knowledge. And so you might call these the quanta, or evidence of there being quanta, although it’s a little bit tricky, because we, in doing the clustering, enforce this discreteness, where everything is a member of a cluster, a particular cluster, and not another cluster. Anyway, it’s complicated and weird. Who knows whether this is even the right model for thinking about the networks.

Eric: But it would be very exciting if it was the true model, because it would maybe tell you that there were these set of things where, if you enumerated them, you could understand the network’s performance and understood what it has learned. It’s just like, ah, there’s this set of pieces of knowledge or pieces of computation that are needed. And you could describe what these are. You could find them in the network and maybe hope to mechanistically understand the whole network by decomposing it into how it implements each one of these things, how it learns each piece of knowledge or each piece of computation.

Experiment specifics: Pythia models from Eleuther, 70 million to 12 billion parameters, per-token loss

Michaël: You said you tried on smaller models. What kind of model are we talking about? What kind of size are we talking about?

Eric: So I was using the Pythia models from a EleutherAI. So they range from 70 million to 12 billion parameters. I was just going up to the second to last, the six billion parameter model.

Michaël: So what kind of experiments did you run? You did the clustering on all of the things. And you showed that like, there’s a new line token? There’s maybe other tokens. Do you have examples of other simple predictions that you can observe, the subtasks being learned at some point? Like, are there specific examples of subtasks?

Eric: Yeah, so you can look for tokens for which the scaling curves are sharp. So you have these models of different scale trained by EleutherAI. And I can evaluate them on a large number of tokens. So I have a large number of loss values. And then you can just look at each one of these individual loss curves. So you can look at the loss as a function of model scale, the number of model parameters. And you have the scaling curve. We know that the average of all of these scaling curves for many different, all the tokens in the corpus or something, looks a power law. But do they individually look power laws? And you see a range of different behaviors, at least with the seeds of the model that are released. And you can find individual curves which look the loss is high. And then it drops to approximately zero at a given scale. And it’s like, oh, yeah, maybe in general, these seem more tokens that involve facts, whereas things that look really smooth, it’s not clear that there’s a specific piece of knowledge that you could have to do prediction well at all. But it might just be based on more intuition or heuristics or many heuristics that might be contributing altogether.

Michaël: Yeah, I was going to ask, how many tokens are we talking about? There’s a number of tokens for the byte pair encoding from OpenAI, or is it thousands of tokens? So do you look at them individually?

Eric: Well, so there’s a little bit of this notational overload thing where you can talk about tokens in the vocabulary, but you can also talk about tokens in documents. So there’s 50,000 tokens in the vocabulary of the tokenizer or whatever. But across a large number of documents, in the pile of the data set that the models were trained on, there’s a test set which has a few hundred thousand documents, I think. And so it has definitely over 100 million tokens. And so you can imagine computing a loss or a scaling curve on each one of those 100 million tokens.

Automatic Clustering, Size and Kinds of Clusters

Michaël: And so the clustering you did, it’s automatic, right? You didn’t force the model to.. you didn’t force a cluster on the new line thing. You just made this cluster automatically and then you look at the different clusters. How many clusters did you find in total approximately?

Eric: Yeah, so when I did this in the paper, I used 10,000 samples, 10,000 tokens where the model got low loss on. And then I think I was mostly looking at an experiment where I produced 400 clusters. But this is a little bit tricky where the algorithm has to put everything in a cluster. So there are a bunch of clusters which are low quality and then a bunch of them which are much higher quality.

Eric: I don’t know, this is kind of a cool exercise, maybe useful for mechanistic interpretability because the sort of end result of this is like, okay, well, maybe we have 200 clusters or something of model behaviors where it’s like, ah yes, each of these samples involved, the model was really good at predicting this token from its context. And maybe it was really good at predicting that token from its context for similar reasons within the cluster. And then you could hope to fight a circuit or something for each cluster of behavior.

Power Law of how useful a token is for prediction

Michaël: From reading your thread about quanta, you mentioned something about the frequency of some words in the training data set. And I know something about number of parameters frequency is like, yeah, I don’t really understand exactly what this whole deal is about.

Eric: Yeah, so the kind of core hypothesis of the model that we make in the paper, in addition to like, there being this discreteness or this discrete set of things need to learn. We also assume that there’s this power law governing how useful these things are for prediction. So you can imagine there are some facts, like, I don’t know, Max Tegmark is a physicist at blank where there’s some frequency that knowing this piece of knowledge is useful for prediction.

Eric: One in a million or one in a billion, next word prediction problems on the internet require this piece of knowledge. And then maybe there’s some other much more frequently useful pieces of knowledge if you’re gonna do language modeling, understanding basic grammar or this kind of thing would be like, if you could only learn so many things, you should probably learn the basic grammatical rules before you learn anything about who Max Tegmark is.

Ordering Clusters Depending On How Useful They Are For Prediction

Michaël: Say you have a very small model that can only shift one million parameters. It doesn’t have the size to remember who Michael Trazzi or Eric Michaud is. But if it’s a one trillion parameter model, then maybe you must remember all these facts. And so there’s an ordering in the quanta you learn and you start by the quanta that the most useful for your training.

Eric: Yeah, if you can only learn so many of these things, you should learn the ones which reduce the mean loss the most. And these are the things we sort of conjecture that are most frequently useful for prediction. So the most frequently useful pieces of knowledge maybe will be learned by the small networks and then the big networks will like, you know, have circuits or they’ll know the knowledge that is much more niche, niche physics or niche facts, which smaller models just don’t have the capacity to learn.

Empirical Evidence For The Power Laws

Michaël: Yeah, do you have any empirical evidence for like, have you done experiments or just a conjecture on like, what do you think your theory might predict?

Eric: So when we did the clustering of the samples in language, you could look at the size of the clusters and then ask whether the cluster sizes drop off as a power law. Yeah, how do you measure the size, number of things inside or the distance in some space? No, yeah, what we did was we just looked at the number of samples in each cluster. So we just randomly, we did some filtering, but basically we just took a bunch of samples, which the model got low loss on, like, and clustered them and then looked at the size of the clusters. And like, this is really messy. And like, there are a bunch of reasons why like, even if there was a power law, this could fail. But like, I don’t know, like, it was a pretty messy curve if you look at the size of the clusters, versus their ordering by size. But maybe eventually it kind of looks roughly power law like, and very roughly with the exponent that we would sort of expect from the empirical scaling laws for the language models themselves.

Michaël: So what’s a power law? Is a power law in the number of items in your cluster?

Eric: Yeah, so cluster 100 has like, I don’t know, 50 elements and then cluster 200 has 20 elements and the size of these clusters will drop off. If you put it on a log log scale of the size of the cluster versus like, the sort of ranking of the cluster, once you order them all by size, then eventually it looks a straight line, maybe roughly, but it’s a very tentative result.

Concrete examples of Clusters

Michaël: And so like, when you talk about clusters at the beginning for grammar stuff, the cluster reveal the rules of I am, you are, or the S at the end of verbs or something.

Eric: No, yeah, there really are clusters that, that we find by doing this. So there’s a cluster for like, when you have something D-O-N apostrophe, then it predicts T after that for don’t. And like, there are just a huge number of examples this, commas after the word however, is another one of the clusters. And so it’s like, well, knowing that commas often come after the word however, it’s just very, it turns out is very frequently useful for predicting the next token in internet text. And so it’s something that even the small language models, it’s a piece of knowledge or what we would call a quantum that even the small language models possess.

Answer To Skeptics: What The Quanta Predicts And Doesn’t Predict

Michaël: If I was someone on YouTube that disliked everything I watched, and I was posting an angry comment, I would say like, yes, your theory seems kind of nice, but it’s just putting a nice name, quota on things. And I feel everything can be described as a quota if you really think hard enough about it. And I really don’t see anything that new that you’re model predicts. I don’t really like, you know, gets why that’s useful to think about. Yeah, do you have anything to say to those people?

Eric: Yeah, so I guess there are, there are a lot of ways in which this, this model could not actually describe what’s going on in language models or real neural networks. I guess I could say this. So there are other models of where power laws come from, where scaling laws come from for neural networks. And they say that, for instance, the exponent governing the power law comes from the dimension of some data manifold, some manifold that the data lies on. And like, we’re just saying that like, oh, maybe it comes from this power law distribution over the frequencies that these things are useful, these quanta. One of the things that we do in the paper is we construct data where our sort of story of scaling is true, where smooth power laws really do average over lots of these phase change type things.

Eric: And there are these discrete sort of subtasks in the prediction problem. And so it’s sort of possible, like, I don’t know, I think the interesting point that the paper makes is that it’s possible for this type of story to still give rise to what we observe with power law scaling, and that it’s possible for data with the right structure for this to be the right story. And then it’s sort of still an open question whether language data has this kind of structure, whether it’s more sensible to think about the problem of predicting language as this kind of smooth interpolation type of thing versus a bunch of discrete pieces of knowledge or algorithms kind of story. And I’m not really sure which one is right, but at least I think it’s maybe useful to have a counterpoint to the sort of everything is smooth type story.

Empirical Evidence in the paper

Michaël: Basically you’re saying that you can explain things in scaling laws pretty easily with your model. And for some stuff in language modeling, you can explain it with quanta. Is this basically right? So scaling, you’re pretty sure that they explain a lot of things that there’s already explained by other papers and in language modeling you’re explaining some stuff?

Eric: One thing we do is that our model makes this kind of prediction about what the relationship between the data and parameter scaling exponents. And some models, other models would maybe say that the data scaling exponent and the parameter scaling exponent should be the same. One of the things we did in the appendix of the paper is we plotted for a bunch of other papers that have done empirical studies of neural scaling, what they’ve measured the parameter scaling exponent as and the data scaling exponent. And like, frankly, the points are all over the figure. It’s a total mess. The chinchilla scaling law, place in the figure where that point is in this plane of data parameter scaling and parameter scaling exponents is it’s right in between the theory line from other models and our model. So it’s pretty ambiguous at this point what is correct.

Michaël: Are you saying that the, maybe the experiments in chinchilla don’t fit with your like, with your experiments or your theory?

Eric: No, so I guess we would expect that the data scaling exponent is going to be less than the parameter scaling exponent. And this is indeed what they observed in the chinchilla scaling loss study. So it seems somewhat encouraging, but if you actually look at the precise value of what these scaling exponents are, it’s sort of not exactly what we would predict.

Michaël: So it’s sort of in the right direction, but not precisely what we would expect.

Eric: Yeah. In the right direction, but not exactly what you would like. Yeah, and then other studies are just totally different. So with open AI original Kaplan’s study, it was the opposite where it was sort of the data scaling was steeper or something than the parameter scaling. And then there are a bunch of other sort of vision maybe paper data points on the figure. And most of these are below the line where the scaling exponents are the same, which is encouraging for our theory, but then there are a bunch of, there are other points also which are above it. And so overall it seems things are messier than I would have expected with neural scaling in the wild.

Michaël: And just to be clear, when you talk about scaling exponents for parameters and data set size, et cetera, we’re talking about the sharpness of the straight line in the log-log plot, right?

Eric: Yeah, it’s the slope of the power law in the log-log plot.

Michaël: Yeah. I think now I kind of get your paper. Is there anything else you think is important?

Eric: Yeah, so I guess I just say that like, I thought this was useful to put out because if it was true, then it would be very exciting because it would sort of, I don’t know, sort of a first shot at a paradigm or something for a science of deep learning where like, ah, what if there are these things that we could understand the network’s performance by reference to just whether or not it learns these things. And like, we can understand scaling as just being a matter of learning more of these things or, you know, the process of training is just learning more and more of these things. And, you know, if this was true, then hopefully we could mechanistically understand the networks by enumerating what these quanta are. But in practice, I expect things to be probably a little bit trickier than even, yeah, then at least the theory laid out in the paper.

Michaël: Maybe the discovery of the quanta in scaling triggers a new field of quantum interpretability. I’m very excited about this. Yeah, maybe we’ll talk more about some other stuff to work on grokking in the next 20 minutes of hike. See you in the next bit.


What’s Grokking?

Michaël: Eric, this is maybe the last shot of the vlog. I met you at NeurIPS on grokking and you keep getting doing those talks about grokking. And I think it’s worth talking about a little bit. And recently I think you’ve been publishing two papers on grokking at ICLR.

Eric: One at NeurIPS and one at ICLR.

Michaël: Yeah, so do you wanna just explain quickly what’s grokking so that our viewers can grokk grokking?

Eric: Sure. Yeah, so grokking is this phenomenon where neural networks can generalize long after they first overfit their training data. This was first discovered by some folks at OpenAI. Basically they were training small transformer models to learn these basic math operations, train the network to do modular addition or something. And you only see some fraction of all possible pairs of inputs or something. And when they train the networks on this task, they would first completely memorize the training data, the pairs of examples that they saw. And they would not generalize at all to unseen examples. But if they kept training the network for way longer than it took for the network to overfit, eventually the network would generalize. So this is just a surprising thing to observe in neural networks. they perfectly fit the training data. And then somehow you keep training and eventually they switch over to doing something else internally,

Eric: which actually allows them to generalize.

Michaël: Because in normal deep learning, when your network has like, you know, very little train loss and very high test loss, is because it overfit, right? And so normally you’re stuck, but in this case, it’s like if you wait long enough, it works. And I think I was listening to Neil Landa talk with Lauren Shan on their work on grokking. And they were saying that like, this is kind of different from double descent. Like, is there a sense in which grokking is when you have a small data sets and you go through it many times and double descent is more when you go through your entire dataset at one point it will generalize?

Eric: So I’m not actually super clear right now on the distinction between grokking and double descent and the exact setup there. There is this paper from some folks at Cambridge on unifying grokking and double descent. And so maybe these things are kind of two sides of the same coin. I think overall, there’s some message here about there being different mechanisms that the network can use to fit the training data. And maybe these are learned at different rates and incentivized in different amounts over the training process. And this can lead to some wonky training dynamics where there’s more than one thing happening inside the network during the training process.

Recent paper: Towards Understanding Grokking

Michaël: Yeah. So yeah, you’ve been publishing two papers on grokking recently. Why are you so excited about grokking? And like, yeah, what are the main takes on these papers?

Eric: Yeah, so I don’t know grokking, it’s kind of exciting just cause it’s weird and surprising. And so like, it’s just, I don’t know, it feels progress if we can understand something in these networks that surprises us. And often things that surprise us are good targets of like, if we can understand these things, then maybe they’ll say something more general. So I don’t know, what is more general? what is to be learned in general about grokking? So there’s a couple of things here. In our first paper, we looked at representation learning and how sort of generalization in these networks on the math operations that they’re learning depends on the network learning these particular structured representations of their inputs.

Eric: And like, it turns out that this structure that we found where like, if we can understand where if for instance, the network is learning to do modular addition, then it’ll arrange in some space with the embedding space for small transformers. It’ll put the embedding vector for, and the embedding vectors for each of the inputs. So if you’re doing addition mod 113 or something, then there’s a vector and embedding for zero and one and two all the way up to 112. And it’ll sort of learn eventually to put these in a circle where it sort of goes from zero one all the way up to 112 and then back to zero, which is sort of exactly how we visualize modular addition in terms of clock math. And it seems maybe the networks learn something similar here.

Michaël: Is the learning of all these modular representations happening through grokking or is it independent?

Eric: Yeah, so I guess there’s this thing that happens in the network where first it memorizes. And so quickly the network find some way of memorizing and like, I don’t really understand mechanistically how it does this, but early on in training it memorizes. But then there’s this kind of slower process of forming the circuit, which actually perform some very structured computation, which is equivalent in some way to the operation that the network as a whole is trying to implement. So if you need to generalize on modular addition, then the only way you’re gonna do that is if internally you implement an algorithm that does something modular addition.

Relationship with Neel Nanda’s Modular Addition Tweets

Michaël: was that something Neil Nenna discovered on Twitter was this modular subtraction addition that was a sum of cosine and sinuses or the same thing?

Eric: Yeah, so the reason why we saw, so we kind of was like, oh yes, there’s this structure here where there’s this ring shape and how the network is organizing its representations, its embeddings, but we didn’t really understand how this was being used. And then Neil, in his description of how the networks are fully implementing modular addition, explained why we were seeing this, which is kind of the first step in this multi-step algorithm, which performs this computation, which is sort of equivalent to modular addition and allows the network to generalize.

Michaël: So yeah, if people want to learn more about this, what’s the name of the paper?

Eric: Yeah, so the paper on representation learning was called Towards Understanding Grokking, an effective theory of representation learning. It was an oral at NeurIPS 2022. \

Omnigrok: Grokking Beyond Algorithmic Data

Michaël: And what’s the other paper that you presented I think at ICLR?

Eric: Yeah, so I went to Rwanda to present this paper. It’s called OmniGrok, Grokking Beyond Algorithmic Data. Such a cool name. Yeah, so I think, so both these papers were led by Ziming Lu, a grad student in the physics department with me. And yeah, OmniGrok is like, yeah, it’s a kind of a little bit of a cute name or something. But basically we just showed in OmniGrok that it was possible to observe this grokking, this delayed generalization on tasks that were not just the sort of algorithmic tasks, which was where grokking was first observed by people on OpenAI, at OpenAI.

Eric: But we could also observe it on MNIST and sort of other standard tasks in deep learning. And basically the way we did this was we just, instead of for MNIST for instance, training on 50,000 examples, we would train on fewer samples, maybe a thousand. And then we would initialize the network weights to be really big and then train with weight decay. And like, when you do this, it just turns out that the network will first memorize. And then you continue training with weight decay and eventually network will generalize.

Eric: So this is Omni, we can sort of grok beyond algorithmic data. We can grok on anything maybe, or maybe not. But at least a few things we can grok on that are nfot math operations. Yeah. From this to everything. Was it the first time someone showed something grokking on MNIST? I think so.

Final Words on grokking, generalisation and interpretability

Michaël: Yeah. Cool. Yeah, if you had people watching that are PhDs in deep learning and want to collaborate with you, what are your lines of research you might do in the future, stuff you’re interested in grokking or in capability, is there any special things you might do in the future that you’re like, you can share?

Eric: Yeah. No, you should not collaborate with me. No, so I don’t know. I guess what is the overarching kind of message here? with both the quanta scaling stuff and with the grokking stuff, we sort of hope to identify these maybe mechanisms in the model that are responsible for certain behaviors or for the model generalizing. And in the case of grokking, there’s sort of multiple circuits or multiple mechanisms that are going on in the model or something where there’s a memorizing mechanism and a generalizing mechanism. And this is very like, you know, I think maybe this way of thinking was maybe first put forth by Neil Nanda in his blog post on grokking. But you know, you can imagine tracking how this mechanism forms for the generalizing circuit over time. And you know, maybe there’s many copies of the mechanism in the network and there’s just kind of a lot going on in the network, but you can maybe think about the network as consisting of some number of these mechanisms, which you can imagine enumerating. And maybe just in general beyond grokking, but in large language models and otherwise, we might hope to sort of decompose their behavior in terms of a bunch of these mechanisms. And like, if you could do this, then you could hope to do interpretability, but maybe other things mechanistic anomaly detection or something you might hope to, you know, eventually be able to say like, ah, yes, when the network did prediction on this problem, it used this and this and this mechanism or something, or these were relevant. And then maybe there’s a mechanism in sufficiently large language models that reasons about deception or something. And then you could just identify it and identify when it’s being used by the model for inference and like, you know, not worry about sharp left turns or whatever.

Michaël: Towards OmniGrok and detecting shuffler turns. Thank you, Eric. And yeah, if you want to watch more of the shows, subscribe to my YouTube channel and support me on Patreon and see you tomorrow for the next video of the InsideView.