{
 "cells": [
  {
   "metadata": {
    "id": "704ffe9c3fc64d90"
   },
   "cell_type": "markdown",
   "source": [
    "# Decomposed Learning an Avenue for Mitigating Grokking\n",
    "\n",
    "This notebook provides an example of decomposed learning for the modular addition grokking task: $a + b \\mod(prime) = y$, where $prime=97$.\n",
    "\n",
    "\n",
    "***Codebase is inspired from this notebook: https://github.com/enerrio/grokking-mlp/blob/main/grokking-modadd.ipynb***\n",
    "\n",
    "\n",
    "\n",
    "**Runtime: CPU**\n",
    "\n",
    "**PLEASE NOTE:**\n",
    " - ***That only the dataset condition explored is the 80% of the data for training condition. This is because it takes appromixately 45 mins to train whereas with 65% and 50% of the training data it takes appromixately 1.5 and 3 hours respectively.***\n",
    "\n",
    " - ***That in this notebook, we will only train one baseline model. This is because the models take about ~45 mins to train and in serial to create an average across 10 it would take 45(time to train)x10(number to average over) = 450 minitues, which would be 7.5 hours.***\n",
    "\n",
    " - ***We train only one decomposed model for each rank condition for only 1,000 epochs instead of 10,000, as this means that section takes ~18 mins instead of 3 hours.***\n",
    "\n",
    " - ***Due to floating point arithmetic and the use of different CPU's the results may varry slightly from the paper, however the general trend will be consistant.***\n",
    "\n",
    "***Appromixate Runtime of the notebook is ~1 hours***\n",
    "\n",
    "\n"
   ],
   "id": "704ffe9c3fc64d90"
  },
  {
   "metadata": {
    "id": "591c2040a1b0722c"
   },
   "cell_type": "markdown",
   "source": [
    "## Creating the Dataset\n",
    "\n",
    "The following section creates the dataset that is used for training and testing."
   ],
   "id": "591c2040a1b0722c"
  },
  {
   "metadata": {
    "id": "initial_id"
   },
   "cell_type": "code",
   "source": [
    "from tqdm import tqdm\n",
    "import torch\n",
    "import numpy as np\n",
    "from sklearn import model_selection\n",
    "import random\n",
    "from einops.layers.torch import Rearrange\n",
    "from torch import nn, optim\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd"
   ],
   "id": "initial_id",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "10b42e8c53c53e25",
    "outputId": "e8041407-51bb-41ab-813c-231a467f6a74"
   },
   "cell_type": "code",
   "source": [
    "test_splits = [0.5, 0.35, 0.2]\n",
    "# Config\n",
    "SEED = 0\n",
    "prime = 97\n",
    "operation_token = \"+\"\n",
    "equal_token = \"=\"\n",
    "# Create token dictionaries\n",
    "char2idx = {str(i): i for i in range(prime)}\n",
    "char2idx[operation_token] = len(char2idx)\n",
    "char2idx[equal_token] = len(char2idx)\n",
    "idx2char = {v: k for k, v in char2idx.items()}\n",
    "\n",
    "# Create all permutations of (a + b) % prime\n",
    "a = np.zeros((prime,prime), dtype=np.uint8)\n",
    "b = np.zeros((prime,prime), dtype=np.uint8)\n",
    "for i in range(prime):\n",
    "    for j in range(prime):\n",
    "      a[i,j]= i\n",
    "      b[i,j]= j\n",
    "a = a.flatten()\n",
    "b = b.flatten()\n",
    "operation = np.zeros((prime*prime,), dtype=np.uint8) + char2idx[operation_token]\n",
    "equal = np.zeros((prime*prime,), dtype=np.uint8) + char2idx[equal_token]\n",
    "# Create Features\n",
    "features = np.stack([a, operation, b, equal], axis=-1).astype(np.uint8)\n",
    "# Create Labels\n",
    "y = (a + b) % prime\n",
    "\n",
    "# Create Train Test Split\n",
    "for test_split in tqdm(test_splits):\n",
    "  torch.manual_seed(SEED)\n",
    "  np.random.seed(SEED)\n",
    "  random.seed(SEED)\n",
    "  X_train, X_test, Y_train, Y_test = model_selection.train_test_split(\n",
    "      features, y, test_size=test_split, random_state=SEED)\n",
    "  # Create Dataset Dictionary\n",
    "  dataset = {\n",
    "            \"train\": {\"features\": torch.tensor(X_train), \"label\":torch.tensor(Y_train)},\n",
    "            \"test\": {\"features\": torch.tensor(X_test), \"label\":torch.tensor(Y_test)}\n",
    "           }\n",
    "  torch.save(dataset, f\"a{operation_token}b_mod_{prime}_test_split_{test_split:.2f}_seed_{SEED}.pth\")\n"
   ],
   "id": "10b42e8c53c53e25",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {
    "id": "7a383b53919c6d9e"
   },
   "cell_type": "markdown",
   "source": [
    "## Define Model\n",
    "\n",
    "This section creates the model definition. The Model was adapated from https://github.com/enerrio/grokking-mlp/blob/main/grokking-modadd.ipynb."
   ],
   "id": "7a383b53919c6d9e"
  },
  {
   "metadata": {
    "id": "ec2dbdb506915659"
   },
   "cell_type": "code",
   "source": [
    "class MLP(nn.Module):\n",
    "    \"\"\"Adapted from https://github.com/enerrio/grokking-mlp/blob/main/grokking-modadd.ipynb\"\"\"\n",
    "\n",
    "    def __init__(self, in_features, out_features):\n",
    "        super().__init__()\n",
    "        self.embed = nn.Embedding(in_features, 128)\n",
    "        # Change shape from batch, token, model_width to batch, token*model_width, for example\n",
    "        # 1,4,128 becomes 1, 512.\n",
    "        self.rearrange = Rearrange(\"batch token model_width -> batch (token model_width)\")\n",
    "        self.linear = nn.Linear(4*128, 128)\n",
    "        self.relu = nn.ReLU()\n",
    "        self.out = nn.Linear(128, out_features)\n",
    "        self.initialize_params()\n",
    "\n",
    "    def initialize_params(self):\n",
    "\n",
    "        nn.init.kaiming_normal_(self.linear.weight, nonlinearity=\"relu\")\n",
    "        nn.init.zeros_(self.linear.bias)\n",
    "\n",
    "        nn.init.kaiming_normal_(self.out.weight, nonlinearity=\"relu\")\n",
    "        nn.init.zeros_(self.out.bias)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.embed(x)\n",
    "        x = self.rearrange(x)\n",
    "        x = self.linear(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.out(x)\n",
    "        return x"
   ],
   "id": "ec2dbdb506915659",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {
    "id": "31802b064a29293"
   },
   "cell_type": "markdown",
   "source": [
    "## Define SVD $(U_k\\Sigma_k V_k^T)$ Embedding and Linear Layers\n",
    "\n",
    "This section defines the SVDEmbedding Layer and SVDLinear Layers."
   ],
   "id": "31802b064a29293"
  },
  {
   "metadata": {
    "id": "3934987ef20c0f4a"
   },
   "cell_type": "markdown",
   "source": [
    "### SVDEmbedding"
   ],
   "id": "3934987ef20c0f4a"
  },
  {
   "metadata": {
    "id": "ae08253749754644"
   },
   "cell_type": "code",
   "source": [
    "class SVDEmbedding(nn.Module):\n",
    "    \"\"\"Convert an Embedding Layer into an SVDEmbedding Layer.\"\"\"\n",
    "    def __init__(self, layer:nn.Embedding, rank:int):\n",
    "        \"\"\"Convert an Embedding Layer into an SVDEmbedding Layer.\n",
    "        Args:\n",
    "            layer: Layer to be converted.\n",
    "            rank: Rank that the layer will have.\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "\n",
    "        self.rank = rank\n",
    "        # Embedding Layer Properties\n",
    "        self.padding_idx = layer.padding_idx\n",
    "        self.max_norm = layer.max_norm\n",
    "        self.norm_type = layer.norm_type\n",
    "        self.scale_grad_by_freq = layer.scale_grad_by_freq\n",
    "        self.sparse = layer.sparse\n",
    "        # Perform SVD on the weights of the Linear Layer (https://docs.pytorch.org/docs/stable/generated/torch.linalg.svd.html)\n",
    "        # By default V is return as V^T.\n",
    "        u,s,v = torch.linalg.svd(layer.weight.data.detach(), full_matrices=False)\n",
    "        # Create $U_K$ and allow gradient updates\n",
    "        self.U_approx = torch.nn.Parameter(u[:,:rank].clone().detach(), requires_grad=True)\n",
    "        # Create $\\Sigma_K$ and allow gradient updates\n",
    "        self.S_approx = torch.nn.Parameter(s[:rank].clone().detach(), requires_grad=True)\n",
    "        # Create $V_K^T$ and allow gradient updates\n",
    "        self.V_approx = torch.nn.Parameter(v[:rank,:].clone().detach(), requires_grad=True)\n",
    "\n",
    "    def forward(self, input):\n",
    "        # S_approx is diagonalised to make it a matrix\n",
    "        return torch.nn.functional.embedding(input, self.U_approx@torch.diag(self.S_approx)@self.V_approx,             self.padding_idx,\n",
    "            self.max_norm,\n",
    "            self.norm_type,\n",
    "            self.scale_grad_by_freq,\n",
    "            self.sparse,)"
   ],
   "id": "ae08253749754644",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {
    "id": "d8bf84791fc14841"
   },
   "cell_type": "markdown",
   "source": [
    "### SVDLinear Layer"
   ],
   "id": "d8bf84791fc14841"
  },
  {
   "metadata": {
    "id": "43628721245dde1f"
   },
   "cell_type": "code",
   "source": [
    "class SVDLinear(nn.Module):\n",
    "    \"\"\"Convert a Linear Layer into an SVDLinear Layer.\"\"\"\n",
    "    def __init__(self, layer:nn.Linear, rank:int):\n",
    "        \"\"\"Convert a Linear Layer into an SVDLinear Layer.\n",
    "        Args:\n",
    "            layer: Layer to be converted.\n",
    "            rank: Rank that the layer will have.\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        self.rank = rank\n",
    "        # Perform SVD on the weights of the Linear Layer (https://docs.pytorch.org/docs/stable/generated/torch.linalg.svd.html)\n",
    "        # By default V is return as V^T.\n",
    "        u,s,v = torch.linalg.svd(layer.weight.data.detach(), full_matrices=False)\n",
    "        self.bias = layer.bias\n",
    "        # Create $U_K$ and allow gradient updates\n",
    "        self.U_approx = torch.nn.Parameter(u[:,:rank].clone().detach(), requires_grad=True)\n",
    "        # Create $\\Sigma_K$ and allow gradient updates\n",
    "        self.S_approx = torch.nn.Parameter(s[:rank].clone().detach(), requires_grad=True)\n",
    "        # Create $V_K^T$ and allow gradient updates\n",
    "        self.V_approx = torch.nn.Parameter(v[:rank,:].clone().detach(), requires_grad=True)\n",
    "\n",
    "    def forward(self, input):\n",
    "        # S_approx is diagonalised to make it a matrix\n",
    "        return torch.nn.functional.linear(input, self.U_approx@torch.diag(self.S_approx)@self.V_approx, self.bias)"
   ],
   "id": "43628721245dde1f",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {
    "id": "c9bd40f9bf015b4a"
   },
   "cell_type": "markdown",
   "source": [
    "## Utility Functions"
   ],
   "id": "c9bd40f9bf015b4a"
  },
  {
   "metadata": {
    "id": "2c0b71457a7d9f22"
   },
   "cell_type": "markdown",
   "source": [
    "### Data Loader Reproducibility"
   ],
   "id": "2c0b71457a7d9f22"
  },
  {
   "metadata": {
    "id": "fb7fb3ebc6b212dc"
   },
   "cell_type": "code",
   "source": [
    "\n",
    "def seed_worker(worker_id):\n",
    "    # https://docs.pytorch.org/docs/stable/notes/randomness.html#dataloader\n",
    "    worker_seed = torch.initial_seed() % 2**32\n",
    "    np.random.seed(worker_seed)\n",
    "    random.seed(worker_seed)"
   ],
   "id": "fb7fb3ebc6b212dc",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {
    "id": "82e97598d28b1252"
   },
   "cell_type": "markdown",
   "source": [
    "### Training"
   ],
   "id": "82e97598d28b1252"
  },
  {
   "metadata": {
    "id": "334b27618d43e644"
   },
   "cell_type": "code",
   "source": [
    "def train_step(net, train_dataloader, criterion, optimizer):\n",
    "    # Step through the dataset\n",
    "    net.train()\n",
    "    total_labels = total_correct = total_loss = 0.\n",
    "    for inputs, labels in train_dataloader:\n",
    "        optimizer.zero_grad()\n",
    "        outputs = net(inputs)\n",
    "        outputs = outputs.type(torch.float64)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        total_correct += (outputs.detach().cpu().softmax(-1).argmax(-1) == labels.detach().cpu()).numpy().sum().item()\n",
    "        total_labels += labels.size(0)\n",
    "        total_loss += loss.item()\n",
    "    epoch_acc = (total_correct / total_labels) * 100.\n",
    "    epoch_loss = total_loss / len(train_dataloader)\n",
    "    return epoch_acc, epoch_loss"
   ],
   "id": "334b27618d43e644",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {
    "id": "c0350dfbd6a04d9d"
   },
   "cell_type": "code",
   "source": [
    "def eval_step(net, test_dataloader, criterion):\n",
    "    total_labels = total_correct = total_loss = 0.\n",
    "    with torch.no_grad():\n",
    "        net.eval()\n",
    "        for inputs, labels in test_dataloader:\n",
    "            # forward pass\n",
    "            outputs = net(inputs)\n",
    "            test_loss = criterion(outputs, labels)\n",
    "            # calculate metrics\n",
    "            total_correct += (outputs.detach().cpu().softmax(-1).argmax(-1) == labels.detach().cpu()).numpy().sum().item()\n",
    "            total_loss += test_loss.item()\n",
    "            total_labels += labels.size(0)\n",
    "    test_acc = (total_correct / total_labels) * 100.\n",
    "    test_loss = total_loss / len(test_dataloader)\n",
    "    return test_acc, test_loss"
   ],
   "id": "c0350dfbd6a04d9d",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {
    "id": "6e0374ba30f14140"
   },
   "cell_type": "code",
   "source": [
    "def train_and_evaluate(net, train_dataloader, test_dataloader, criterion, optimizer, epochs, save_name):\n",
    "    stats = {\n",
    "        \"epoch_train_losses\": [],\n",
    "        \"epoch_train_accs\": [],\n",
    "        \"epoch_test_losses\": [],\n",
    "        \"epoch_test_accs\": [],\n",
    "    }\n",
    "    for epoch in tqdm(range(epochs)):\n",
    "        train_acc, train_loss = train_step(net, train_dataloader, criterion, optimizer)\n",
    "        stats[\"epoch_train_losses\"].append(train_loss)\n",
    "        stats[\"epoch_train_accs\"].append(train_acc)\n",
    "        test_acc, test_loss = eval_step(net, test_dataloader, criterion)\n",
    "        stats[\"epoch_test_losses\"].append(test_loss)\n",
    "        stats[\"epoch_test_accs\"].append(test_acc)\n",
    "        torch.save(stats, save_name)\n"
   ],
   "id": "6e0374ba30f14140",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Configs for Training"
   ],
   "metadata": {
    "id": "S1nKGSrnBSPU"
   },
   "id": "S1nKGSrnBSPU"
  },
  {
   "metadata": {
    "id": "86fd58ebc059e02a"
   },
   "cell_type": "code",
   "source": [
    "SEED = 0\n",
    "DATA_ORDER_SEED = 0\n",
    "PRIME = 97\n",
    "TEST_SPLIT = 0.2\n",
    "EPOCHS = 10_000\n",
    "DEVICE = torch.device(\"cpu\")"
   ],
   "id": "86fd58ebc059e02a",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {
    "id": "746dc484af0ad8ad"
   },
   "cell_type": "markdown",
   "source": [
    "## Train Base Models\n",
    "\n",
    "Main Run Loop.\n",
    "\n",
    "In this notebook, we will only train one model, with 80% training dataset as it would take several hours to create the mean report as one run can take approxmiately 45mins.\n"
   ],
   "id": "746dc484af0ad8ad"
  },
  {
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "9b76cfbb1ea7421e",
    "outputId": "c3e39b99-5e26-42c1-a0b0-c755aec151a6"
   },
   "cell_type": "code",
   "source": [
    "print(f\"TRAINING With {(1-TEST_SPLIT)*100}% of Data for {EPOCHS} Epochs.\")\n",
    "### Ensure Reproducibility:\n",
    "os.environ[\"PYTHONHASHSEED\"] = str(SEED)\n",
    "torch.use_deterministic_algorithms(True)\n",
    "torch.backends.cudnn.deterministic = True\n",
    "torch.backends.cudnn.benchmark = False\n",
    "torch.set_num_threads(1)\n",
    "torch.manual_seed(SEED)\n",
    "np.random.seed(SEED)\n",
    "random.seed(SEED)\n",
    "rng_generator = torch.Generator().manual_seed(DATA_ORDER_SEED)\n",
    "# Get dataset\n",
    "dataset = torch.load( f\"a+b_mod_{PRIME}_test_split_{TEST_SPLIT:.2f}_seed_{DATA_ORDER_SEED}.pth\", weights_only=True)\n",
    "train_dataset_features = dataset[\"train\"][\"features\"].clone().to(torch.long, copy=True)\n",
    "train_dataset_labels = dataset[\"train\"][\"label\"].clone().to(torch.long,copy=True)\n",
    "test_dataset_features = dataset[\"test\"][\"features\"].clone().to(torch.long,copy=True)\n",
    "test_dataset_labels = dataset[\"test\"][\"label\"].clone().to(torch.long,copy=True)\n",
    "\n",
    "train_dataset = TensorDataset(train_dataset_features.to(DEVICE), train_dataset_labels.to(DEVICE))\n",
    "test_dataset = TensorDataset(test_dataset_features.to(DEVICE), test_dataset_labels.to(DEVICE))\n",
    "# Create DataLoader\n",
    "train_dataloader = DataLoader(\n",
    "    train_dataset,\n",
    "    batch_size=512,\n",
    "    shuffle=True,\n",
    "    worker_init_fn=seed_worker,\n",
    "    generator=rng_generator,\n",
    "    num_workers=0,\n",
    ")\n",
    "\n",
    "test_dataloader = DataLoader(\n",
    "    test_dataset,\n",
    "    batch_size=512,\n",
    "    shuffle=False,\n",
    "    worker_init_fn=seed_worker,\n",
    "    generator=rng_generator,\n",
    "    num_workers=0,\n",
    ")\n",
    "\n",
    "print(f\"Total samples in train set: {len(train_dataset):,}\")\n",
    "print(f\"Total samples in val set: {len(test_dataset):,}\")\n",
    "\n",
    "print(f\"Total number of batches in train dataloader: {len(train_dataloader):,}\")\n",
    "print(f\"Total number of batches in validation dataloader: {len(test_dataloader):,}\")\n",
    "\n",
    "  # Create Baseline Model\n",
    "os.makedirs(f\"a+b_mod_{PRIME}/init/\", exist_ok=True)\n",
    "net = MLP(99,99)\n",
    "if os.path.isfile(f\"a+b_mod_{PRIME}/init/mlp_seed_{SEED}.pth\"):\n",
    "    net.load_state_dict(torch.load(f\"a+b_mod_{PRIME}/init/mlp_seed_{SEED}.pth\", weights_only=True))\n",
    "else:\n",
    "    torch.save(net.state_dict(), f\"a+b_mod_{PRIME}/init/mlp_seed_{SEED}.pth\")\n",
    "net.to(DEVICE)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.AdamW(net.parameters(), lr=1e-3, betas=(0.9, 0.98))\n",
    "\n",
    "print(f\"Number of parameters in model: {sum([x.numel() for x in net.parameters()]):,}\")\n",
    "os.makedirs(f\"a+b_mod_{PRIME}/test_split/{TEST_SPLIT:.2f}/base/\", exist_ok=True)\n",
    "# Train Model\n",
    "save_name = f\"a+b_mod_{PRIME}/test_split/{TEST_SPLIT:.2f}/base/stats_seed_{SEED}_dos_{DATA_ORDER_SEED}.pth\"\n",
    "train_and_evaluate(net, train_dataloader, test_dataloader, criterion, optimizer, EPOCHS, save_name)"
   ],
   "id": "9b76cfbb1ea7421e",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Train Decomposed Learning Models\n",
    "\n",
    "Main Run Loop.\n",
    "\n",
    "In this notebook, we will only train one model, with 12.5\\%, 25\\%, 50\\% and 100% of the ranks. We train only one decomposed model for each rank condition for only 1,000 epochs instead of 10,000, as this means that section takes ~18 mins instead of 3 hours."
   ],
   "metadata": {
    "id": "YzPtOHH0BiZr"
   },
   "id": "YzPtOHH0BiZr"
  },
  {
   "cell_type": "code",
   "source": [
    "SEED = 0\n",
    "DATA_ORDER_SEED = 0\n",
    "PRIME = 97\n",
    "TEST_SPLIT = 0.2\n",
    "EPOCHS = 1_000\n",
    "DEVICE = torch.device(\"cpu\")"
   ],
   "metadata": {
    "id": "3zTkrRlj-Owa"
   },
   "id": "3zTkrRlj-Owa",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "source": [
    "for rank in [0.125, 0.25, 0.5, 1.0]:\n",
    "  EMBEDDING_AND_OUT_RANK = int(99*rank)\n",
    "  LINEAR_LAYER_RANK = int(128*rank)\n",
    "  title_text= f\"TRAINING With {(1-TEST_SPLIT)*100}% of Data @ {rank*100}% of Ranks for {EPOCHS} Epochs.\"\n",
    "  print(\"*\"*len(title_text))\n",
    "  print(title_text)\n",
    "  print(\"*\"*len(title_text))\n",
    "  ### Ensure Reproducibility:\n",
    "  os.environ[\"PYTHONHASHSEED\"] = str(SEED)\n",
    "  torch.use_deterministic_algorithms(True)\n",
    "  torch.backends.cudnn.deterministic = True\n",
    "  torch.backends.cudnn.benchmark = False\n",
    "  torch.set_num_threads(1)\n",
    "  torch.manual_seed(SEED)\n",
    "  np.random.seed(SEED)\n",
    "  random.seed(SEED)\n",
    "  rng_generator = torch.Generator().manual_seed(DATA_ORDER_SEED)\n",
    "  # Get dataset\n",
    "  dataset = torch.load( f\"a+b_mod_{PRIME}_test_split_{TEST_SPLIT:.2f}_seed_{DATA_ORDER_SEED}.pth\", weights_only=True)\n",
    "  train_dataset_features = dataset[\"train\"][\"features\"].clone().to(torch.long, copy=True)\n",
    "  train_dataset_labels = dataset[\"train\"][\"label\"].clone().to(torch.long,copy=True)\n",
    "  test_dataset_features = dataset[\"test\"][\"features\"].clone().to(torch.long,copy=True)\n",
    "  test_dataset_labels = dataset[\"test\"][\"label\"].clone().to(torch.long,copy=True)\n",
    "\n",
    "  train_dataset = TensorDataset(train_dataset_features.to(DEVICE), train_dataset_labels.to(DEVICE))\n",
    "  test_dataset = TensorDataset(test_dataset_features.to(DEVICE), test_dataset_labels.to(DEVICE))\n",
    "  # Create DataLoader\n",
    "  train_dataloader = DataLoader(\n",
    "      train_dataset,\n",
    "      batch_size=512,\n",
    "      shuffle=True,\n",
    "      worker_init_fn=seed_worker,\n",
    "      generator=rng_generator,\n",
    "      num_workers=0,\n",
    "  )\n",
    "\n",
    "  test_dataloader = DataLoader(\n",
    "      test_dataset,\n",
    "      batch_size=512,\n",
    "      shuffle=False,\n",
    "      worker_init_fn=seed_worker,\n",
    "      generator=rng_generator,\n",
    "      num_workers=0,\n",
    "  )\n",
    "\n",
    "  print(f\"Total samples in train set: {len(train_dataset):,}\")\n",
    "  print(f\"Total samples in val set: {len(test_dataset):,}\")\n",
    "\n",
    "  print(f\"Total number of batches in train dataloader: {len(train_dataloader):,}\")\n",
    "  print(f\"Total number of batches in validation dataloader: {len(test_dataloader):,}\")\n",
    "\n",
    "  # Get Model Init\n",
    "  os.makedirs(f\"a+b_mod_{PRIME}/init/\", exist_ok=True)\n",
    "  net = MLP(99,99)\n",
    "  if os.path.isfile(f\"a+b_mod_{PRIME}/init/mlp_seed_{SEED}.pth\"):\n",
    "      net.load_state_dict(torch.load(f\"a+b_mod_{PRIME}/init/mlp_seed_{SEED}.pth\", weights_only=True))\n",
    "  else:\n",
    "      torch.save(net.state_dict(), f\"a+b_mod_{PRIME}/init/mlp_seed_{SEED}.pth\")\n",
    "  net.to(DEVICE)\n",
    "  # Convert To UΣV^T\n",
    "  net.embed = SVDEmbedding(net.embed, EMBEDDING_AND_OUT_RANK)\n",
    "  net.linear = SVDLinear(net.linear, LINEAR_LAYER_RANK)\n",
    "  net.out = SVDLinear(net.out, EMBEDDING_AND_OUT_RANK)\n",
    "  criterion = nn.CrossEntropyLoss()\n",
    "  optimizer = optim.AdamW(net.parameters(), lr=1e-3, betas=(0.9, 0.98))\n",
    "\n",
    "  print(f\"Number of parameters in model: {sum([x.numel() for x in net.parameters()]):,}\")\n",
    "  os.makedirs(f\"a+b_mod_{PRIME}/test_split/{TEST_SPLIT:.2f}/rank/{rank}/\", exist_ok=True)\n",
    "  # # Train Model\n",
    "  save_name = f\"a+b_mod_{PRIME}/test_split/{TEST_SPLIT:.2f}/rank/{rank}/stats_seed_{SEED}_dos_{DATA_ORDER_SEED}.pth\"\n",
    "  train_and_evaluate(net, train_dataloader, test_dataloader, criterion, optimizer, EPOCHS, save_name)\n",
    "  print()"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "oxHep55t9Fa7",
    "outputId": "a9798a89-a105-4d63-80fd-26dee753fb0d"
   },
   "id": "oxHep55t9Fa7",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Figures and Tables\n",
    "\n",
    "This section creates the plots and tables as they were created for the paper.\n",
    "\n",
    "***PLEASE NOTE: The plot and table will not be exactly the same as this notebook only explores one seed and training size.***\n"
   ],
   "metadata": {
    "id": "gcku6zFSoRjs"
   },
   "id": "gcku6zFSoRjs"
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Configs"
   ],
   "metadata": {
    "id": "TVEP1HeLo5eZ"
   },
   "id": "TVEP1HeLo5eZ"
  },
  {
   "cell_type": "code",
   "source": [
    "num_to_average_over = 1\n",
    "colors = [\n",
    "    \"tab:orange\",\n",
    "    \"tab:red\",\n",
    "    \"tab:green\",\n",
    "    \"tab:purple\",\n",
    "    \"tab:brown\",\n",
    "    \"tab:pink\",\n",
    "    \"tab:cyan\",\n",
    "]\n",
    "folder = \"a+b_mod_97/test_split\"\n",
    "percentages = [\"0.20\"]\n",
    "rank_plots = {1.0: \"100%\", 0.5: \"50%\", 0.25: \"25%\", 0.125: \"12.5%\"}"
   ],
   "metadata": {
    "id": "Sm5PYsGno9Mp"
   },
   "id": "Sm5PYsGno9Mp",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Plotting Figure\n",
    "\n",
    "\n",
    "***PLEASE NOTE: The plot will not be exactly the same as this notebook only explores one seed and one dataset size.***"
   ],
   "metadata": {
    "id": "zsBrSPMmsIBW"
   },
   "id": "zsBrSPMmsIBW"
  },
  {
   "cell_type": "code",
   "source": [
    "for percentage in percentages:\n",
    "    fig, ax = plt.subplots(figsize=(9, 5), constrained_layout=True)\n",
    "    # Baseline Plot\n",
    "    train_steps = []\n",
    "    test_steps = []\n",
    "    for i in range(num_to_average_over):\n",
    "        train_stats = torch.load(\n",
    "            f\"{folder}/{percentage}/base/stats_seed_{i}_dos_0.pth\", weights_only=True\n",
    "        )\n",
    "        train_steps.append(train_stats[\"epoch_train_accs\"])\n",
    "        test_steps.append(train_stats[\"epoch_test_accs\"])\n",
    "    train_steps = np.array(train_steps)\n",
    "    test_steps = np.array(test_steps)\n",
    "    # Train Plot\n",
    "    ax.plot(train_steps.mean(axis=0), label=\"Train Baseline\", color=\"k\")\n",
    "    SEM = train_steps.std(axis=0) / np.sqrt(train_steps.shape[0])\n",
    "    ax.fill_between(\n",
    "        np.arange(len(SEM)),\n",
    "        train_steps.mean(axis=0) + SEM,\n",
    "        train_steps.mean(axis=0) - SEM,\n",
    "        alpha=0.3,\n",
    "        color=\"k\",\n",
    "    )\n",
    "    # Test Plot\n",
    "    ax.plot(test_steps.mean(axis=0), label=\"Test Baseline\", color=\"k\", linestyle=\"--\")\n",
    "    SEM = test_steps.std(axis=0) / np.sqrt(test_steps.shape[0])\n",
    "    ax.fill_between(\n",
    "        np.arange(len(SEM)),\n",
    "        test_steps.mean(axis=0) + SEM,\n",
    "        test_steps.mean(axis=0) - SEM,\n",
    "        alpha=0.3,\n",
    "        color=\"k\",\n",
    "    )\n",
    "    # Ranks Plot\n",
    "    for c_idx, rank in enumerate(list(rank_plots.keys())):\n",
    "        train_steps = []\n",
    "        test_steps = []\n",
    "        for i in range(num_to_average_over):\n",
    "            train_stats = torch.load(\n",
    "                f\"{folder}/{percentage}/rank/{rank}/stats_seed_{i}_dos_0.pth\",\n",
    "                weights_only=True,\n",
    "            )\n",
    "            train_steps.append(train_stats[\"epoch_train_accs\"])\n",
    "            test_steps.append(train_stats[\"epoch_test_accs\"])\n",
    "        train_steps = np.array(train_steps)\n",
    "        test_steps = np.array(test_steps)\n",
    "        # Train Plot\n",
    "        ax.plot(\n",
    "            train_steps.mean(axis=0),\n",
    "            label=f\"Rank {rank_plots[rank]}\",\n",
    "            color=colors[c_idx],\n",
    "        )\n",
    "        SEM = train_steps.std(axis=0) / np.sqrt(train_steps.shape[0])\n",
    "        ax.fill_between(\n",
    "            np.arange(len(SEM)),\n",
    "            train_steps.mean(axis=0) + SEM,\n",
    "            train_steps.mean(axis=0) - SEM,\n",
    "            alpha=0.3,\n",
    "            color=colors[c_idx],\n",
    "        )\n",
    "        # Test Plot\n",
    "        ax.plot(\n",
    "            test_steps.mean(axis=0),\n",
    "            label=f\"Rank {rank_plots[rank]}\",\n",
    "            color=colors[c_idx],\n",
    "            linestyle=\"--\",\n",
    "        )\n",
    "        SEM = test_steps.std(axis=0) / np.sqrt(test_steps.shape[0])\n",
    "        ax.fill_between(\n",
    "            np.arange(len(SEM)),\n",
    "            test_steps.mean(axis=0) + SEM,\n",
    "            test_steps.mean(axis=0) - SEM,\n",
    "            alpha=0.3,\n",
    "            color=colors[c_idx],\n",
    "        )\n",
    "\n",
    "    ax.set_xscale(\"log\")\n",
    "    ax.grid(True)\n",
    "    ax.set_xlabel(\"Epochs\")\n",
    "    ax.set_ylabel(\"Accuracy\")\n",
    "    ax.legend(\n",
    "        loc=\"upper center\",\n",
    "        bbox_to_anchor=(0.5, 1.1),\n",
    "        fancybox=False,\n",
    "        ncol=len(list(rank_plots.keys())) + 1,\n",
    "    )  # Add 1 for the baseline model\n",
    "    # Save Figures\n",
    "    os.makedirs(f\"{folder}/figures\", exist_ok=True)\n",
    "    plt.savefig(\n",
    "        f\"{folder}/figures/grokking_{percentage}_test_accuracy.png\",\n",
    "        dpi=400,\n",
    "        bbox_inches=\"tight\",\n",
    "        pad_inches=0.01,\n",
    "    )"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 528
    },
    "id": "9EesdJ4qoQx7",
    "outputId": "e908a927-0548-4620-aa5f-4c03ac47dfe0"
   },
   "id": "9EesdJ4qoQx7",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Create Table Data\n",
    "\n",
    "This outputs the table data.\n",
    "\n",
    "***PLEASE NOTE: The plot will not be exactly the same as this notebook only explores one seed and one dataset size.***"
   ],
   "metadata": {
    "id": "jnKJ-yslrZW3"
   },
   "id": "jnKJ-yslrZW3"
  },
  {
   "cell_type": "code",
   "source": [
    "for percentage in percentages:\n",
    "  max_acc = []\n",
    "  steps_acc = []\n",
    "  results = {}\n",
    "  for i in range(num_to_average_over):\n",
    "      train_stats = torch.load(\n",
    "          f\"{folder}/{percentage}/base/stats_seed_{i}_dos_0.pth\", weights_only=True\n",
    "      )\n",
    "      max_acc.append(max(train_stats[\"epoch_test_accs\"]))\n",
    "      steps_acc.append(np.argmax(train_stats[\"epoch_test_accs\"]))\n",
    "  max_acc = np.array(max_acc)\n",
    "  steps_acc = np.array(steps_acc)\n",
    "\n",
    "  results[\"Baseline\"] = [\n",
    "      f\"{max_acc.mean():.3f} $pm$ {max_acc.std()/np.sqrt(len(max_acc)):.3f}\",\n",
    "      f\"{steps_acc.mean():.3f} $pm$ {steps_acc.std()/np.sqrt(len(steps_acc)):.3f}\",\n",
    "  ]\n",
    "  # results[\"baseline_steps\"] =\n",
    "\n",
    "  for c_idx, rank in enumerate(list(rank_plots.keys())):\n",
    "      max_acc = []\n",
    "      steps_acc = []\n",
    "      for i in range(num_to_average_over):\n",
    "          train_stats = torch.load(\n",
    "              f\"{folder}/{percentage}/rank/{rank}/stats_seed_{i}_dos_0.pth\", weights_only=True\n",
    "          )\n",
    "          max_acc.append(max(train_stats[\"epoch_test_accs\"]))\n",
    "          steps_acc.append(np.argmax(train_stats[\"epoch_test_accs\"]))\n",
    "      max_acc = np.array(max_acc)\n",
    "      steps_acc = np.array(steps_acc)\n",
    "      results[f\"{rank}\"] = [\n",
    "          f\"{max_acc.mean():.3f} $pm$ {max_acc.std()/np.sqrt(len(max_acc)):.3f}\",\n",
    "          f\"{steps_acc.mean():.3f} $pm$ {steps_acc.std()/np.sqrt(len(steps_acc)):.3f}\",\n",
    "      ]\n",
    "      # results[f\"{rank}_steps\"] =\n",
    "\n",
    "\n",
    "  data = pd.DataFrame(results).T\n",
    "  data.columns = [\"Best Accuracy\", \"Epochs For Best Accuracy\"]\n",
    "  data.index.name = \"Condition\"\n",
    "  data.to_csv(f\"{folder}/{percentage}/grokking_{percentage}_test_accuracy.csv\", float_format=\"%.3f\")\n",
    "  print(data)"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "bqhlgWHZrg1t",
    "outputId": "a01a67b9-df54-4df3-b28d-3e4908ee995a"
   },
   "id": "bqhlgWHZrg1t",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "source": [
    "## The End\n",
    "\n",
    "This is the end of the notebook."
   ],
   "metadata": {
    "id": "3d-a1ZdJsaCA"
   },
   "id": "3d-a1ZdJsaCA"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  },
  "colab": {
   "provenance": []
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
