\section{Introduction}

Grokking is a phenomenon in neural network training when generalization happens well past the point of overfitting, which is contrary to the common intuition of machine learning. It was firstly discovered by \cite{grokking_first} and until now was only observed on toy examples.

In this work, we attempt to reproduce the findings of the paper ``Towards Understanding Grokking: An Effective Theory of Representation Learning" \cite{understanding_grokking}, in which the authors succeeded in capturing grokking on the MNIST dataset constructing a specific setup for this purpose. They also visualize the division of the learning process into four phases by displaying phase diagrams from a grid search of hyperparameters. 

\section{Scope of reproducibility}
\label{sec:claims}
The authors consider grokking as one of four learning phases. \emph{Grokking} and \emph{comprehension} lead to generalization, though the former has delayed generalization. \emph{Memorization} is a synonym for overfitting when the model performs well on the training set but can not generalize to the validation set. Finally, \emph{confusion} is an indicator of a training failure when a model can not even memorize the training set. One of the main contributions of the paper is an attempt to describe grokking from the theoretical point of view while we focus on its experimental part. Therefore, we study the following claims from the original paper:
\begin{itemize}
    \item \textbf{Claim 1}: The model training may happen in four phases depending on the training hyperparameters. \label{claim:1}
    \item \textbf{Claim 2}: Grokking is an intermediate phase between comprehension and memorization. It occurs if the training hyperparameters are not tuned properly.
    \label{claim:2}
    \item \textbf{Claim 3}: Grokking is a more general phenomenon which is not restricted to toy algorithmic datasets. \label{claim:3}
\end{itemize}

\section{Methodology}
We conduct our experiments using the PyTorch framework \cite{pytorch}. We re-implemented both the toy setup and the MNIST experiments ourselves.

\subsection{Datasets}

\subsubsection{Toy model} Following the authors, we consider a toy dataset consisting of integer pairs $(i, j)$, where $i, j \in \{0, \dots, p - 1\}$. The target is defined as $i + j$, i. e. non-modular addition of the two features. There are $p(p-1)/2$ unique $(i, j)$ pairs (the pairs $i+j$ and $j+i$ are considered the same). We set $p=10$ with a total of $55$ pairs and use a $45/10$ random train-validation split. 

\subsubsection{MNIST}
The authors use the default version of the MNIST dataset \cite{mnist}, but for training, they randomly sample 1k images out of a possible 60k. The validation set is kept unchanged.

\subsection{Model descriptions}

\subsubsection{Toy model} The training setup is borrowed from the original paper. The toy model consists of two parts: an embedding encoder and an MLP decoder (denoted as $\text{Dec}$). Input integers $(i, j)$ are interpreted as characters and mapped to $1$-D embeddings $\mathbf{E}_i, \mathbf{E}_j$. The model output is $\text{Dec}(\mathbf{E}_i + \mathbf{E}_j)$. Generally, we solve a classification task predicting one of $2p-1$ possible sums (we call them classes for clarity). Following the authors, we train the model in two setups: classification and regression. The former trivially optimizes cross-entropy loss, while for the latter we need to generate $2p-1$ fixed random gaussian vectors serving as targets, one for each class, and minimize mean squared error (MSE) between the model output and the target vector for the ground truth class. We set the size of these target vectors equal to 30. The final prediction is a class with the highest probability in classification and a class with a target vector nearest to the model output in regression. The decoder is a $3$-layer MLP with a Tanh activation, with $1-200-200-19$ dimensions for classification (as $2p-1 = 19$ outputs for $p=10$) and $1-200-200-30$ dimensions for regression. The embeddings are initialized from a uniform distribution $U[-\frac{1}{2}, \frac{1}{2}]$, while the default PyTorch initialization is used for the linear layers in MLP.

\subsubsection{MNIST}
The authors use the same $3$-layer MLP with the hidden dimension of $200$ but a ReLU activation. The input image is flattened before feeding to the model, so the overall dimensions are $784 - 200 - 200 - 10$. As the whole theory of grokking in the paper is based on the encoder-decoder idea, the authors call the first two layers an ``encoder" and the last one -- a ``decoder". They initialize the model weights with Kaiming uniform initialization \cite{kaiming_init} and then multiply them by $9$, which means that initialization has an increased scale. Instead of cross-entropy loss, MSE loss with one-hot targets is used. The authors find that it helps the network to fit fast to the training set and much later to the validation set.

\subsection{Hyperparameters}

\subsubsection{Toy model}  We train the toy model with AdamW \cite{adamw} for $10^5$ iterations as proposed by the authors. The learning rate $\eta_{\text{emb}} = 10^{-3}$ and weight decay ${\lambda}_{\text{emb}} = 0$ of embeddings encoder remain fixed, while we sweep over decoder hyperparameters searching for different learning phases. The ranges are taken from the original paper and make up $[10^{-5}, 10^{-2}]$ for the learning rate (logarithmic), $[0, 20]$ for the weight decay in classification, and $[0, 10]$ for the weight decay in regression (both weight decay ranges are linear). The batch size equals $45$, meaning that a full sample gradient is calculated at each iteration. The train-validation split seed is fixed among different runs, as well as the model initialization seed.

\subsubsection{MNIST}
Similar to the toy model, the authors use AdamW optimizer but fix the learning rate $\eta_{\text{enc}} = 10^{-3}$ only for the encoder part of MLP. Weight decay changes for the whole model during sweeping in a logarithmic range of $[10^{-5}, 10^{1}]$, and the learning rate does only for the decoder in a logarithmic range of $[10^{-6}, 10^{0}]$. According to the paper, the model is trained for $10^5$ iterations with a batch size equal to $200$. Both the data generation seed and the model initialization seed are kept the same among different runs.

\subsection{Experimental setup and code}
For each pair of learning rate and weight decay, we train a model for $10^5$ iterations, as it was proposed by the authors. Then we classify the result of training as one of the four phases. To do so, we count the number of steps to reach 90\% accuracy on train and validation sets, $T_{train}$ and $T_{val}$ respectively, for the toy example. For the MNIST dataset, the threshold is 60\% accuracy. Then the authors define the phases as in Table~\ref{tab:1}.

\begin{table}[]
    \centering
    \renewcommand{\arraystretch}{1.2}
    \begin{tabular}{ cccc } 
     \hline
     {} & {} & criteria & {} \\
     \cline{2-4}
     Phase & $T_{train} \le 10^5$ & $T_{val} \le 10^5$ & $T_{val} - T_{train} < 10^3$ \\ 
     \hline
     \textbf{Comprehension} & + & + & +\\ 
     \textbf{Grokking} & + & + & --\\ 
     \textbf{Memorization} & + & -- & Not Applicable \\ 
     \textbf{Confusion} & -- & -- & Not Applicable \\ 
     \hline
    \end{tabular}
    \caption{The definition of training phases.}
    \label{tab:1}
\end{table}

\subsection{Computational requirements}

We carried out the toy model experiments on $16$ Intel Xeon Gold 6240R CPU cores in parallel because we were limited in computational resources. We found CPU computations faster than running on a single NVIDIA V100. Every phase diagram has $21 \cdot 21 = 441$ pixels, each representing a separate launch (i. e. a unique pair of weight decay and learning rate values). Full training for $10^5$ iterations takes about $5$ minutes on a single CPU core, so the budget for a single diagram constitutes $441 \cdot 5 \div 16 \div 60 \approx 2.5$ CPU hours.

In fact, we used an early stopping after both training and validation accuracy had reached a threshold of $90\%$, so the actual budget can be reduced to a maximum of $5$ times depending on the amount of ``comprehension'' phases on the diagram.

For the MNIST dataset experiments, we used a single NVIDIA V100 GPU. The full training of a $3$-layer MLP took approximately $30$ minutes. To draw a phase diagram, we have run $14 \cdot 15 = 210$ training processes. Therefore, the total time consumption should have been about $210 \cdot 30 \div 60 = 105$ hours. Using an early stopping with an accuracy threshold of $60\%$ allowed us to reduce time costs twice, leading to approximately $50$ GPU hours.

Experiments with the increased number of maximum iterations took proportionally more time for both the toy model and MNIST. Overall, all experiments took roughly $30$ CPU and $125$ GPU hours.

\section{Results}
\label{sec:results}

Generally, we succeeded in reproducing the phase diagrams from the original paper (for both the toy model and MNIST). All three claims proposed by the authors agree with our experiments. Our contribution can be summarized as follows:
\begin{enumerate}
    \item We reproduced the toy model (Addition classification and Addition regression) and obtained the diagrams with all 4 phases present, supporting Claim~\hyperref[claim:1]{1}. Besides, we study how phase diagrams change for different random seeds and decoder activation functions in Appendix~\hyperref[app:A]{A}, \hyperref[app:B]{B}.
    \item We reproduced the MNIST experiments minding the authors' code. We observed all 4 phases, contributing to Claims~\hyperref[claim:1]{1}, \hyperref[claim:3]{3}.
    \item The Claim~\hyperref[claim:2]{2} describing the intermediate position of grokking also holds in both experimental setups.
    \item We set up additional experiments described in Section~\ref{sec:additional-results}, which show smooth transitions between phases, complementing Claim~\hyperref[claim:1]{1}. We also doubt the existence of memorization on MNIST, illustrating that this phase is even more delayed\break grokking.
\end{enumerate}

\subsection{Results reproducing original paper}

\subsubsection{Toy model}

\begin{figure}[h]
\caption{Learning phases over decoder learning rate and weight decay for the toy model classification (left) and regression (right) setups.}
\centering
\includegraphics[width=0.85\textwidth]{images/toy_phases.png}
\label{fig:toy-phases}
\end{figure}

The phase diagrams for both classification and regression setups are shown in Figure~\ref{fig:toy-phases}. The same diagrams plotted by the authors can be found in Section 4.1 of the original paper (Figure 6). Clearly, all $4$ phases are presented in both diagrams, supporting the Claim~\hyperref[claim:1]{1}. The phases without generalization (memorization and confusion) occur for the large values of decoder learning rate (in both setups) or for the large values of decoder weight decay (classification setup). The rest values of hyperparameters enable generalization, leading to comprehension (mainly for smaller values of weight decay) or grokking (for larger values of weight decay). The area of grokking in both setups is located between comprehension and memorization, agreeing with the Claim~\hyperref[claim:2]{2}. However, the phase borders strongly depend on a data seed and slightly on a model seed, which we discuss in Appendix~\hyperref[app:A]{A}.

\subsubsection{MNIST}

\begin{figure}[h]
\caption{Phase diagram for MNIST.}
\centering
\includegraphics[width=0.75\textwidth]{images/mnist-original-phases-100000.png}
\label{fig:mnist-phases}
\end{figure}

In Figure \ref{fig:mnist-phases}, we show the reproduced phase diagram of learning dynamics on MNIST. The corresponding phase diagram from authors' experiments can be seen in Section 4.3 of the original paper (Figure 8). In our diagram we have all $4$ phases in the same order as in the diagram from the paper, with grokking sandwiched between memorization and comprehension, which proves all three claims from Section~\ref{sec:claims}. However, the area of confusion phase is much smaller in our results. In Section \ref{sec:additional-results}, we extend the boundaries of hyperparameters and get a more similar phase diagram.
In the same Section \ref{sec:additional-results}, we argue that, in fact there is no memorization phase for MNIST, so grokking does not necessary stay between comprehension and memorization, which indicate an inaccuracy in Claim~\hyperref[claim:2]{2}, but it does not break any intuition about grokking.

\subsection{Results beyond original paper}
\label{sec:additional-results}

As a further investigation of grokking in the scope of the proposed setup, we ask two questions:
\begin{enumerate}
\item Do we need to discretize phases?
\item Does memorization exist or is it even more delayed grokking?
\end{enumerate}

\subsubsection{Question 1}\label{sec:question_1}
By asking the first question, we argue that the boundaries of phases are blurred and cannot be set by some threshold. In order to prove that, we construct smooth phase diagrams using the following procedure. We train models using the same setup, but for each model we save the optimization steps at which the model has reached the desired train and validation accuracy, $T_{train}$ and $T_{val}$, respectively. Then we split the diagram in two parts. The first part corresponds to comprehension and grokking, which both generalize to the validation set. For this part we draw the difference between validation and train steps $T_{val} - T_{train}$. The second part corresponds to memorization and confusion. As at these phases the model does not reach the desired validation accuracy. We draw the number of train steps $T_{train}$ for memorization and the maximum number of optimization steps for confusion.

\subsubsection{Question 2}\label{sec:question_2}
To answer the second question, we train the toy model for $3 \cdot 10^5$ iterations ($3$x from the original training budget) and the MNIST model for $1.5 \cdot 10^5$ iterations ($1.5$x from the original training budget). We plot the smooth phase diagrams described in the previous paragraph and compare them to our reproduction of the original diagrams.

\subsubsection{Results (Q1)}
We have conducted such experiments for both toy model setups and the MNIST dataset. The resulting diagrams (Figures~\ref{fig:cls-smooth}, \ref{fig:regr-smooth}, \ref{fig:mnist-smooth}) show that for the toy model, there is \emph{no implicit gap} between comprehension and grokking, as well as between memorization and confusion, while for the MNIST dataset, there is \emph{a noticeable gap} between all four phases.

\begin{figure}[h]
\caption{Discrete (left) and smooth (right) phase diagrams for the Addition classification toy model. The discrete diagram has $10^5$ max iterations of training, while the smooth diagram is extended to $3 \cdot 10^5$ training iterations.}
\centering
\includegraphics[width=0.9\textwidth]{images/cls_smooth.png}
\label{fig:cls-smooth}
\end{figure}

\begin{figure}[h]
\caption{Discrete (left) and smooth (right) phase diagrams for the Addition regression toy model. The discrete diagram has $10^5$ max iterations of training, while the smooth diagram is extended to $3 \cdot 10^5$ training iterations.}
\centering
\includegraphics[width=0.9\textwidth]{images/regr_smooth.png}
\label{fig:regr-smooth}
\end{figure}

\begin{figure}[h]
\caption{Discrete (left) and smooth (right) phase diagrams for the MNIST dataset. We highlight the range of hyperparameters from the original paper with a red rectangle. The discrete diagram has $10^5$ max iterations of training, while the smooth diagram is extended to $1.5 \cdot 10^5$ training iterations.}
\centering
\includegraphics[width=1\textwidth]{images/mnist-smooth.png}
\label{fig:mnist-smooth}
\end{figure}

\subsubsection{Results (Q2)}
Considering the toy model (Figures~\ref{fig:cls-smooth}, \ref{fig:regr-smooth}), we observe that after increasing the number of training iterations, the overall phase diagram is preserved. A few pixels on the boundary of memorization and grokking have converted from the former to the latter phase. However, a vast memorization area is still present in the diagram, meaning that this phase actually exists in the toy setup (i. e. memorization \emph{is not more delayed grokking} but rather a separate phase).

The results for the MNIST dataset are inpicted on Figure~\ref{fig:mnist-smooth} and guide us to the following observations. First, memorization in the upper left corner \emph{is actually an even more delayed grokking} because training the model for more additional steps has led to convergence on the validation set. Second, there is a notable gap in convergence steps on the validation set between grokking and comprehension phases. Therefore, the phases are indeed well-separated. Third, we obtained a new unexpected memorization phase in the upper right corner. We investigated how train and validation accuracy change over the optimization process for this area of hyperparameters. The results can be seen in Figure \ref{fig:mnist-strange-memorization} (left). It seems like the model falls into an unstable local minimum and then gets kicked out of it during further optimization steps because of large learning rate. Hence, while such behaviour is memorization by definition, in fact, it is not what we used to call by that name. We argue that this configuration is closer to confusion, as we finish iterations with a model having a random accuracy on both training and validation sets, in spite of visiting a local minimum during the optimization.

\begin{figure}[h]
\caption{Train and validation accuracy during the optimization process with the learning rate equal to $50$ and the weight decay equal to $10^{-7}$ (left). Convergence iterations for different pairs of weight decay and learning rate on the MNIST \emph{training} set (right).}
\centering
{\includegraphics[width=0.45\textwidth]{images/mnist-lr50-wd1e-7.png}}
\raisebox{-0.15\height}{\includegraphics[width=0.51\textwidth]{images/mnist-train-smooth-150000.png}}
\label{fig:mnist-strange-memorization}
\end{figure}

Looking at Figure \ref{fig:mnist-smooth}, one might think that every pair of hyperparameters at the comprehension phase is eligible, as they lead to a low difference between validation and training convergence iterations. However, it is critical to keep in mind the convergence iteration on the training set alone (i. e. in terms of iterations, it is better to have $T_{train}=1000, T_{val}=5000$ rather than $T_{train}=50000, T_{val}=50001$). It follows from Figure \ref{fig:mnist-strange-memorization} (right) that the convergence is fast only for a small portion of hyperparameters, and, surprisingly, this area lies close to the border of confusion.

\section{Discussion}

As discussed in prior works \cite{sharp_minima, entropy_sgd}, poor model generalization (i. e. memorization) is related to localization in a sharp local minimum. Then, due to a small shift between train and validation datasets, the model reaches high accuracy on the training dataset and nearly random accuracy on the validation dataset. Thus, we argue that for more complex datasets (compared to the toy examples), which have a negligible distribution shift, it is almost impossible to fall into such sharp narrow local minima, so memorization is not observed in practice.

\subsection{What was easy}

The authors did an excellent job describing all the setups for each experiment. However, some crucial details were missing, such as the activation function for the toy model and the batch size for the MNIST dataset. We contacted the authors and asked for these values. After that, it was easy to reproduce the results.

\subsection{What was difficult}

Reproduction and additional research on the toy model were difficult because the authors did not specify the activation function for the MLP, and we assumed it to be \emph{ReLU}, as for the MNIST model. In practice, it turned out to be \emph{Tanh}, which completely changed the training phases. The impact of an activation function is described in Appendix~\hyperref[app:B]{B}. Also, the resulting phase diagrams strongly depend on the data generation seed, which is shown in Appendix~\hyperref[app:A]{A}. We spent significant time trying to achieve the authors' results before discovering that this was the source of the problem.

\subsection{Communication with original authors}

We have contacted the authors twice: the first time to ask for the MNIST dataset code and the second time to ask about the toy model, as we thought we misinterpreted the setup. In both times, the authors responded very quickly and provided all the information we needed to fill in the blanks, which helped us a lot. We thank the authors for their concern and involvement.

\section*{Acknowledgements}

We would like to thank Sergey Samsonov and Alexey Naumov who encouraged us to participate in ML Reproducibility Challenge 2022. We also thank Maxim Kodryan for the valuable comments on the manuscript. The empirical results were supported in part through the computational resources of HPC facilities at HSE University~\cite{hpc}.
