## Source Code for the paper _Grokking at the Edge of Numerical Stability_

These files are the accompanying code for the _Grokking at the Edge of Numerical Stability_. Here you can find guidance to reproduce the main results of the paper.

## Replicating paper plots

The code to generate several of the main plots of the paper is in `paper_plots.ipynb`. For each plot there are corresponding keys to select the corresponding results. These keys specify the hyparparameter values of the experiments that need to be run before generating the plot. When the experiments corresponding to each of the keys in a plot have been run, you can generate the corresponding plot from `paper_plots.ipynb`.

## Runing MLP experiments

Experiments from the paper using MLPs can be replicated using `grokking_experiments.py`. Use `python grokking_experiments.py --help` for a comprehensive list of arguments. This file can be used to run experiments for Sparse Parity, modular addition, subtraction and multiplication as well as MNIST. 


## Runing Transformer experiments

Transformer results are only implemented for modular arithemtic tasks and can be run from `grokking_experiments_transformers.py` with the same command line arguments as `grokking_experiments.py`. The `transformer.py` code is taken from Nanda et. al. https://github.com/mechanistic-interpretability-grokking/progress-measures-paper.git


## Example experiment command

For example, in order to replicate the StableMax plot in Figure 1, you can check the key for the code above this plot in `paper_plots.ipynb` which is `add_mod|num_epochs-80000|train_fraction-0.4|loss_function-stablemax|log_frequency-5000|lr-0.01|batch_size-5107|float_precision-64`. Extracting the arguments from this, the command to run to replicate this experiment is:

`python grokking_experiments.py --binary_operation add_mod --num_epochs 80_000 --train_fraction 0.4 --loss_function stablemax --log_frequency 5000 --lr 0.01`
