{
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# How sharp is Occam's razor? Toy models for complexity via grokking\n"
      ],
      "metadata": {
        "id": "RANQmBmspRfA"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "By XXXX-5 Jaburi"
      ],
      "metadata": {
        "id": "lVauJQLWP9lc"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 1. Introduction"
      ],
      "metadata": {
        "id": "d9neL8KoqhBl"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Why do neural networks generalize? While a full explanation is still missing, an important principle seems to be Occam's razor. Many versions of this principle exist, but for this document we will summarize it as\n",
        "> *solutions that are simpler, will be preferred.*\n",
        "\n",
        "What remains elusive, is the notion of *simple* or equivalently the notion of complexity. Several competing notions have been proposed, for example:\n",
        "\n",
        "1.   Kolmogorov complexity,\n",
        "2.   Notions from statistical learning theory, such as VC dimension or Rademacher complexity.\n",
        "\n",
        "But they all come with their respective downsides: Kolmogorov complexity requires fixing a UTM and in general is incomputable. More concrete notions as mentioned in 2. have failed empirical verification, see [Zhang et al.](XXXX)\n",
        "\n",
        "As there is a lot of uncertainty about what theoretical notion of complexity is empirically valid, in this document we propose the opposite approach: Starting from empirical results, we want to gather guiding intuition towards a more theoretical notion that could explain why DNNs prefer to learn certain solutions over others.\n",
        "\n",
        "To do this, we propose a toy model for studying generalization behaviour. More precisely, in our set up multiple hypotheses can be explained by the data. A useful theory of complexity should predict which hypothesis is likely to be learned.\n",
        "\n",
        "\n"
      ],
      "metadata": {
        "id": "ZkFNH_hRUmgs"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 2. Method"
      ],
      "metadata": {
        "id": "NgmXjLVKeO14"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "We first start with the high-level picture. Consider supervised learning tasks, where we are trying to learn two (different) functions $f_1, f_2: X\\to Y$. This will yield two datasets $D_1, D_2\\subset X\\times Y$ and we can consider the intersection $D_{12}=D_1\\cap D_2$. We can ask:\n",
        "\n",
        "> If we train a model on the dataset $D_{12}$, which function will it learn?\n",
        "\n",
        "By Occam's razor, the function that is simpler, should be the function that should be learned during training (or is more likely to be learned).\n",
        "\n",
        "<img src=\"XXXX\" width =500>\n",
        "\n",
        "\n",
        "More generally, we could start with a family of functions $\\{f_i\\}_{i\\in I}:X\\to Y$ and consider $D_I=\\cap_{i\\in I} D_i$. Is there now a hierachy determining the learning behaviour?"
      ],
      "metadata": {
        "id": "VZpevKJ2eVU1"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "To get reasonable results, we should put some requirements:\n",
        "\n",
        "\n",
        "1.   The functions $f_1, f_2$ should capture some structure. Using random maps $f_1,f_2: X\\to Y$ is unlikely to result in interesting or informative generalization behaviour,\n",
        "2.   At least initially, the training data should be free of noise to reduce other potential biases. Although later on this poses an interesting question in itself: (How) does noise prefer certain solutions?\n",
        "3. The training data set should be sufficiently big.\n",
        "\n",
        "We propose grokking, where MLPs learn simple algorithmic tasks, as a good toy model. Not only are the two above conditions fulfilled, but additionally we observe that the model already employs some ranking among two options: memorization and generalization. Furthermore, we manage to find a set up where the training data set $D_{12}$ can make up as much as $75\\%$ of the test set $X$, while still fitting both tasks $f_1$ and $f_2$. Also, we can make $X$ arbitrary large.\n",
        "\n",
        "In an ideal case, we would have a notion such as $c(Memorization)>c(f_2)> c(f_1)$, where $c$ indicates the complexity. We should expect then that $f_1$ will be preferred over time, or at least more likely to be learned."
      ],
      "metadata": {
        "id": "b1PLFg5msuID"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Possible measurements for \"complexity\"\n"
      ],
      "metadata": {
        "id": "KSF-3op97Yz2"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "What kind of measurements can we do to observe complexity? For example:\n",
        "\n",
        "1.   Training time until grokking is completed. Do more complex tasks require longer training time?\n",
        "2.   The learning coefficient as described in [Lau et al.](XXXX) Using the widely applicable Bayesian information criterion (WBIC), their approach explores the local geometry of the loss landscape.\n",
        "\n",
        "  Informed by considerations based on the free energy formula in singular learning theory (see [*loc. cit.*]((XXXX) p.5) , the learning coefficient should approximate a notion of \"effective parameters\" that the MLP is using. Less effective parameters used, means the model is simpler.\n",
        "3.  Circuit size. In particular [Varma et al.](XXXX) suggest that, \"*...weight decay prefers circuits [...] that require less parameter norm to produce a given logit value*\".\n",
        "\n",
        "\n",
        "\n"
      ],
      "metadata": {
        "id": "E2vmpNv87dxW"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Running the experiments, that we describe below, serves as a way to empirically verify the above hypotheses. The reader is invited to skip to the Section 4 to run the experiments themselves and to Section 5 & 6 for an overview of the current results."
      ],
      "metadata": {
        "id": "N3jJhFWaN059"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 3. Description of the experiments\n"
      ],
      "metadata": {
        "id": "bHlC_tbm5Zoq"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "We first describe the algorithmic tasks $f_1, f_2$ that are grokked. Grokking was first observed in [Power et al.](XXXX) when performing several binary operations (such as addition) in the finite group ℤ/nℤ: The model quickly realizes perfect accuracy on the training data, but fails to generalize. After some time though it also achieves perfect accuracy on the whole test data.\n",
        "\n",
        "A deconstruction of the learned algorithm in case of the modular addition via Fourier transforms was given in [Nanda et al.](XXXX) A more general approach for finite groups using representations theory was given in [Chughtai et al.](XXXX) An interesting distinction within the learned algorithms, the Clock vs Pizza algorithm, was given in [Zhong et al.](XXXX)\n"
      ],
      "metadata": {
        "id": "-IJGUJ1qrFyh"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "In our case, we study different group operations on a set $M$ of size $100$, i.e. $$f_1,f_2: M\\times M\\to M$$ are given by $f_i(a,b)=a+_i b$. Thus $|X|=|M|^2 =10000$, $|Y|=100$, and the size of the training data set will be $|D_{12}|=7500$.\n",
        "\n",
        "**The following paragraph constructs the two groups and may be safely skipped.**\n",
        "\n",
        "Consider more generally $M=\\{1,...,2\\cdot N\\}$ (in the above $N=50$). We can endow it with two group structures:\n",
        "\n",
        "1.   The commutative group given by $\\mathbb{Z}/N\\times \\mathbb{Z}/2$, which is defined by $(a,b)+_1(c,d)=(a+b,c+d)$,\n",
        "2.   The [semidirect product](XXXX) product $\\mathbb{Z}/N \\rtimes \\mathbb{Z}/2$, which is defined by $(a,0)+_2(b,c)=(a+b,c)$ and $(a,1)+_2(b,c)=(a+(N+1)\\cdot b, 1+c)$.\n",
        "\n",
        "One can verify that $+_1$ and $+_2$ always agree on $75\\%$ of $|X|$ (this will be expanded and explained more in future)."
      ],
      "metadata": {
        "id": "F6qVlWfWkLSx"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "We trained on two different kinds of MLPs. The first one is the one used in [Chughtai et al.](XXXX), the second one is the one used in [Investigating the learning coefficient of modular addition: hackathon project](XXXX). In both cases there are two embedding matrices which are not tied.\n",
        "\n",
        "Be default, we used Adam optimizer with weight decay $0.0002$. In both cases there are ~16-17k parameters.\n",
        "\n",
        "In this colab we will use the latter model."
      ],
      "metadata": {
        "id": "XayZcoqD0nr8"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "<img src=\"XXXX\" width =500>\n",
        "\n",
        "<img src=\"XXXX\" width=400>\n"
      ],
      "metadata": {
        "id": "nrICIjWa1bw7"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Left: Image taken from [Chughtai et al.](XXXX)\n",
        "\n",
        "Right: Image taken from [Investigating the learning coefficient of modular addition: hackathon project](XXXX)"
      ],
      "metadata": {
        "id": "gw1aiiyeq0Qh"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Code"
      ],
      "metadata": {
        "id": "mvHru6Tw8hbZ"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "You can run all cells in this section and proceed to the next section.\n",
        "\n",
        "\n",
        "\n",
        "(Available also at XXXX)"
      ],
      "metadata": {
        "id": "EhN26oNyhUQR"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2xfI9vD5F6fU",
        "colab": {
          "base_uri": "XXXX"
        },
        "outputId": "dbdb38e2-d493-472d-d405-f8ba54b9b486"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting wandb\n",
            "  Downloading wandb-0.16.3-py3-none-any.whl (2.2 MB)\n",
            "\u001b[?25l     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/2.2 MB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K     \u001b[91m━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.5/2.2 MB\u001b[0m \u001b[31m14.5 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m37.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.1.0+cu121)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (4.66.2)\n",
            "Collecting einops\n",
            "  Downloading einops-0.7.0-py3-none-any.whl (44 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.6/44.6 kB\u001b[0m \u001b[31m7.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: Click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb) (8.1.7)\n",
            "Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)\n",
            "  Downloading GitPython-3.1.42-py3-none-any.whl (195 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m195.4/195.4 kB\u001b[0m \u001b[31m28.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: requests<3,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (2.31.0)\n",
            "Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (5.9.5)\n",
            "Collecting sentry-sdk>=1.0.0 (from wandb)\n",
            "  Downloading sentry_sdk-1.40.6-py2.py3-none-any.whl (258 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m258.5/258.5 kB\u001b[0m \u001b[31m34.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting docker-pycreds>=0.4.0 (from wandb)\n",
            "  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\n",
            "Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from wandb) (6.0.1)\n",
            "Collecting setproctitle (from wandb)\n",
            "  Downloading setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)\n",
            "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb) (67.7.2)\n",
            "Requirement already satisfied: appdirs>=1.4.3 in /usr/local/lib/python3.10/dist-packages (from wandb) (1.4.4)\n",
            "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (3.20.3)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.13.1)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch) (4.10.0)\n",
            "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.2.1)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.3)\n",
            "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2023.6.0)\n",
            "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.1.0)\n",
            "Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n",
            "Collecting gitdb<5,>=4.0.1 (from GitPython!=3.1.29,>=1.0.0->wandb)\n",
            "  Downloading gitdb-4.0.11-py3-none-any.whl (62 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.7/62.7 kB\u001b[0m \u001b[31m8.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (3.3.2)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (3.6)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (2.0.7)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (2024.2.2)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.5)\n",
            "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n",
            "Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb)\n",
            "  Downloading smmap-5.0.1-py3-none-any.whl (24 kB)\n",
            "Installing collected packages: smmap, setproctitle, sentry-sdk, einops, docker-pycreds, gitdb, GitPython, wandb\n",
            "Successfully installed GitPython-3.1.42 docker-pycreds-0.4.0 einops-0.7.0 gitdb-4.0.11 sentry-sdk-1.40.6 setproctitle-1.3.3 smmap-5.0.1 wandb-0.16.3\n"
          ]
        }
      ],
      "source": [
        "%pip install wandb torch tqdm einops"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WpO4_jSaF2iV"
      },
      "outputs": [],
      "source": [
        "import torch as t\n",
        "import os\n",
        "from torch.utils.data import DataLoader\n",
        "from torch.utils.data import Dataset\n",
        "from tqdm import tqdm\n",
        "import random\n",
        "import wandb\n",
        "import copy\n",
        "from einops import rearrange\n",
        "from dataclasses import dataclass\n",
        "\n",
        "from datetime import datetime\n",
        "os.environ['CUDA_LAUNCH_BLOCKING'] = \"1\"\n",
        "device = t.device(\"cuda\" if t.cuda.is_available() else \"cpu\")\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "def twisted_group(group, automorphism=lambda x: x):\n",
        "    \"\"\"Constructs semidirect product of groups with Z/2Z using the given automorphism\"\"\"\n",
        "    group_cardinality = group.size(dim=0)\n",
        "    new_cardinality = group_cardinality * 2\n",
        "    new_group = t.zeros((new_cardinality, new_cardinality), dtype=t.int64)\n",
        "\n",
        "    for i in range(new_cardinality):\n",
        "        for j in range(new_cardinality):\n",
        "            if i < group_cardinality and j < group_cardinality:\n",
        "                new_group[i, j] = group[i, j]\n",
        "\n",
        "            if i < group_cardinality and j >= group_cardinality:\n",
        "                new_group[i, j] = group[i, j - group_cardinality] + group_cardinality\n",
        "\n",
        "            if i >= group_cardinality and j < group_cardinality:\n",
        "                new_group[i, j] = (\n",
        "                    group[i - group_cardinality, automorphism(j) % group_cardinality]\n",
        "                    + group_cardinality\n",
        "                )\n",
        "\n",
        "            if i >= group_cardinality and j >= group_cardinality:\n",
        "                new_group[i, j] = group[\n",
        "                    i - group_cardinality,\n",
        "                    automorphism(j - group_cardinality) % group_cardinality,\n",
        "                ]\n",
        "\n",
        "    return new_group\n",
        "\n",
        "\n",
        "def cyclic(params):\n",
        "    cyclic_group = t.zeros((params.N_1, params.N_1), dtype=t.int64)\n",
        "    for i in range(params.N_1):\n",
        "        for j in range(params.N_1):\n",
        "            cyclic_group[i, j] = (i + j) % params.N_1\n",
        "    return cyclic_group\n",
        "\n",
        "\n",
        "class GroupData(Dataset):\n",
        "    def __init__(self, params):\n",
        "        self.group1 = twisted_group(cyclic(params))\n",
        "        self.group2 = twisted_group(cyclic(params), lambda x: (params.N_1 // 2 + 1) * x)\n",
        "        self.group1_list = [\n",
        "            (i, j, self.group1[i, j].item())\n",
        "            for i in range(self.group1.size(0))\n",
        "            for j in range(self.group1.size(1))\n",
        "        ]\n",
        "\n",
        "        self.group2_list = [\n",
        "            (i, j, self.group2[i, j].item())\n",
        "            for i in range(self.group2.size(0))\n",
        "            for j in range(self.group2.size(1))\n",
        "        ]\n",
        "\n",
        "        self.group1_only = [\n",
        "            item for item in self.group1_list if item not in self.group2_list\n",
        "        ]\n",
        "        self.group2_only = [\n",
        "            item for item in self.group2_list if item not in self.group1_list\n",
        "        ]\n",
        "\n",
        "        if (params.data_group1 == True) and (params.data_group2 == False):\n",
        "            self.train_data = self.group1_list\n",
        "        elif (params.data_group2 == True) and (params.data_group1 == False):\n",
        "            self.train_data = self.group2_list\n",
        "        else:\n",
        "            self.train_data = [\n",
        "                i for i, j in zip(self.group1_list, self.group2_list) if i == j\n",
        "            ]  # intersection of G_1 and G_2\n",
        "\n",
        "        self.train_data = self.train_data + random.sample(\n",
        "            self.group1_only, params.add_points_group1\n",
        "        )  # add points from G_1 exclusively\n",
        "        self.train_data = self.train_data + random.sample(\n",
        "            self.group2_only, params.add_points_group2\n",
        "        )  # add points from G_1 exclusively\n",
        "\n",
        "        self.train_data_tensor= t.tensor(self.train_data).to(device)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        return [self.train_data_tensor[idx][0], self.train_data_tensor[idx][1]], self.train_data_tensor[idx][\n",
        "            2\n",
        "        ]\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.train_data)\n"
      ],
      "metadata": {
        "id": "PHKD89WPsiwn"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class MLP(t.nn.Module):\n",
        "    def __init__(self, params):\n",
        "        super().__init__()\n",
        "        self.Embedding_left = t.nn.Embedding(params.N, params.embed_dim)\n",
        "        self.Embedding_right = t.nn.Embedding(params.N, params.embed_dim)\n",
        "        self.linear = t.nn.Linear(params.embed_dim * 2, params.hidden_size, bias=True)\n",
        "        if params.activation == \"gelu\":\n",
        "            self.activation = t.nn.GELU()\n",
        "        if params.activation == \"relu\":\n",
        "            self.activation = t.nn.ReLU()\n",
        "        self.Umbedding = t.nn.Linear(params.hidden_size, params.N, bias=True)\n",
        "\n",
        "    def forward(self, a):\n",
        "        x1 = self.Embedding_left(a[0])\n",
        "        x2 = self.Embedding_right(a[1])\n",
        "        x12 = t.cat([x1, x2], -1)\n",
        "        hidden = self.linear(x12)\n",
        "        hidden = self.activation(hidden)\n",
        "        out = self.Umbedding(hidden)\n",
        "        return out"
      ],
      "metadata": {
        "id": "yJuMyWuCss3d"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class MLP2(t.nn.Module):\n",
        "    def __init__(self, params):\n",
        "        super().__init__()\n",
        "        self.Embedding_left = t.nn.Embedding(params.N, params.embed_dim)\n",
        "        self.Embedding_right = t.nn.Embedding(params.N, params.embed_dim)\n",
        "        self.linear_left = t.nn.Linear(params.embed_dim, params.hidden_size, bias=True)\n",
        "        self.linear_right = t.nn.Linear(params.embed_dim, params.hidden_size, bias=True)\n",
        "        if params.activation == \"gelu\":\n",
        "            self.activation = t.nn.GELU()\n",
        "        if params.activation == \"relu\":\n",
        "            self.activation = t.nn.ReLU()\n",
        "        self.Umbedding = t.nn.Linear(params.hidden_size, params.N, bias=True)\n",
        "\n",
        "    def forward(self, a):\n",
        "        x1 = self.Embedding_left(a[0])\n",
        "        x2 = self.Embedding_right(a[1])\n",
        "        hidden_x1 = self.linear_left(x1)\n",
        "        hidden_x2 = self.linear_right(x2)\n",
        "        hidden_sum = hidden_x1 + hidden_x2\n",
        "        hidden = self.activation(hidden_sum)\n",
        "        out = self.Umbedding(hidden)\n",
        "        return out"
      ],
      "metadata": {
        "id": "cIAcnXpVs1Y0"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "@dataclass\n",
        "class Parameters:\n",
        "    N_1: int = 50\n",
        "    N: int = N_1 * 2\n",
        "    embed_dim: int = 32\n",
        "    hidden_size: int = 64\n",
        "    num_epoch: int = 2000\n",
        "    batch_size: int = 512\n",
        "    activation: str = \"relu\"  # gelu or relu\n",
        "    checkpoint_every: int = 5\n",
        "    max_steps_per_epoch: int = N * N // batch_size\n",
        "    train_frac: float = 1\n",
        "    weight_decay: float = 0.0002\n",
        "    lr: float = 0.01\n",
        "    beta_1: int = 0.9\n",
        "    beta_2: int = 0.98\n",
        "    warmup_steps = 0\n",
        "    optimizer: str = \"adam\"  # adamw or adam or sgd\n",
        "    data_group1: bool = True  # training data G_1\n",
        "    data_group2: bool = True  # training data G_2\n",
        "    add_points_group1: int = 0  # add points from G_1 only\n",
        "    add_points_group2: int = 0  # add points from G_2 only\n"
      ],
      "metadata": {
        "id": "rfsCFg39s5Ly"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Accuracy and cross entropy loss functions."
      ],
      "metadata": {
        "id": "1KVJC8PX3oz3"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def loss_fn(logits, labels):\n",
        "    \"\"\"\n",
        "    Compute cross entropy loss.\n",
        "\n",
        "    Args:\n",
        "        logits (Tensor): (batch, group.order) tensor of logits\n",
        "        labels (Tensor): (batch) tensor of labels\n",
        "\n",
        "    Returns:\n",
        "        float: cross entropy loss\n",
        "    \"\"\"\n",
        "    log_probs = logits.log_softmax(dim=-1)\n",
        "    correct_log_probs = log_probs.gather(dim=-1, index=labels[:, None])[:, 0]\n",
        "    return -correct_log_probs.mean()\n",
        "\n",
        "\n",
        "def get_accuracy(logits, labels):\n",
        "    \"\"\"\n",
        "    Compute accuracy of model.\n",
        "\n",
        "    Args:\n",
        "        logits (torch.tensor): (batch, group.order) tensor of logits\n",
        "        labels (torch.tensor): (batch) tensor of labels\n",
        "\n",
        "    Returns:\n",
        "        float: accuracy\n",
        "    \"\"\"\n",
        "    return ((logits.argmax(-1) == labels).sum() / len(labels)).item()\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "def test_loss(model, params, Group_Dataset):\n",
        "    \"\"\"Create all possible pairs (x,y) and return loss and accuracy for G_1 and G_2\"\"\"\n",
        "    test_labels_x = t.tensor([num for num in range(params.N) for _ in range(params.N)]).to(device)\n",
        "    test_labels_y = t.tensor([num % params.N for num in range(params.N * params.N)]).to(device)\n",
        "\n",
        "    logits = model([test_labels_x, test_labels_y])\n",
        "    labels_group_1 = rearrange(Group_Dataset.group1, \"a b-> (a b)\").to(device)\n",
        "    labels_group_2 = rearrange(Group_Dataset.group2, \"a b-> (a b)\").to(device)\n",
        "\n",
        "    loss_group_1 = loss_fn(logits, labels_group_1)\n",
        "    loss_group_2 = loss_fn(logits, labels_group_2)\n",
        "\n",
        "    accuracy_group_1 = get_accuracy(logits, labels_group_1)\n",
        "    accuracy_group_2 = get_accuracy(logits, labels_group_2)\n",
        "\n",
        "    return (loss_group_1, loss_group_2), (accuracy_group_1, accuracy_group_2)\n",
        "\n",
        "random.seed(42)\n",
        "\n",
        "def random_indices(full_dataset, params):\n",
        "    \"\"\"Picks random subset of indices the data given\"\"\"\n",
        "    num_indices = int(len(full_dataset) * params.train_frac)\n",
        "    picked_indices = random.sample(list(range(len(full_dataset))), num_indices)\n",
        "    return picked_indices\n"
      ],
      "metadata": {
        "id": "gG-Y7c0svfa_"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Training function"
      ],
      "metadata": {
        "id": "PRaPj6DA37Pl"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def train(model, params):\n",
        "    current_time = datetime.today().strftime(\"%Y-%m-%d %H:%M:%S\")\n",
        "    wandb.init(\n",
        "\n",
        "        project=\"Grokking ambiguous data\",\n",
        "        name=f\"experiment_{current_time}\",\n",
        "        config={\n",
        "            \"Epochs\": params.num_epoch,\n",
        "            \"Batch size\": params.batch_size,\n",
        "            \"Cardinality\": params.N,\n",
        "            \"Embedded dimension\": params.embed_dim,\n",
        "            \"Hidden dimension\": params.hidden_size,\n",
        "            \"Training\": (params.data_group1, params.data_group2),\n",
        "            \"Added points\": (params.add_points_group1, params.add_points_group2),\n",
        "            \"Train frac\": params.train_frac,\n",
        "            \"Weight decay\": params.weight_decay,\n",
        "            \"Learning rate\": params.lr,\n",
        "            \"Warm up steps\": params.warmup_steps,\n",
        "        },\n",
        "    )\n",
        "    Group_Dataset = GroupData(params=params)\n",
        "\n",
        "    train_data = t.utils.data.Subset(\n",
        "        Group_Dataset, random_indices(Group_Dataset, ExperimentsParameters)\n",
        "    )\n",
        "    train_loader = DataLoader(\n",
        "        dataset=train_data,\n",
        "        batch_size=len(train_data),\n",
        "        shuffle=True,\n",
        "        drop_last=False\n",
        "    )\n",
        "\n",
        "    criterion = t.nn.CrossEntropyLoss()\n",
        "\n",
        "    if params.optimizer == \"sgd\":\n",
        "        optimizer = t.optim.SGD(model.parameters(), lr=params.lr)\n",
        "    if params.optimizer == \"adam\":\n",
        "        optimizer = t.optim.Adam(\n",
        "            model.parameters(),\n",
        "            weight_decay=params.weight_decay,\n",
        "            lr=params.lr,\n",
        "        )\n",
        "    if params.optimizer == \"adamw\":\n",
        "        optimizer = t.optim.AdamW(\n",
        "            model.parameters(),\n",
        "            weight_decay=params.weight_decay,\n",
        "            lr=params.lr,\n",
        "            betas=[params.beta_1, params.beta_2],\n",
        "        )\n",
        "\n",
        "    average_loss_training = 0\n",
        "    step = 0\n",
        "    for epoch in range(params.num_epoch):\n",
        "        with t.no_grad():\n",
        "            model.eval()\n",
        "\n",
        "            average_loss_training = average_loss_training / (params.max_steps_per_epoch)\n",
        "\n",
        "            losses_test, accuracies_test = test_loss(model, params, Group_Dataset)\n",
        "            wandb.log({\"Loss G_1\": losses_test[0], \"Loss G_2\": losses_test[1]})\n",
        "            wandb.log(\n",
        "                {\"Accuracy G_1\": accuracies_test[0], \"Accuracy G_2\": accuracies_test[1]}\n",
        "            )\n",
        "            wandb.log({\"Training loss\": average_loss_training})\n",
        "            average_loss_training = 0\n",
        "        for x, z in train_loader:\n",
        "            global_step = epoch * len(train_data) + step\n",
        "            if global_step < params.warmup_steps:\n",
        "                lr = global_step * params.lr / float(params.warmup_steps)\n",
        "            else:\n",
        "                lr = params.lr\n",
        "            for g in optimizer.param_groups:\n",
        "                g[\"lr\"] = lr\n",
        "\n",
        "            model.train()\n",
        "            optimizer.zero_grad()\n",
        "            output = model(x)\n",
        "            loss = criterion(output, z)\n",
        "            average_loss_training += loss.item()\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "            step += 1\n",
        "\n",
        "    wandb.finish()"
      ],
      "metadata": {
        "id": "AkNCWEy2yN1s"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 4. Running experiments"
      ],
      "metadata": {
        "id": "S-64hhnV4hkE"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "We ran the following experiments:\n",
        "\n",
        "1.   Train a network on a random subset of size $40\\%$ of $D_1$ or $D_2$ respectively (i.e. this is the usual grokking set up)\n",
        "2.   Train a network on $D_{12}$\n",
        "3.   Train a network on $40\\%$ of $D_{12}$\n",
        "4.   Train a network on $D_{12}$ but add a small amount of points from $D_1$ or $D_2$ respectively\n"
      ],
      "metadata": {
        "id": "32TFnS4LcSil"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "\n",
        "To run the experiments yourself, make sure that you ran all the cells in the previous \"Code\" section. You will also need a wandb API key. Alternatively, you can just skip to the next section for a summary."
      ],
      "metadata": {
        "id": "eu9oNfC-7Ys6"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### 4.1. Usual grokking\n",
        "Train a network on a random subset of size $40\\%$ of $D_1$ and $D_2$ respectively i.e. this is the usual grokking set up. This is mostly a sanity check, that the model with its parameters is indeed capable of grokking. Nothing unusual happening here."
      ],
      "metadata": {
        "id": "19Gh4J3548xe"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Q3tZseKpO5hp",
        "colab": {
          "base_uri": "XXXX",
          "height": 514,
          "referenced_widgets": [
            "3995eff9f3e442e889b7598c5fa8f967",
            "8ba7a1031446473198c0352ca4c2e1b8",
            "4c3a27b3806141e19dfd1386d995946e",
            "f2fc2ba4d9e64fb7925032a37922f6b4",
            "0ff803c547cd453aa54756b9c62c47cf",
            "79f1abb960514a379b866167c66c2b8b",
            "658ca26d7d6d446ba9578bcaa3224e5c",
            "734edd7eb4634cd39c9169f95c43a28b"
          ]
        },
        "outputId": "4ff8194a-baff-4430-83e3-05eda462b46c"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ],
            "application/javascript": [
              "\n",
              "        window._wandbApiKey = new Promise((resolve, reject) => {\n",
              "            function loadScript(url) {\n",
              "            return new Promise(function(resolve, reject) {\n",
              "                let newScript = document.createElement(\"script\");\n",
              "                newScript.onerror = reject;\n",
              "                newScript.onload = resolve;\n",
              "                document.body.appendChild(newScript);\n",
              "                newScript.src = url;\n",
              "            });\n",
              "            }\n",
              "            loadScript(\"XXXX\").then(() => {\n",
              "            const iframe = document.createElement('iframe')\n",
              "            iframe.style.cssText = \"width:0;height:0;border:none\"\n",
              "            document.body.appendChild(iframe)\n",
              "            const handshake = new Postmate({\n",
              "                container: iframe,\n",
              "                url: 'XXXX'\n",
              "            });\n",
              "            const timeout = setTimeout(() => reject(\"Couldn't auto authenticate\"), 5000)\n",
              "            handshake.then(function(child) {\n",
              "                child.on('authorize', data => {\n",
              "                    clearTimeout(timeout)\n",
              "                    resolve(data)\n",
              "                });\n",
              "            });\n",
              "            })\n",
              "        });\n",
              "    "
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "\u001b[34m\u001b[1mwandb\u001b[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: XXXX)\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: You can find your API key in your browser here: XXXX\n",
            "wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            " ··········\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "Tracking run with wandb version 0.16.3"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "Run data is saved locally in <code>/content/wandb/run-20240305_150315-hdablvuh</code>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "Syncing run <strong><a href='XXXX' target=\"_blank\">experiment_2024-03-05 15:02:49</a></strong> to <a href='XXXX' target=\"_blank\">Weights & Biases</a> (<a href='XXXX' target=\"_blank\">docs</a>)<br/>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              " View project at <a href='XXXX' target=\"_blank\">XXXX</a>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              " View run at <a href='XXXX' target=\"_blank\">XXXX</a>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded\\r'), FloatProgress(value=0.10963748894783377, max=1.…"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "3995eff9f3e442e889b7598c5fa8f967"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "<style>\n",
              "    table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n",
              "    .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n",
              "    .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
              "    </style>\n",
              "<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>Accuracy G_1</td><td>▁▂▃▄▄▄▄▄▄▄▄▄▅▇▇█████████████████████████</td></tr><tr><td>Accuracy G_2</td><td>▁▂▃▄▄▄▄▄▄▄▄▄▅▇▇█████████████████████████</td></tr><tr><td>Loss G_1</td><td>▄▅▇██▇▆▅▄▄▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>Loss G_2</td><td>▂▃▅▇█▇▆▅▄▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>Training loss</td><td>█▅▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>Accuracy G_1</td><td>0.996</td></tr><tr><td>Accuracy G_2</td><td>0.7484</td></tr><tr><td>Loss G_1</td><td>0.04031</td></tr><tr><td>Loss G_2</td><td>3.68317</td></tr><tr><td>Training loss</td><td>0.0013</td></tr></table><br/></div></div>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              " View run <strong style=\"color:#cdcd00\">experiment_2024-03-05 15:02:49</strong> at: <a href='XXXX' target=\"_blank\">XXXX</a><br/>Synced 4 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "Find logs at: <code>./wandb/run-20240305_150315-hdablvuh/logs</code>"
            ]
          },
          "metadata": {}
        }
      ],
      "source": [
        "ExperimentsParameters = Parameters(data_group1= True, data_group2=False, train_frac=0.4)\n",
        "\n",
        "model = MLP2(ExperimentsParameters).to(device=device)\n",
        "\n",
        "train(model=model, params=ExperimentsParameters)"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Now let's grok $D_2$."
      ],
      "metadata": {
        "id": "1oBpEUllMNyI"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "ExperimentsParameters = Parameters(data_group1= False, data_group2= True, train_frac=0.4)\n",
        "\n",
        "model = MLP2(ExperimentsParameters).to(device=device)\n",
        "\n",
        "train(model=model, params=ExperimentsParameters)"
      ],
      "metadata": {
        "id": "lNxWVFeG7hPn"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### 4.2. Train a network on ambiguous training data $D_{12}$"
      ],
      "metadata": {
        "id": "ffM2p3Sb8eZa"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Now we train the model on the full ambiguous dataset $D_{12}$."
      ],
      "metadata": {
        "id": "ZhWSKv51R4Pv"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "ExperimentsParameters = Parameters()\n",
        "\n",
        "model = MLP2(ExperimentsParameters).to(device=device)\n",
        "\n",
        "train(model=model, params=ExperimentsParameters)"
      ],
      "metadata": {
        "id": "p0ufjK_Y8p7a"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### 4.3. Train a network on 40% of the ambiguous training data $D_{12}$"
      ],
      "metadata": {
        "id": "wBCjcvZIEpz4"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "ExperimentsParameters = Parameters(train_frac=0.4)\n",
        "\n",
        "model = MLP2(ExperimentsParameters).to(device=device)\n",
        "\n",
        "train(model=model, params=ExperimentsParameters)"
      ],
      "metadata": {
        "id": "1iqvT4NnE1aw"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### 4.4. Train a network on ambiguous training data $D_{12}$ and add points form $D_1\\setminus D_2$ or $D_2\\setminus D_1$"
      ],
      "metadata": {
        "id": "SQj4WSyvE96-"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "ExperimentsParameters = Parameters(add_points_group1 = 1) # Alternatively add_points_group2\n",
        "\n",
        "model = MLP2(ExperimentsParameters).to(device=device)\n",
        "\n",
        "train(model=model, params=ExperimentsParameters)"
      ],
      "metadata": {
        "id": "fSOtydMFGHDV"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 5. Results\n"
      ],
      "metadata": {
        "id": "ZYdzxLBI830h"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "So far, most of our focus has been on experiment 4.2, i.e. we train the model on the full ambiguous data set $D_{12}$.\n",
        "Let us first compare the accuracy for $f_1$ and $f_2$ (denoted by G_1 and G_2 in the graph below).\n"
      ],
      "metadata": {
        "id": "iNT9aNMK86wh"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "\n",
        "\n",
        "<img src=\"XXXX\" width =500>\n",
        "<img src=\"XXXX\" width =500>"
      ],
      "metadata": {
        "id": "1Rl27rus-8Zz"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Interactive version: [Accuracy G_1](XXXX)\n",
        "[Accuracy G_2](XXXX)"
      ],
      "metadata": {
        "id": "Go2vERjp_4ov"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "We make the following observations:\n",
        "\n",
        "\n",
        "1.   In three cases $f_1$ was fully grokked,\n",
        "2.   On the other hand $f_2$ was never grokked,\n",
        "3.   In most cases the accuracy of $f_1$ plateaus at some intermediate level,\n",
        "4.   The accuracy of $f_2$ bumps up once. But notice that in this run the accuracy of $f_1$ and $f_2$ goes up simulatenously (around 400-800 steps).\n",
        "\n"
      ],
      "metadata": {
        "id": "nELO_951AKoV"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Some consequences are:\n",
        "\n",
        "1. It seems like $f_1$ is easier to learn than $f_2$. From a human intuition point of view this is reasonable. $f_1$ is a commutative group, as opposed to $f_2$ which is not commutative and by default seems easier to comprehend. Group operations are monsters of symmetry and the easiest form of a symmetry is $x+_1y=y+_1x$,\n",
        "2. But this raises the question: What happens at most of the other runs? Can we give some comprehensible interpretation to the intermediate levels?\n",
        "3. How can we understand the bump which occurs simulatenously in $f_1$ and $f_2$?\n",
        "\n"
      ],
      "metadata": {
        "id": "FfSbShnDBSm5"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let's look at the loss curves."
      ],
      "metadata": {
        "id": "8WoTrGW7Xngq"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "<img src=XXXX width =500>\n",
        "<img src=XXXX width =500>"
      ],
      "metadata": {
        "id": "g0q0zNqYReyl"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Interactive version: [Loss G_1](XXXX)\n",
        "[Loss G_2](XXXX)"
      ],
      "metadata": {
        "id": "lOdQTE7TCu0g"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "We observe again that $f_2$ exhibits some sort of slingshot behaviour between 400-800 steps."
      ],
      "metadata": {
        "id": "xuYVtnblDgP3"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 6. Evaluation of results\n"
      ],
      "metadata": {
        "id": "ret7bJmivUW7"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let's focus on two runs (more to be added).\n",
        "\n",
        "The first corresponds to 4.1 and the second to 4.3.\n",
        "\n",
        "\n"
      ],
      "metadata": {
        "id": "-wDmk4Lzv4f-"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### 6.1 Usual grokking and local learning coefficient measure (as in 4.1)"
      ],
      "metadata": {
        "id": "DS5w1KB4B6HQ"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let's look at the losses:\n",
        "\n",
        "<img src=XXXX"
      ],
      "metadata": {
        "id": "-YTPFwi2Wkcv"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "This is what we expected to see. The training loss (in yellow) goes down very quickly. The test loss of $G_1$ goes down to approx. $0$ eventually."
      ],
      "metadata": {
        "id": "X2rFYfmIVY81"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Here is a visualization of the training run (open gif in new tab to view the run again):\n",
        "\n",
        "<img src=XXXX\n",
        "\n",
        "Each cell corresponds to one data point $x\\in X$. A cell is:\n",
        "- orange, if the model learned $G_1$,\n",
        "- green, if the model learned $G_2$,\n",
        "- blue, if the model learned the data point correctly and the output is the same for $G_1$ and $G_2$.\n",
        "\n",
        "\n",
        "In this case we see that roughly $75$% is blue and since we grokked $G_1$, the rest is orange.\n"
      ],
      "metadata": {
        "id": "veZlddpQVbBQ"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Now let's look at the local learning coefficient and the accuracy of $G_1$.\n",
        "\n",
        "<img src=XXXX\n",
        "\n"
      ],
      "metadata": {
        "id": "JemHC9zobdKq"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Something that is interesting here: The LLC changes over time and indicates that the model is doing something!\n",
        "\n",
        "Of course in this case, we knew what to test for (namely the accuracy of $G_1$). But imagine we were training the model without this knowledge. Then we would just see (almost) $0$ training loss, but not that something is happening in the background!\n",
        "\n",
        "The measurements are still a bit all over the place. Better calibration (running more finegrained sweeps over LLC estimation hyperparameters) and more accurate measurements (sampling more chains from the local loss landscape) might give a better picture."
      ],
      "metadata": {
        "id": "65ff7IauW6E4"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### 6.2 Grokking and local learning coefficient measure in the ambiguous case (as in 4.3)\n",
        "\n"
      ],
      "metadata": {
        "id": "J6VCthHdW7Vo"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let's again look at the losses first:\n",
        "<img src=XXXX"
      ],
      "metadata": {
        "id": "N3yC7o3xYmvx"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "There is a weird slingshot behaviour. Although $G_2$ starts off better, in the end $G_1$ is grokked.\n",
        "Let's look at the training run again (open gif in new tab\n",
        " to see the run again)"
      ],
      "metadata": {
        "id": "ETDuVeHzdS2H"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "<img src=XXXX\n",
        "\n",
        "We can see that at initialization that $G_2$ is prefered: in the first frame at epoch $0$ there are more green points, But then eventually $G_1$ trumps it, as also the above losses indicate."
      ],
      "metadata": {
        "id": "Pn1wrFRPc152"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Now let's look at the local learning coefficient and the accuracy of $G_1$.\n",
        "\n",
        "<img src=XXXX"
      ],
      "metadata": {
        "id": "-vqKcLUsdse1"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "This is slightly odd: It seems like we are grokking $G_1$ twice.\n",
        "\n",
        "\n",
        "First, around step $5$ and then again around $20$ (corresponding to epoch $285$ and $1140$).\n",
        "\n",
        "Also, while its seems like step $5$ is spotted by the LLC, the change in step $20$ goes unchanged.\n"
      ],
      "metadata": {
        "id": "huP0ww9Te78K"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 7. Further steps"
      ],
      "metadata": {
        "id": "m0xmNWfhGkRj"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "**Next steps which are already in progress:**\n",
        "\n",
        "- Run more experiments to get more reliable data on outcomes\n",
        "- Run experiments where there are more than $2$ group structures that can be learned"
      ],
      "metadata": {
        "id": "FuNGQ5jdGKod"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "**Beyond that**:\n",
        "\n",
        "\n",
        "- Try to solve the mystery of the intermediate grokking and ungrokking\n",
        "- Start measuring the formation of circuit (size)"
      ],
      "metadata": {
        "id": "MS38T1IMHJpx"
      }
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": [],
      "gpuType": "T4",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "3995eff9f3e442e889b7598c5fa8f967": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "VBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "VBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "VBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_8ba7a1031446473198c0352ca4c2e1b8",
              "IPY_MODEL_4c3a27b3806141e19dfd1386d995946e"
            ],
            "layout": "IPY_MODEL_f2fc2ba4d9e64fb7925032a37922f6b4"
          }
        },
        "8ba7a1031446473198c0352ca4c2e1b8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "LabelModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "LabelModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "LabelView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_0ff803c547cd453aa54756b9c62c47cf",
            "placeholder": "​",
            "style": "IPY_MODEL_79f1abb960514a379b866167c66c2b8b",
            "value": "0.011 MB of 0.011 MB uploaded\r"
          }
        },
        "4c3a27b3806141e19dfd1386d995946e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_658ca26d7d6d446ba9578bcaa3224e5c",
            "max": 1,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_734edd7eb4634cd39c9169f95c43a28b",
            "value": 1
          }
        },
        "f2fc2ba4d9e64fb7925032a37922f6b4": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "0ff803c547cd453aa54756b9c62c47cf": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "79f1abb960514a379b866167c66c2b8b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "658ca26d7d6d446ba9578bcaa3224e5c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "734edd7eb4634cd39c9169f95c43a28b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}