{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "OfW60Lq4-hXk",
        "outputId": "eb3b4eeb-dcf1-4e70-d88c-66504a39706a"
      },
      "outputs": [],
      "source": [
        "# Codebase and experiments for the paper \"Grokking as the transition from lazy to rich training dynamics\"\n",
        "# We use torch some of the time, and JAX some of the time (mostly when working with linearizations and NTKs)\n",
        "!pip3 install neural-tangents\n",
        "!pip3 install einops\n",
        "!pip3 install torch"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6HSNA2G6Joau"
      },
      "outputs": [],
      "source": [
        "# List of imports\n",
        "from copy import deepcopy\n",
        "from datetime import datetime\n",
        "from functools import *\n",
        "from google.colab import drive\n",
        "from itertools import product\n",
        "from jax import grad, jit, jacfwd, jacrev, lax, random, vmap\n",
        "from jax.example_libraries import optimizers\n",
        "from matplotlib import cm, pyplot as plt\n",
        "from neural_tangents import stax\n",
        "from pathlib import Path\n",
        "from scipy.interpolate import BSpline, interp1d, make_interp_spline\n",
        "from torch import matmul as mm\n",
        "from torch._C import wait\n",
        "from torch.autograd import Variable\n",
        "from torch.utils.data import DataLoader, Dataset, random_split\n",
        "from torchvision.datasets import MNIST\n",
        "from torchvision.transforms import ToTensor\n",
        "from tqdm import tqdm\n",
        "import einops\n",
        "import gc\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import matplotlib\n",
        "import neural_tangents as nt\n",
        "import numpy as np\n",
        "import optax\n",
        "import os\n",
        "import pandas as pd\n",
        "import pickle\n",
        "import plotly.express as px\n",
        "import plotly.graph_objects as go\n",
        "import plotly.io as pio\n",
        "import random as nrandom\n",
        "import seaborn as sns\n",
        "import shutil\n",
        "import sys\n",
        "import time\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "E3k-01E1JqFn"
      },
      "outputs": [],
      "source": [
        "cmap = matplotlib.colormaps.get_cmap('tab20')\n",
        "torch.manual_seed(42)\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "UQ7A-GG2M_63",
        "outputId": "4d3c660e-9848-48e1-9acb-c3368a1bacfc"
      },
      "outputs": [],
      "source": [
        "# Figure 1\n",
        "D = 100\n",
        "P = 550\n",
        "N = 500\n",
        "\n",
        "def target_fn(beta, X):\n",
        "        return (X.T @ beta)**2/2.0\n",
        "\n",
        "X = random.normal(random.PRNGKey(0), (D,P))/ jnp.sqrt(D)\n",
        "Xt = random.normal(random.PRNGKey(1), (D,1000))/ jnp.sqrt(D)\n",
        "beta = random.normal(random.PRNGKey(2), (D,))\n",
        "\n",
        "y = target_fn(beta, X)\n",
        "yt = target_fn(beta,Xt)\n",
        "\n",
        "W = random.normal(random.PRNGKey(0), (N, D))\n",
        "\n",
        "a = random.normal(random.PRNGKey(0), (N, ))\n",
        "params = [a, W]\n",
        "alpha = 1 # scaling parameter, NOT weight norm scale\n",
        "eps = 0.02\n",
        "\n",
        "def NN_func2(params,X):\n",
        "    global alpha\n",
        "    global eps\n",
        "\n",
        "    a, W = params\n",
        "    D = W.shape[1]\n",
        "    N = a.shape[0]\n",
        "    h = W @ X.T\n",
        "\n",
        "    f = alpha * np.mean(phi(h,eps),axis=0) # w/o readouts\n",
        "    return f\n",
        "\n",
        "\n",
        "def phi(z, eps = 0.25):\n",
        "        return z + 0.5*eps*z**2\n",
        "\n",
        "ntk_fn = nt.empirical_ntk_fn(\n",
        "    NN_func2, vmap_axes=0, trace_axes=(),\n",
        "    implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES)\n",
        "\n",
        "def kernel_regression(X, y, Xt, yt, params, which='test'):\n",
        "      K_train = ntk_fn(X.T, None, params)\n",
        "\n",
        "      a = jnp.linalg.solve(K_train, y)\n",
        "\n",
        "      def estimate(xt):\n",
        "        k_test_train = ntk_fn(Xt.T, X.T, params)\n",
        "        k_test_train_squeezed = jnp.squeeze(k_test_train)\n",
        "        return jnp.dot(k_test_train_squeezed, a)\n",
        "\n",
        "      estimates = vmap(estimate)(Xt.T if which=='test' else X.T)\n",
        "      labels = yt if which=='test' else y\n",
        "      mse = jnp.mean((estimates - labels) ** 2)\n",
        "      return mse\n",
        "\n",
        "\n",
        "def kalignment(K, train_y):\n",
        "    train_yc = train_y.reshape(-1, 1)\n",
        "    train_yc = train_yc - train_yc.mean(axis=0)\n",
        "    Kc = K - K.mean(axis=0)\n",
        "    top = jnp.dot(jnp.dot(train_yc.T, Kc), train_yc)\n",
        "    bottom = jnp.linalg.norm(Kc) * (jnp.linalg.norm(train_yc)**2)\n",
        "    return jnp.trace(top)/bottom\n",
        "\n",
        "kmse = kernel_regression(X, y, Xt, yt, params)\n",
        "\n",
        "alphas = [1]\n",
        "epsilons = [0.02]\n",
        "epochs = 100000\n",
        "CENTER_LOSS = True\n",
        "TRAIN_READOUTS = False\n",
        "ntk_interval = 100\n",
        "\n",
        "\n",
        "for alpha in alphas:\n",
        "  for eps in epsilons:\n",
        "    kaligns_test = []\n",
        "    epochs_to_plot = []\n",
        "    dots = []\n",
        "\n",
        "    Cs, As = [], []\n",
        "    actual_w1aligns, actual_w2aligns = [], []\n",
        "    w1_aligns, w2_aligns = [], []\n",
        "    w1_vars, w2_vars, ws_covs = [], [], []\n",
        "    vars_compute_interval = 50\n",
        "\n",
        "    lamb = 0\n",
        "    eta = N/alpha**2\n",
        "    opt_init, opt_update, get_params = optimizers.sgd(eta)\n",
        "    opt_init_lin, opt_update_lin, get_params_lin = optimizers.sgd(eta)\n",
        "\n",
        "    opt_state = opt_init(params)\n",
        "    opt_state_lin = opt_init_lin(params)\n",
        "\n",
        "    f_lin = nt.linearize(NN_func2, params)\n",
        "    lin_tr_losses = []\n",
        "    lin_te_losses = []\n",
        "\n",
        "\n",
        "    if CENTER_LOSS:\n",
        "      loss_fn = jit(lambda p, X, y: jnp.mean( ( NN_func2(p, X.T)- NN_func2(params,X.T) - y )**2))\n",
        "    else:\n",
        "      loss_fn = jit(lambda p, X, y: jnp.mean( ( NN_func2(p, X) - y )**2 / alpha**2 ))\n",
        "\n",
        "    f_lin0 = nt.linearize(NN_func2, params)\n",
        "    lin_loss = jit(lambda p, X, y: jnp.mean((f_lin(p, X.T) - f_lin0(params, X.T) - y)**2)  )\n",
        "    grad_loss_lin = jit(grad(lin_loss, 0))\n",
        "\n",
        "    reg_loss = jit(lambda p, X, y: loss_fn(p,X,y) + lamb / alpha * optimizers.l2_norm(p)**2 )\n",
        "\n",
        "    grad_loss = jit(grad(reg_loss,0))\n",
        "\n",
        "    tr_losses = []\n",
        "    te_losses = []\n",
        "\n",
        "    alignments, alignmentst = [], []\n",
        "    epochs_to_plot = []\n",
        "\n",
        "    t1s, t2s, t3s, epochs_to_compute = [], [], [], []\n",
        "    t1sm, t2sm, t3sm, ts_summ = [], [], [], []\n",
        "    ts_sum = []\n",
        "    alignments, alignmentst = [], []\n",
        "\n",
        "    kmse = kernel_regression(X, y, Xt, yt, get_params(opt_state))\n",
        "\n",
        "    for t in tqdm(range(epochs)):\n",
        "      opt_state = opt_update(t, grad_loss(get_params(opt_state), X, y), opt_state)\n",
        "      pars = get_params(opt_state)\n",
        "\n",
        "      train_loss = loss_fn(pars, X, y)\n",
        "      test_loss = loss_fn(pars, Xt, yt)\n",
        "      tr_losses += [train_loss]\n",
        "      te_losses += [test_loss]\n",
        "\n",
        "      # new update rule for f_lin to compare learning curves\n",
        "      lin_pars = get_params_lin(opt_state_lin)\n",
        "      opt_state_lin = opt_update_lin(t, grad_loss_lin(lin_pars, X, y), opt_state_lin)\n",
        "\n",
        "      lin_tr_losses += [ lin_loss(lin_pars, X, y) ]\n",
        "      lin_te_losses += [ lin_loss(lin_pars, Xt, yt) ]\n",
        "\n",
        "      if t % vars_compute_interval == 0:\n",
        "          epochs_to_compute.append(t)\n",
        "      if t % ntk_interval == 0 and t>0:\n",
        "        K_test = ntk_fn(Xt.T, None, pars)\n",
        "        cka_test = kalignment(K_test, yt)\n",
        "        kaligns_test += [ cka_test ]\n",
        "\n",
        "      if t % 5000 == 0 and t>0:\n",
        "        from scipy.interpolate import interp1d\n",
        "        max_t = t\n",
        "        t_values = np.arange(0, max_t, ntk_interval)\n",
        "        interpolator = interp1d(t_values, kaligns_test, kind='linear', fill_value='extrapolate')\n",
        "        interpolated_kaligns = interpolator(np.arange(max_t))\n",
        "\n",
        "\n",
        "        fig, ax1 = plt.subplots()\n",
        "\n",
        "        col = cmap(0)\n",
        "        ax1.plot(np.array(tr_losses), linestyle='--', label=rf'Train Loss', color=col, lw=2)\n",
        "        ax1.plot(np.array(te_losses), label=rf'Test Loss', color=col, lw=2)\n",
        "        ax1.plot(np.array(lin_tr_losses), color='black', linestyle='--', label=f'Linearized train loss')\n",
        "        ax1.plot(np.array(lin_te_losses), color='black', label=f'Linearized test loss')\n",
        "        ax1.axhline(kmse, color='r', label=rf'$K_0$ regression MSE')\n",
        "        ax1.set_xlabel('Epochs', fontsize=20)\n",
        "        ax1.set_xscale('log')\n",
        "        ax1.set_ylabel('MSE', fontsize=20)\n",
        "        ax1.legend(loc='lower left', bbox_to_anchor=(0, 0.1))\n",
        "\n",
        "        ax2 = ax1.twinx()\n",
        "        ax2.plot(interpolated_kaligns, linestyle='--', color='green', label=r'NTK alignment, $\\frac{y^T K_0y}{||K_0||_F||y||^2}$', lw=2)\n",
        "        ax2.legend(loc='upper left', bbox_to_anchor=(0, 0.58))\n",
        "        ax2.set_ylabel('NTK alignment', fontsize=20)\n",
        "        plt.tight_layout()\n",
        "        plt.show()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "PN_PgYlaM_9o",
        "outputId": "fcfdada5-8a5e-46a0-eaf9-a3cdb3bbfa1e"
      },
      "outputs": [],
      "source": [
        "# Figure 2. Modular arithmetic.\n",
        "\n",
        "# p = 13\n",
        "p = 23\n",
        "D = 2*p  # Input units (each input is a pair of one-hot vectors)\n",
        "alpha = 0.9  # Fraction of data used for training\n",
        "scale = 1\n",
        "lr = 1e2 / scale**2\n",
        "seed= 1\n",
        "N = 100\n",
        "epochs = 50000\n",
        "ntk_interval = int(epochs/10)\n",
        "\n",
        "# scales = [0.1, 0.5, 1, 2, 5][::-1]\n",
        "scales = [1]\n",
        "trs, tls = [], []\n",
        "\n",
        "class SimpleMLP(nn.Module):\n",
        "    def __init__(self, input_dim, hidden_dim, output_dim, scale=1):\n",
        "        super(SimpleMLP, self).__init__()\n",
        "        self.D = input_dim\n",
        "        self.N = hidden_dim\n",
        "        self.scale = scale\n",
        "        self.layer1 = nn.Linear(input_dim, hidden_dim, bias=False)\n",
        "        self.layer2 = nn.Linear(hidden_dim, output_dim, bias=False)\n",
        "\n",
        "    def forward(self, x):\n",
        "        # Split input into two parts (each representing one input integer)\n",
        "        x = self.layer1(x)\n",
        "        x = x**2  # Quadratic activation function\n",
        "        x = self.layer2(x) * self.scale / (self.D * self.N)\n",
        "        return x*scale\n",
        "\n",
        "\n",
        "def modulo(x):\n",
        "  a, b = torch.where(x == 1)[0]\n",
        "  # grab the two numbers\n",
        "  b -= p\n",
        "  m = (a + b) % p\n",
        "  return np.eye(p)[m].tolist()\n",
        "    # Create dataset for modular arithmetic task\n",
        "\n",
        "# Each input is a pair of one-hot vectors; each output is a one-hot vector\n",
        "X = torch.tensor([np.eye(p)[i//p].tolist() + np.eye(p)[i%p].tolist() for i in range(p**2)])\n",
        "y = torch.tensor([modulo(x) for x in X])\n",
        "# Split dataset into train and test subsets\n",
        "torch.manual_seed(seed)\n",
        "indices = torch.randperm(p**2)\n",
        "train_indices = indices[:int(alpha*(p**2))]\n",
        "test_indices = indices[int(alpha*(p**2)):]\n",
        "X_train, y_train = X[train_indices], y[train_indices]\n",
        "X_test, y_test = X[test_indices], y[test_indices]\n",
        "\n",
        "tr_losses, te_losses = [], []\n",
        "print(f\"Grokking modular addition: p: {p}, D: {D}, train-set-frac: {alpha}, scale: {scale}, lr: {round(lr, 2)}, seed: {seed}, N: {N}, epochs: {epochs}\")\n",
        "model = SimpleMLP(D, N, p, scale)\n",
        "for param in model.parameters():\n",
        "    nn.init.normal_(param, mean=0, std=1)\n",
        "\n",
        "loss_fn = nn.MSELoss()\n",
        "optimizer = optim.SGD(model.parameters(), lr=lr)\n",
        "\n",
        "# Training\n",
        "train_loss, test_loss = [], []\n",
        "train_acc, test_acc = [], []\n",
        "wnorms, kaligns, knorms, gnorms = [], [], [], [] # progress measures\n",
        "train_kaligns = []\n",
        "norms1, norms2, = [], []\n",
        "\n",
        "for epoch in tqdm(range(epochs)):\n",
        "    # Forward pass\n",
        "    y_pred = model(X_train)\n",
        "    loss = loss_fn(y_pred, y_train) / float(scale**2)\n",
        "    train_loss.append(loss.item())\n",
        "\n",
        "    predicted = torch.argmax(y_pred, dim=1)\n",
        "    correct = (predicted == torch.argmax(y_train, dim=1)).sum().item()\n",
        "    train_acc.append(100 * correct / len(y_train))\n",
        "\n",
        "    # Backward pass and optimization\n",
        "    optimizer.zero_grad()\n",
        "    loss.backward()\n",
        "\n",
        "    total_norm = 0.0\n",
        "    total_grad_norm = 0.0 # gradient norm\n",
        "    for par in model.parameters():\n",
        "        param_norm = par.data.norm(2)\n",
        "        total_norm += param_norm.item() ** 2\n",
        "\n",
        "        if par.grad is not None:\n",
        "            param_grad_norm = par.grad.data.norm(2)\n",
        "            total_grad_norm += param_grad_norm.item() ** 2\n",
        "    # wnorms.append(total_norm)\n",
        "\n",
        "    t1 = list(model.parameters())[0]\n",
        "    t2 = list(model.parameters())[1]\n",
        "    norm1 = torch.norm(t1).item()\n",
        "    norm2 = torch.norm(t2).item()\n",
        "\n",
        "    norms1.append(norm1)\n",
        "    norms2.append(norm2)\n",
        "    wnorms.append(norm1 + norm2)\n",
        "    gnorms.append(total_grad_norm)  # add this line\n",
        "\n",
        "    optimizer.step()\n",
        "\n",
        "    # Evaluate on test set\n",
        "    with torch.no_grad():\n",
        "        y_pred_test = model(X_test)\n",
        "        test_loss.append(loss_fn(y_pred_test, y_test).item())\n",
        "\n",
        "        predicted_test = torch.argmax(y_pred_test, dim=1)\n",
        "        correct_test = (predicted_test == torch.argmax(y_test, dim=1)).sum().item()\n",
        "        test_acc.append(100 * correct_test / len(y_test))\n",
        "\n",
        "trs += [ train_acc ]\n",
        "tls += [ test_acc ]\n",
        "\n",
        "tr_losses += [ train_loss ]\n",
        "te_losses += [ test_loss ]\n",
        "\n",
        "cmap = plt.get_cmap('tab20')\n",
        "colors = [cmap(i) for i in range(cmap.N)]\n",
        "fig, ax1 = plt.subplots(figsize=(9, 6))\n",
        "plt.rcParams.update({'font.size': 14})\n",
        "\n",
        "ax1.set_xlabel('Epochs', fontsize=20)\n",
        "ax1.set_ylabel('Accuracy', fontsize=20)\n",
        "ax1.plot(range(epochs), train_acc, color=colors[0], label='Train accuracy', linestyle='--', dashes=(3, 4))\n",
        "ax1.plot(range(epochs), test_acc, color=colors[0], label='Test accuracy')\n",
        "\n",
        "ax2 = ax1.twinx()\n",
        "ax2.plot(range(epochs), wnorms, color=colors[2], label='Weight Norm', linewidth=2, linestyle='--', dashes=(3, 4))\n",
        "ax2.tick_params(axis='y')\n",
        "ax2.set_ylabel('Weight Norm', fontsize=20)\n",
        "\n",
        "# Get the legend details for ax1 and ax2\n",
        "lines1, labels1 = ax1.get_legend_handles_labels()\n",
        "lines2, labels2 = ax2.get_legend_handles_labels()\n",
        "\n",
        "# Combine the legend details\n",
        "lines = lines1 + lines2\n",
        "labels = labels1 + labels2\n",
        "\n",
        "# Plot the combined legend\n",
        "ax1.legend(lines, labels, loc='lower right', fontsize=20)  # you can adjust the location as needed\n",
        "\n",
        "plt.tight_layout()\n",
        "plt.savefig('2a')\n",
        "plt.show()\n",
        "\n",
        "plt.figure(figsize=(9, 6))\n",
        "plt.plot(train_loss, label='Train loss', color=colors[0], linestyle='--', linewidth=2.0, dashes=(3, 4))\n",
        "plt.plot(test_loss, label='Test loss', color=colors[0], linewidth=2.0)\n",
        "plt.xlabel('Epochs', fontsize=20)\n",
        "plt.ylabel('MSE', fontsize=20)\n",
        "plt.legend(fontsize=20)\n",
        "plt.tight_layout()\n",
        "plt.savefig('2b')\n",
        "plt.show()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 737
        },
        "id": "MCfGcL5K8UzG",
        "outputId": "f0122371-919e-4b50-aa2c-e3a2558899a4"
      },
      "outputs": [],
      "source": [
        "# 2(c). Sweep over laziness.\n",
        "# Figure 2. Modular arithmetic.\n",
        "\n",
        "# p = 13\n",
        "p = 23\n",
        "D = 2*p  # Input units (each input is a pair of one-hot vectors)\n",
        "alpha = 0.9  # Fraction of data used for training\n",
        "scale = 1\n",
        "seed= 1\n",
        "N = 100\n",
        "\n",
        "ntk_interval = int(epochs/10)\n",
        "\n",
        "# scales = [0.1, 0.5, 1, 2, 5][::-1]\n",
        "scales = [1]\n",
        "trs, tls = [], []\n",
        "\n",
        "class SimpleMLP(nn.Module):\n",
        "    def __init__(self, input_dim, hidden_dim, output_dim, scale=1):\n",
        "        super(SimpleMLP, self).__init__()\n",
        "        self.D = input_dim\n",
        "        self.N = hidden_dim\n",
        "        self.scale = scale\n",
        "        self.layer1 = nn.Linear(input_dim, hidden_dim, bias=False)\n",
        "        self.layer2 = nn.Linear(hidden_dim, output_dim, bias=False)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.layer1(x)\n",
        "        x = x**2  # Quadratic activation\n",
        "        x = self.layer2(x) * self.scale / (self.D * self.N)\n",
        "        return x*scale\n",
        "\n",
        "\n",
        "def modulo(x):\n",
        "  a, b = torch.where(x == 1)[0]\n",
        "  # grab the two numbers\n",
        "  b -= p\n",
        "  m = (a + b) % p\n",
        "  return np.eye(p)[m].tolist()\n",
        "    # Create dataset for modular arithmetic task\n",
        "\n",
        "# Each input is a pair of one-hot vectors; each output is a one-hot vector\n",
        "X = torch.tensor([np.eye(p)[i//p].tolist() + np.eye(p)[i%p].tolist() for i in range(p**2)])\n",
        "y = torch.tensor([modulo(x) for x in X])\n",
        "torch.manual_seed(seed)\n",
        "indices = torch.randperm(p**2)\n",
        "train_indices = indices[:int(alpha*(p**2))]\n",
        "test_indices = indices[int(alpha*(p**2)):]\n",
        "X_train, y_train = X[train_indices], y[train_indices]\n",
        "X_test, y_test = X[test_indices], y[test_indices]\n",
        "\n",
        "scales = [0.5, 1, 1.5, 2]\n",
        "for scale in scales:\n",
        "  lr = 1e2 / scale**2\n",
        "  epochs = 100000\n",
        "  tr_losses, te_losses = [], []\n",
        "  # for wd in wds:\n",
        "  print(f\"Grokking modular addition: p: {p}, D: {D}, train-set-frac: {alpha}, scale: {scale}, lr: {round(lr, 2)}, seed: {seed}, N: {N}, epochs: {epochs}\")\n",
        "  model = SimpleMLP(D, N, p, scale)\n",
        "  # winit_norm = 1\n",
        "  for param in model.parameters(): # what if we change NTK0's spectral structure?\n",
        "      nn.init.normal_(param, mean=0, std=1)\n",
        "\n",
        "  loss_fn = nn.MSELoss()\n",
        "  optimizer = optim.SGD(model.parameters(), lr=lr)\n",
        "  # optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-6)\n",
        "\n",
        "  # Training\n",
        "  train_loss, test_loss = [], []\n",
        "  train_acc, test_acc = [], []\n",
        "  wnorms, kaligns, knorms, gnorms = [], [], [], [] # progress measures\n",
        "  train_kaligns = []\n",
        "  norms1, norms2, = [], []\n",
        "\n",
        "  for epoch in tqdm(range(epochs)):\n",
        "      # Forward pass\n",
        "      y_pred = model(X_train)\n",
        "      loss = loss_fn(y_pred, y_train) / float(scale**2)\n",
        "      train_loss.append(loss.item())\n",
        "\n",
        "      predicted = torch.argmax(y_pred, dim=1)\n",
        "      correct = (predicted == torch.argmax(y_train, dim=1)).sum().item()\n",
        "      train_acc.append(100 * correct / len(y_train))\n",
        "\n",
        "      # Backward pass and optimization\n",
        "      optimizer.zero_grad()\n",
        "      loss.backward()\n",
        "\n",
        "      total_norm = 0.0\n",
        "      total_grad_norm = 0.0 # gradient norm\n",
        "      for par in model.parameters():\n",
        "          param_norm = par.data.norm(2)\n",
        "          total_norm += param_norm.item() ** 2\n",
        "\n",
        "          if par.grad is not None:\n",
        "              param_grad_norm = par.grad.data.norm(2)\n",
        "              total_grad_norm += param_grad_norm.item() ** 2\n",
        "\n",
        "      t1 = list(model.parameters())[0]\n",
        "      t2 = list(model.parameters())[1]\n",
        "      norm1 = torch.norm(t1).item()\n",
        "      norm2 = torch.norm(t2).item()\n",
        "\n",
        "      norms1.append(norm1)\n",
        "      norms2.append(norm2)\n",
        "      wnorms.append(norm1 + norm2)\n",
        "      gnorms.append(total_grad_norm)\n",
        "\n",
        "      optimizer.step()\n",
        "\n",
        "      with torch.no_grad():\n",
        "          y_pred_test = model(X_test)\n",
        "          test_loss.append(loss_fn(y_pred_test, y_test).item())\n",
        "\n",
        "          predicted_test = torch.argmax(y_pred_test, dim=1)\n",
        "          correct_test = (predicted_test == torch.argmax(y_test, dim=1)).sum().item()\n",
        "          test_acc.append(100 * correct_test / len(y_test))\n",
        "\n",
        "  trs += [ train_acc ]\n",
        "  tls += [ test_acc ]\n",
        "\n",
        "  tr_losses += [ train_loss ]\n",
        "  te_losses += [ test_loss ]\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "colors = ['g', 'b', 'r', 'black']\n",
        "plt.figure(figsize=(9,6))\n",
        "plt.rcParams.update({'font.size': 14})\n",
        "\n",
        "plt.xlabel('Epochs', fontsize=20)\n",
        "plt.ylabel('Accuracy', fontsize=20)\n",
        "\n",
        "for i in range(len(trs)):\n",
        "  plt.plot(range(epochs), trs[i], color=colors[i], label=rf'Train, $\\alpha={scales[i]}$', linestyle='--', linewidth=2.0, dashes=(3, 4))\n",
        "  plt.plot(range(epochs), tls[i], color=colors[i], label=rf'Test, $\\alpha={scales[i]}$', linewidth=2.0)\n",
        "\n",
        "plt.legend(bbox_to_anchor=(0.67, 0.9), fontsize=16)\n",
        "plt.tight_layout()\n",
        "plt.show()\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "uhIkDb81WmpT",
        "outputId": "e799452a-78ae-4c84-8526-d7f41e182dec"
      },
      "outputs": [],
      "source": [
        "# Figure 3(a) and 3(b). Sweep over alpha and epsilon.\n",
        "\n",
        "def phi(z, eps = 0.25):\n",
        "    return z + 0.5*eps*z**2\n",
        "\n",
        "def NN_func(params,X, alpha, eps=0.25):\n",
        "    a, W = params\n",
        "\n",
        "    D = W.shape[1]\n",
        "    N = a.shape[0]\n",
        "\n",
        "    h = W @ X / jnp.sqrt(D)\n",
        "    f = alpha/N * phi(h, eps = eps).T @ a\n",
        "    return f\n",
        "\n",
        "def NN_func2(params,X, alpha, eps=0.25):\n",
        "    a, W = params\n",
        "\n",
        "    D = W.shape[1]\n",
        "    N = a.shape[0]\n",
        "\n",
        "    h = W @ X / jnp.sqrt(D)\n",
        "    f = alpha * jnp.mean( phi(h, eps = eps), axis = 0)\n",
        "    return f\n",
        "\n",
        "\n",
        "def target_fn(beta, X):\n",
        "    return (X.T @ beta / jnp.sqrt(D))**2\n",
        "\n",
        "def phi(z, eps = 0.25):\n",
        "    return z + 0.5*eps*z**2\n",
        "\n",
        "def NN_func(params,X, alpha, eps=0.25):\n",
        "    a, W = params\n",
        "\n",
        "    D = W.shape[1]\n",
        "    N = a.shape[0]\n",
        "\n",
        "    h = W @ X / jnp.sqrt(D)\n",
        "    f = alpha/N * phi(h, eps = eps).T @ a\n",
        "    return f\n",
        "\n",
        "def NN_func2(params,X, alpha, eps=0.25):\n",
        "    a, W = params\n",
        "\n",
        "    D = W.shape[1]\n",
        "    N = a.shape[0]\n",
        "\n",
        "    h = W @ X / jnp.sqrt(D)\n",
        "    f = alpha * jnp.mean( phi(h, eps = eps), axis = 0)\n",
        "    return f\n",
        "\n",
        "\n",
        "def target_fn(beta, X):\n",
        "    return (X.T @ beta / jnp.sqrt(D))**2\n",
        "\n",
        "\n",
        "D = 100\n",
        "P = 450\n",
        "N = 500\n",
        "\n",
        "X = random.normal(random.PRNGKey(0), (D,P))\n",
        "Xt = random.normal(random.PRNGKey(1), (D,1000))\n",
        "beta = random.normal(random.PRNGKey(2), (D,))\n",
        "\n",
        "y = target_fn(beta, X)\n",
        "yt = target_fn(beta,Xt)\n",
        "\n",
        "a = random.normal(random.PRNGKey(0), (N, ))\n",
        "W = random.normal(random.PRNGKey(0), (N, D))\n",
        "params = [a, W]\n",
        "\n",
        "\n",
        "eta = 0.5 * N\n",
        "lamb = 0.0\n",
        "opt_init, opt_update, get_params = optimizers.sgd(eta)\n",
        "\n",
        "alphas = [2**(-5),0.25,0.5,1.0,2.0,4.0,8.0,16,32]\n",
        "\n",
        "all_tr_losses = []\n",
        "all_te_losses = []\n",
        "all_acc_tr = []\n",
        "all_acc_te = []\n",
        "\n",
        "param_movement = []\n",
        "\n",
        "for alpha in alphas:\n",
        "    opt_state = opt_init(params)\n",
        "\n",
        "\n",
        "    loss_fn = jit(lambda p, X, y: jnp.mean( ( NN_func2(p, X,alpha)- NN_func2(params,X,alpha) - y )**2 / alpha**2 ))\n",
        "    acc_fn = jit(lambda p, X, y: jnp.mean( ( y * ( NN_func2(p, X,alpha)- NN_func2(params,X,alpha)) ) > 0.0 ))\n",
        "    reg_loss = jit(lambda p, X, y: loss_fn(p,X,y) + lamb / alpha * optimizers.l2_norm(p)**2 )\n",
        "\n",
        "    grad_loss = jit(grad(reg_loss,0))\n",
        "\n",
        "    tr_losses = []\n",
        "    te_losses = []\n",
        "    tr_acc = []\n",
        "    te_acc = []\n",
        "    for t in range(60000):\n",
        "        opt_state = opt_update(t, grad_loss(get_params(opt_state), X, y), opt_state)\n",
        "\n",
        "        if t % 2 == 0:\n",
        "            train_loss = alpha**2*loss_fn(get_params(opt_state), X, y)\n",
        "            test_loss = alpha**2*loss_fn(get_params(opt_state), Xt, yt)\n",
        "            tr_losses += [train_loss]\n",
        "            te_losses += [test_loss]\n",
        "            tr_acc += [ acc_fn(get_params(opt_state), X, y) ]\n",
        "            te_acc += [ acc_fn(get_params(opt_state), Xt, yt) ]\n",
        "            sys.stdout.write(f'\\r t: {t} | train loss: {train_loss} | test loss: {test_loss}')\n",
        "        if t % 10000 == 0:\n",
        "            print(\" \")\n",
        "\n",
        "    all_tr_losses += [tr_losses]\n",
        "    all_te_losses += [te_losses]\n",
        "    all_acc_tr += [tr_acc]\n",
        "    all_acc_te += [te_acc]\n",
        "\n",
        "    paramsf = get_params(opt_state)\n",
        "    dparam = (jnp.sum((paramsf[0]-params[0])**2) + jnp.sum((paramsf[1]-params[1])**2)) / ( jnp.sum( params[0]**2 ) + jnp.sum(params[1]**2) )\n",
        "    param_movement += [  dparam ]\n",
        "\n",
        "plt.rcParams.update({'font.size': 14})\n",
        "plt.figure()\n",
        "for i,alpha in enumerate(alphas[:-1]):\n",
        "    print(alpha)\n",
        "    plt.plot(jnp.linspace(1,len(all_tr_losses[i]),len(all_tr_losses[i])), jnp.array(all_tr_losses[i]) / all_tr_losses[i][0], '--',  color = f'C{i}')\n",
        "    plt.plot(jnp.linspace(1,len(all_tr_losses[i]),len(all_tr_losses[i])), jnp.array(all_te_losses[i]) / all_te_losses[i][0],  color = f'C{i}', label = r'$\\alpha = 2^{%0.0f}$' % jnp.log2(alpha))\n",
        "\n",
        "plt.xscale('log')\n",
        "plt.xlabel(r'$t$',fontsize = 20)\n",
        "plt.ylabel('Loss',fontsize = 20)\n",
        "plt.legend(loc='upper left', bbox_to_anchor=(1, 1))\n",
        "plt.tight_layout()\n",
        "plt.show()\n",
        "\n",
        "\n",
        "weight_norms = [0.125,0.25,0.5,1.0,2.0]\n",
        "alpha = 1.0\n",
        "\n",
        "eta = 0.5 * N\n",
        "lamb = 0.0\n",
        "\n",
        "all_tr_losses_w = []\n",
        "all_te_losses_w = []\n",
        "all_acc_tr_w = []\n",
        "all_acc_te_w = []\n",
        "\n",
        "param_movement_w = []\n",
        "\n",
        "for i, wscale in enumerate(weight_norms):\n",
        "\n",
        "    a = wscale * random.normal(random.PRNGKey(0), (N, ))\n",
        "    W = wscale * random.normal(random.PRNGKey(0), (N, D))\n",
        "    params = [a, W]\n",
        "\n",
        "    opt_init, opt_update, get_params = optimizers.sgd( eta / wscale**2 )\n",
        "    opt_state = opt_init(params)\n",
        "\n",
        "\n",
        "    loss_fn = jit(lambda p, X, y: jnp.mean( ( NN_func2(p, X,alpha)- NN_func2(params,X,alpha) - y )**2 / alpha**2 ))\n",
        "    acc_fn = jit(lambda p, X, y: jnp.mean( ( y * ( NN_func2(p, X,alpha)- NN_func2(params,X,alpha)) ) > 0.0 ))\n",
        "    reg_loss = jit(lambda p, X, y: loss_fn(p,X,y) + lamb / alpha * optimizers.l2_norm(p)**2 )\n",
        "\n",
        "    grad_loss = jit(grad(reg_loss,0))\n",
        "\n",
        "    tr_losses = []\n",
        "    te_losses = []\n",
        "    tr_acc = []\n",
        "    te_acc = []\n",
        "    for t in range(50000):\n",
        "        opt_state = opt_update(t, grad_loss(get_params(opt_state), X, y), opt_state)\n",
        "\n",
        "        if t % 2 == 0:\n",
        "            train_loss = alpha**2*loss_fn(get_params(opt_state), X, y)\n",
        "            test_loss = alpha**2*loss_fn(get_params(opt_state), Xt, yt)\n",
        "            tr_losses += [train_loss]\n",
        "            te_losses += [test_loss]\n",
        "            tr_acc += [ acc_fn(get_params(opt_state), X, y) ]\n",
        "            te_acc += [ acc_fn(get_params(opt_state), Xt, yt) ]\n",
        "            sys.stdout.write(f'\\r t: {t} | train loss: {train_loss} | test loss: {test_loss}')\n",
        "        if t % 10000 == 0:\n",
        "            print(\" \")\n",
        "\n",
        "    all_tr_losses_w += [tr_losses]\n",
        "    all_te_losses_w += [te_losses]\n",
        "    all_acc_tr_w += [tr_acc]\n",
        "    all_acc_te_w += [te_acc]\n",
        "\n",
        "    paramsf = get_params(opt_state)\n",
        "    dparam = (jnp.sum((paramsf[0]-params[0])**2) + jnp.sum((paramsf[1]-params[1])**2)) / ( jnp.sum( params[0]**2 ) + jnp.sum(params[1]**2) )\n",
        "    param_movement_w += [  dparam ]\n",
        "\n",
        "plt.rcParams.update({'font.size': 14})\n",
        "plt.figure()\n",
        "for i,wscale in enumerate(weight_norms):\n",
        "    print(alpha)\n",
        "    plt.plot(jnp.linspace(1,len(all_tr_losses_w[i]),len(all_tr_losses_w[i])), jnp.array(all_tr_losses_w[i]) / all_tr_losses_w[i][0], '--',  color = f'C{i}')\n",
        "    plt.plot(jnp.linspace(1,len(all_tr_losses_w[i]),len(all_tr_losses_w[i])), jnp.array(all_te_losses_w[i]) / all_te_losses_w[i][0],  color = f'C{i}', label = r'$\\sigma = 2^{%0.0f}$' % jnp.log2(wscale))\n",
        "plt.xscale('log')\n",
        "plt.xlabel('t',fontsize = 20)\n",
        "plt.ylabel('Loss',fontsize = 20)\n",
        "plt.legend()\n",
        "plt.tight_layout()\n",
        "plt.show()\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "2ejOSDMIWmti",
        "outputId": "0befc8ef-0be1-4da4-9f9c-6e820ef1cec5"
      },
      "outputs": [],
      "source": [
        "# Figure 3(e). Empirical phase diagram sweep.\n",
        "TRAIN_READOUTS = False\n",
        "CENTER_LOSS = False\n",
        "\n",
        "eps = 0.02\n",
        "alpha = 1.0\n",
        "D = 100\n",
        "P = 550\n",
        "N = 500\n",
        "\n",
        "def target_fn(beta, X):\n",
        "        return (X.T @ beta)**2/2.0\n",
        "\n",
        "X = random.normal(random.PRNGKey(0), (D,P))/ jnp.sqrt(D)\n",
        "Xt = random.normal(random.PRNGKey(1), (D,1000))/ jnp.sqrt(D)\n",
        "beta = random.normal(random.PRNGKey(2), (D,))\n",
        "\n",
        "y = target_fn(beta, X)\n",
        "yt = target_fn(beta,Xt)\n",
        "\n",
        "W = random.normal(random.PRNGKey(0), (N, D))\n",
        "\n",
        "a = random.normal(random.PRNGKey(0), (N, ))\n",
        "params = [a, W]\n",
        "\n",
        "eta = 1 * N\n",
        "\n",
        "alpha=1; eps = 0.02\n",
        "\n",
        "def NN_func2(params,X):\n",
        "    global alpha\n",
        "    global eps\n",
        "\n",
        "    a, W = params\n",
        "    D = W.shape[1]\n",
        "    N = a.shape[0]\n",
        "    h = W @ X.T\n",
        "\n",
        "    ap = a.reshape(-1, 1)\n",
        "    if TRAIN_READOUTS: f = alpha * np.mean(ap * phi(h,eps),axis=0) # E^g decomp wrong for this\n",
        "    else: f = alpha * np.mean(phi(h,eps),axis=0) # w/o readouts\n",
        "\n",
        "    return f\n",
        "\n",
        "\n",
        "def phi(z, eps = 0.25):\n",
        "        return z + 0.5*eps*z**2 # can also try other functional forms for phi and for the target\n",
        "\n",
        "\n",
        "ntk_fn = nt.empirical_ntk_fn(\n",
        "    NN_func2, vmap_axes=0, trace_axes=(),\n",
        "    implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES)\n",
        "\n",
        "ntk = ntk_fn(X.T, None, params)\n",
        "\n",
        "def kernel_regression(X, y, Xt, yt, params, which='test'):\n",
        "      K_train = ntk_fn(X.T, None, params)\n",
        "\n",
        "      a = jnp.linalg.solve(K_train, y)\n",
        "\n",
        "      def estimate(xt):\n",
        "        k_test_train = ntk_fn(Xt.T, X.T, params)\n",
        "        k_test_train_squeezed = jnp.squeeze(k_test_train)\n",
        "        return jnp.dot(k_test_train_squeezed, a)\n",
        "\n",
        "      estimates = vmap(estimate)(Xt.T if which=='test' else X.T)\n",
        "      labels = yt if which=='test' else y\n",
        "      mse = jnp.mean((estimates - labels) ** 2)\n",
        "      return mse\n",
        "\n",
        "\n",
        "def kalignment(K, train_y):\n",
        "    train_yc = train_y.reshape(-1, 1)\n",
        "    train_yc = train_yc - train_yc.mean(axis=0)\n",
        "    Kc = K - K.mean(axis=0)\n",
        "    top = jnp.dot(jnp.dot(train_yc.T, Kc), train_yc)\n",
        "    bottom = jnp.linalg.norm(Kc) * (jnp.linalg.norm(train_yc)**2)\n",
        "    return jnp.trace(top)/bottom\n",
        "\n",
        "kmse = kernel_regression(X, y, Xt, yt, params)\n",
        "\n",
        "def find_first_index(tr_losses, x):\n",
        "    indices = [idx for idx, value in enumerate(tr_losses) if value < x]\n",
        "    return indices[0] if indices else -1\n",
        "\n",
        "def fn(alpha, eps):\n",
        "  W = random.normal(random.PRNGKey(0), (N, D))\n",
        "\n",
        "  a = random.normal(random.PRNGKey(0), (N, ))\n",
        "  pars = [a, W]\n",
        "\n",
        "  lamb = 0\n",
        "  eta = 20*N/alpha**2\n",
        "\n",
        "  optimizer = optax.sgd(learning_rate=eta)\n",
        "  opt_state = optimizer.init(pars)\n",
        "\n",
        "  loss_fn = jit(lambda p, X, y: jnp.mean( ( NN_func2(p, X.T)- NN_func2(params,X.T) - y )**2))\n",
        "\n",
        "  reg_loss = jit(lambda p, X, y: loss_fn(p,X,y) + lamb / alpha * optimizers.l2_norm(p)**2 )\n",
        "  grad_loss = jit(grad(reg_loss,0))\n",
        "\n",
        "  tr_losses = []\n",
        "  te_losses = []\n",
        "\n",
        "  # why is it stopping so early?\n",
        "\n",
        "  # Convergence parameters\n",
        "  delta_thresh = 1e-5  # Convergence threshold\n",
        "  lookback = 5  # Number of previous steps to consider\n",
        "  converged = False\n",
        "\n",
        "  # Initial losses (epoch 0)\n",
        "  init_train_loss = loss_fn(params, X, y)\n",
        "  init_test_loss = loss_fn(params, Xt, yt)\n",
        "\n",
        "  worst_test_loss = init_test_loss  # Initialize with epoch 0 value\n",
        "  t = 0\n",
        "  while not converged:\n",
        "      grads = grad_loss(pars, X, y)\n",
        "      updates, opt_state = optimizer.update(grads, opt_state, pars)\n",
        "      pars = optax.apply_updates(pars, updates)\n",
        "\n",
        "      train_loss = loss_fn(pars, X, y)\n",
        "      test_loss = loss_fn(pars, Xt, yt)\n",
        "      tr_losses += [train_loss]\n",
        "      te_losses += [test_loss]\n",
        "\n",
        "      # Update worst test loss\n",
        "      worst_test_loss = max(worst_test_loss, test_loss)\n",
        "\n",
        "      # Convergence condition\n",
        "      if len(tr_losses) > lookback and len(te_losses) > lookback:\n",
        "          tr_diff = jnp.abs(tr_losses[-1] - tr_losses[-lookback])\n",
        "          te_diff = jnp.abs(te_losses[-1] - te_losses[-lookback])\n",
        "          if tr_diff < delta_thresh and te_diff < delta_thresh and te_losses[-1] < init_test_loss or t>110000:\n",
        "              converged = True\n",
        "      t += 1\n",
        "\n",
        "  final_train = tr_losses[-1]\n",
        "  final_test = te_losses[-1]\n",
        "\n",
        "  conv_tr, conv_te = {}, {}\n",
        "  threshes = [0.01, 0.05, 0.1, 0.2, 0.3, 0.4]\n",
        "  for thresh in threshes:\n",
        "    conv_tr[thresh] = find_first_index(tr_losses, thresh)\n",
        "    conv_te[thresh] = find_first_index(te_losses, thresh)\n",
        "\n",
        "  conv_tr['within_5%_final'] = find_first_index(tr_losses, (1.05 * final_train))\n",
        "  conv_tr['within_10%_final'] = find_first_index(tr_losses, (1.1 * final_train))\n",
        "  conv_tr['within_20%_final'] = find_first_index(tr_losses, (1.2 * final_train))\n",
        "\n",
        "  conv_te['within_5%_final'] = find_first_index(tr_losses, (1.05 * final_test))\n",
        "  conv_te['within_10%_final'] = find_first_index(tr_losses, (1.1 * final_test))\n",
        "  conv_te['within_20%_final'] = find_first_index(tr_losses, (1.2 * final_test))\n",
        "\n",
        "  plt.figure()\n",
        "  plt.plot(tr_losses)\n",
        "  plt.plot(te_losses)\n",
        "  plt.xscale('log')\n",
        "  plt.show()\n",
        "\n",
        "  print(f'final_train={final_train}, final_test={final_test}, tr_conv_when={conv_tr[0.05]}, te_conv_when={conv_te[0.05]}, init_train_loss={init_train_loss}, init_test_loss={init_test_loss}, worst_test_loss={worst_test_loss}')\n",
        "  return final_train, final_test, conv_tr, conv_te, init_train_loss, init_test_loss, worst_test_loss\n",
        "\n",
        "alphas = [0.4 * i for i in range(1, 15)]\n",
        "epsilons = [0.02 * i for i in range(1, 15)]\n",
        "\n",
        "results = {}\n",
        "\n",
        "for alpha in tqdm(alphas):\n",
        "  print(f'Beginning alpha={alpha}')\n",
        "  for eps in tqdm(epsilons):\n",
        "    # should serialize these somewhere as this sweep can take 10+ hours\n",
        "    results[(alpha, eps)] = fn(alpha, eps)\n",
        "\n",
        "matrix = np.zeros((len(alphas), len(epsilons)))\n",
        "\n",
        "# Populate the matrix using the 'results' dictionary\n",
        "for i, alpha in enumerate(alphas):\n",
        "    for j, eps in enumerate(epsilons):\n",
        "        conv_tr, conv_te = results[(alpha, eps)][2], results[(alpha, eps)][3]\n",
        "        matrix[i, j] = max(0, conv_te[0.1] - conv_tr[0.1])\n",
        "\n",
        "# Interpolate to get a higher resolution matrix for smooth visualization\n",
        "interp_factor = 2  # The factor by which we increase the resolution\n",
        "matrix_high_res = np.kron(matrix, np.ones((interp_factor, interp_factor)))  # Kronecker product does the trick\n",
        "\n",
        "# Plot the heatmap\n",
        "plt.figure()\n",
        "plt.imshow(matrix_high_res, interpolation='bilinear', origin='lower', extent=[min(epsilons), max(epsilons), min(alphas), max(alphas)], aspect='auto', cmap='OrRd')\n",
        "plt.colorbar(label='Amount of Grokking')\n",
        "plt.xlabel(r'Kernel Alignment', fontsize=20)\n",
        "plt.ylabel(r'Laziness', fontsize=20)\n",
        "plt.tight_layout()\n",
        "plt.show()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "zTkT_-CFc5fo",
        "outputId": "7287a5df-3799-415e-e61f-a32f0700ca79"
      },
      "outputs": [],
      "source": [
        "# Figure 4. Loss Decomposition.\n",
        "D = 40\n",
        "N = 2\n",
        "alpha = 1\n",
        "eps = 0.25\n",
        "P = 120\n",
        "alphas = [1]\n",
        "\n",
        "epochs = 50000\n",
        "\n",
        "ntk_interval = 200\n",
        "vars_compute_interval = 10\n",
        "\n",
        "all_tr_losses_eps = []\n",
        "all_te_losses_eps = []\n",
        "all_alignments, all_alignmentst = [], []\n",
        "\n",
        "for alpha in alphas:\n",
        "  def phi(z, eps = 0.25):\n",
        "      return z + 0.5*eps*z**2\n",
        "\n",
        "  def NN_func2(params,X, alpha, eps = 0.25):\n",
        "      a, W = params\n",
        "      D = W.shape[1]\n",
        "      N = a.shape[0]\n",
        "      h = W @ X\n",
        "\n",
        "      f = alpha * np.mean(phi(h,eps),axis=0) # w/o readouts\n",
        "\n",
        "      return f\n",
        "\n",
        "  def target_fn(beta, X):\n",
        "      return (X.T @ beta)**2/2.0\n",
        "\n",
        "  X = random.normal(random.PRNGKey(0), (D,P))/ jnp.sqrt(D)\n",
        "  Xt = random.normal(random.PRNGKey(1), (D,1000))/ jnp.sqrt(D)\n",
        "  beta = random.normal(random.PRNGKey(2), (D,))\n",
        "\n",
        "  y = target_fn(beta, X)\n",
        "  yt = target_fn(beta,Xt)\n",
        "\n",
        "  W = random.normal(random.PRNGKey(0), (N, D))\n",
        "\n",
        "  a = random.normal(random.PRNGKey(0), (N, ))\n",
        "  params = [a, W]\n",
        "\n",
        "  eta = 0.5 * N / alpha**2\n",
        "  lamb = 0\n",
        "  opt_init, opt_update, get_params = optimizers.sgd(eta)\n",
        "\n",
        "  opt_state = opt_init(params)\n",
        "  loss_fn = jit(lambda p, X, y: jnp.mean( ( NN_func2(p, X,alpha, eps) - y )**2 ))\n",
        "  reg_loss = jit(lambda p, X, y: loss_fn(p,X,y) + lamb / alpha * optimizers.l2_norm(p)**2 )\n",
        "\n",
        "  grad_loss = jit(grad(reg_loss,0))\n",
        "\n",
        "  tr_losses = []\n",
        "  te_losses = []\n",
        "\n",
        "  epochs_to_plot = []\n",
        "\n",
        "  t1s, t2s, t3s, epochs_to_compute = [], [], [], []\n",
        "  t1sm, t2sm, t3sm, ts_summ = [], [], [], []\n",
        "  ts_sum = []\n",
        "\n",
        "\n",
        "  for t in tqdm(range(epochs)):\n",
        "    opt_state = opt_update(t, grad_loss(get_params(opt_state), X, y), opt_state)\n",
        "    pars = get_params(opt_state)\n",
        "\n",
        "    train_loss = loss_fn(pars, X, y)\n",
        "    test_loss = loss_fn(pars, Xt, yt)\n",
        "    tr_losses += [train_loss]\n",
        "    te_losses += [test_loss]\n",
        "\n",
        "    if t % vars_compute_interval == 0:\n",
        "        W = pars[1]\n",
        "        A = W @ beta / D\n",
        "        M = W.T @ W / N\n",
        "\n",
        "        t1m = (alpha * eps/2.0 * 1.0/D * np.trace(M) - np.mean(beta**2)/2)**2\n",
        "        t2m = (1/2) * np.mean( ( alpha*eps* M - np.outer(beta, beta) )**2 )\n",
        "        t3m = alpha**2 * (1/D) * np.linalg.norm(np.mean(W, axis=0))**2\n",
        "\n",
        "        t1sm.append(t1m)\n",
        "        t2sm.append(t2m)\n",
        "        t3sm.append(t3m)\n",
        "        ts_summ.append(t1m + t2m + t3m)\n",
        "\n",
        "        epochs_to_compute.append(t)\n",
        "\n",
        "  all_tr_losses_eps += [tr_losses]\n",
        "  all_te_losses_eps += [te_losses]\n",
        "\n",
        "  T_compute = np.array(epochs_to_compute)\n",
        "\n",
        "  t1sm_spline = make_interp_spline(T_compute, np.array(t1sm), k=3)\n",
        "  t2sm_spline = make_interp_spline(T_compute, np.array(t2sm), k=3)\n",
        "  t3sm_spline = make_interp_spline(T_compute, np.array(t3sm), k=3)\n",
        "\n",
        "  t1sm_smooth = t1sm_spline(np.arange(epochs))\n",
        "  t2sm_smooth = t2sm_spline(np.arange(epochs))\n",
        "  t3sm_smooth = t3sm_spline(np.arange(epochs))\n",
        "  ts_summ_smooth = make_interp_spline(T_compute, np.array(ts_summ), k=3)(np.arange(epochs))\n",
        "\n",
        "  tit = f'N={N}, P={P}, D={D}, eps={eps}, lr={round(eta, 2)}, alpha={round(alpha, 2)}, lamb={lamb}'\n",
        "  # let's check if our decomp here with A,M is equiv to what BB wrote in current manuscript, bc i suspect that (rather than wbar' or M') is the prob\n",
        "\n",
        "  def get_random_color():\n",
        "      return np.random.rand(3,)\n",
        "\n",
        "  plt.figure(figsize=(9, 6))\n",
        "  plt.plot(tr_losses, label='Train Loss')\n",
        "  plt.plot(te_losses, label='Test Loss')\n",
        "  plt.xscale('log')\n",
        "  plt.xlabel('Epochs', fontsize=20)\n",
        "  plt.ylabel('MSE', fontsize=20)\n",
        "  plt.legend(fontsize=14)\n",
        "  plt.tight_layout()\n",
        "  plt.show()\n",
        "\n",
        "  plt.figure(figsize=(9, 6))\n",
        "  plt.plot(ts_summ_smooth, label='Full Test Loss', color='darkorange')\n",
        "  plt.plot(t1sm_smooth, label='Variance error component', color='red', linestyle='--')\n",
        "  plt.plot(t2sm_smooth, label='Alignment error component', color='green', linestyle='--')\n",
        "  plt.plot(t3sm_smooth, label='Linear term error component', color='black', linestyle='--')\n",
        "  plt.xscale('log')\n",
        "  plt.ylabel('MSE', fontsize=20)\n",
        "  plt.xlabel('Epochs', fontsize=20)\n",
        "  plt.legend(fontsize=14)\n",
        "  plt.tight_layout()\n",
        "  plt.show()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "YhfTASQrrqya",
        "outputId": "d3e354b9-1ee7-4437-a02d-807d0c2e6e60"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "twN-3I7sbBF-",
        "outputId": "30eb654e-0f4e-4790-9623-0a39f7b31e8f"
      },
      "outputs": [],
      "source": [
        "plt.figure(figsize=(9, 6))\n",
        "plt.plot(tr_losses, label='Train Loss')\n",
        "plt.plot(te_losses, label='Test Loss')\n",
        "plt.xscale('log')\n",
        "plt.xlabel('Epochs', fontsize=20)\n",
        "plt.ylabel('MSE', fontsize=20)\n",
        "plt.legend(fontsize=16)\n",
        "plt.tight_layout()\n",
        "plt.savefig('readouts_lcs.pdf')\n",
        "plt.show()\n",
        "\n",
        "plt.figure(figsize=(9, 6))\n",
        "plt.plot(ts_summ_smooth, label='Full Test Loss', color='darkorange')\n",
        "plt.plot(t1sm_smooth, label='Variance error component', color='red', linestyle='--')\n",
        "plt.plot(t2sm_smooth, label='Alignment error component', color='green', linestyle='--')\n",
        "plt.plot(t3sm_smooth, label='Linear term error component', color='black', linestyle='--')\n",
        "plt.xscale('log')\n",
        "plt.ylabel('MSE', fontsize=20)\n",
        "plt.xlabel('Epochs', fontsize=20)\n",
        "plt.legend(fontsize=16)\n",
        "plt.tight_layout()\n",
        "plt.savefig('readouts_decomp.pdf')\n",
        "plt.show()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 874
        },
        "id": "kehrftc3Xzcv",
        "outputId": "9c09c15f-5eca-4316-d89a-4e12c57c4480"
      },
      "outputs": [],
      "source": [
        "# Figure 5. Grokking on arbitrary datasets but setting labels to goldilocks eigenvectors of kernel matrix K(X, X).\n",
        "all_tr_losses_eps = []\n",
        "all_te_losses_eps = []\n",
        "\n",
        "def phi(z, eps = 0.25):\n",
        "    return z + 0.5*eps*z**2\n",
        "def phi2(z, eps = 0.25):\n",
        "    return 0.5*eps*z**2\n",
        "\n",
        "def phi3(z, eps = 0.25):\n",
        "    return 0.5*eps*z**3\n",
        "\n",
        "def phi4(z, eps = 0.25):\n",
        "  def relu(z):\n",
        "    return jax.numpy.maximum(0, z)\n",
        "  return relu(z) + 0.5*eps*z**4\n",
        "\n",
        "def phi5(z, eps = 0.25):\n",
        "    return 0.5*eps*z**4\n",
        "\n",
        "def plot(P=100, N=2, D=100, k=75, alpha=1, phi=phi, eps=0.25, epochs=25000, path='test'):\n",
        "  def get_K(params, X):\n",
        "    def ntk(X, params, NN_func2, alpha, eps):\n",
        "      f = NN_func2(params, X, alpha, eps)\n",
        "      grad_f = jacrev(lambda W, X: NN_func2([params[0], W], X, alpha, eps))(params[1], X)\n",
        "      return jnp.tensordot(grad_f, grad_f, axes=((1, 2), (1, 2)))\n",
        "\n",
        "    return ntk(X, params, NN_func2, alpha, eps)\n",
        "\n",
        "  def NN_func2(params,X, alpha, eps = eps):\n",
        "      a, W = params\n",
        "\n",
        "      D = W.shape[1]\n",
        "      N = a.shape[0]\n",
        "\n",
        "      h = W @ X / jnp.sqrt(D)\n",
        "      f = alpha * jnp.mean( phi(h, eps), axis = 0)\n",
        "      return f\n",
        "\n",
        "  def target_fn(beta, X):\n",
        "      return (X.T @ beta / jnp.sqrt(D))**2\n",
        "\n",
        "  def generate_label_vector(X, k, params):\n",
        "      K = get_K(params, X)\n",
        "      K -= np.mean(K)\n",
        "\n",
        "      eigenvalues, eigenvectors = jnp.linalg.eigh(K)\n",
        "      eigenvalues = eigenvalues[::-1]\n",
        "      eigenvectors = eigenvectors[:, ::-1]\n",
        "      y = eigenvectors[:, k - 1]\n",
        "\n",
        "      return y\n",
        "\n",
        "  X_total = random.normal(random.PRNGKey(0), (D, 2 * P))\n",
        "  params_total = [random.normal(random.PRNGKey(0), (N, )), random.normal(random.PRNGKey(0), (N, D))]\n",
        "  y_total = generate_label_vector(X_total, k, params_total)\n",
        "\n",
        "  X = X_total[:, :P]\n",
        "  Xt = X_total[:, P:]\n",
        "  y = y_total[:P]\n",
        "  yt = y_total[P:]\n",
        "\n",
        "  a = random.normal(random.PRNGKey(0), (N, ))\n",
        "  W = random.normal(random.PRNGKey(0), (N, D))\n",
        "  params = [a, W]\n",
        "\n",
        "  eta = 0.5 * N / alpha**2\n",
        "  lamb = 0\n",
        "  opt_init, opt_update, get_params = optimizers.sgd(eta)\n",
        "\n",
        "  tit = f'Eigvec grokking: p={P}, D={D}, N={N}, scale={alpha}, lr={eta}, eps={eps}, eigk={k}, phi={phi}, lamb={lamb}'\n",
        "  print(tit)\n",
        "  opt_state = opt_init(params)\n",
        "\n",
        "\n",
        "  loss_fn = jit(lambda p, X, y: jnp.mean( ( NN_func2(p, X,alpha, eps)- NN_func2(params,X,alpha, eps) - y )**2 / alpha**2 ))\n",
        "  acc_fn = jit(lambda p, X, y: jnp.mean( ( y * ( NN_func2(p, X,alpha,eps)- NN_func2(params,X,alpha,eps)) ) > 0.0 ))\n",
        "  reg_loss = jit(lambda p, X, y: loss_fn(p,X,y) + lamb / alpha * optimizers.l2_norm(p)**2 )\n",
        "  grad_loss = jit(grad(reg_loss,0))\n",
        "\n",
        "  tr_losses = []\n",
        "  te_losses = []\n",
        "\n",
        "  epochs_to_plot = []\n",
        "  dots = []\n",
        "\n",
        "  for t in tqdm(range(epochs)):\n",
        "    params_current = get_params(opt_state)\n",
        "    opt_state = opt_update(t, grad_loss(params_current, X, y), opt_state)\n",
        "    train_loss = alpha**2*loss_fn(params_current, X, y)\n",
        "    test_loss = alpha**2*loss_fn(params_current, Xt, yt)\n",
        "    tr_losses += [train_loss]\n",
        "    te_losses += [test_loss]\n",
        "\n",
        "  plt.figure()\n",
        "  plt.plot(np.array(tr_losses), label='Train Loss', linestyle='--')\n",
        "  plt.plot(np.array(te_losses), label='Test Loss')\n",
        "  plt.xscale('log')\n",
        "  plt.xlabel('Epochs', fontsize=20)\n",
        "  plt.ylabel('MSE', fontsize=20)\n",
        "  plt.tight_layout()\n",
        "  plt.legend()\n",
        "\n",
        "  plt.show()\n",
        "\n",
        "alpha = 1\n",
        "scale = 1\n",
        "P = 100\n",
        "D = 100\n",
        "N = 2\n",
        "ks = [1, 5, 70, 75, 100, 150]\n",
        "for k in ks:\n",
        "  plot(P=P, D=D, k=k, path='k', alpha=scale, N=N)\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 794
        },
        "id": "ED7sb4O4NAHI",
        "outputId": "951dec5c-37e3-435e-ec78-4e5d995bdac0"
      },
      "outputs": [],
      "source": [
        "# Grokking on MNIST. Architectural details taken from Omnigrok paper: https://arxiv.org/pdf/2210.01117.pdf\n",
        "\n",
        "# Set device\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "# Download MNIST dataset\n",
        "mnist_dataset = MNIST(root=\"./data\", train=True, download=True, transform=ToTensor())\n",
        "\n",
        "\n",
        "def MLP(params, x):\n",
        "  w0 = params[0]\n",
        "  h = w0 @ x.T / jnp.sqrt(w0.shape[1]) # N x P\n",
        "  for l, Wl in enumerate(params[1:-1]):\n",
        "    phi = h * (h > 0.0)\n",
        "    h = 1/jnp.sqrt(Wl.shape[1]) * Wl @ phi\n",
        "\n",
        "  phi = h * (h > 0.0)\n",
        "  f = phi.T @ params[-1].T / phi.shape[0]\n",
        "  return f\n",
        "\n",
        "def init_params(N, D, L, key, w_scale = 1.0):\n",
        "\n",
        "  params = [ w_scale * random.normal(key, (N,D)) ]\n",
        "  for l in range(L-1):\n",
        "    key, _ = random.split(key)\n",
        "    params += [ w_scale * random.normal(key,(N,N)) ]\n",
        "\n",
        "  params += [ w_scale * random.normal(key, (10,N)) ]\n",
        "  return params\n",
        "\n",
        "\n",
        "subset_size = 1000\n",
        "train_set, test_set = random_split(mnist_dataset, [subset_size, len(mnist_dataset) - subset_size])\n",
        "\n",
        "# Create data loaders\n",
        "train_loader = DataLoader(train_set, batch_size=subset_size, shuffle=True)\n",
        "test_loader = DataLoader(test_set, batch_size=2*subset_size)\n",
        "\n",
        "for X,y in train_loader:\n",
        "  break\n",
        "\n",
        "X = jnp.array( X.numpy() )\n",
        "y = jnp.array( y.numpy() )\n",
        "\n",
        "\n",
        "X = X.reshape((X.shape[0], X.shape[-2] * X.shape[-1]))\n",
        "\n",
        "for Xte, yte in test_loader:\n",
        "  break\n",
        "\n",
        "\n",
        "Xte = jnp.array(Xte.numpy())\n",
        "yte = jnp.array(yte.numpy())\n",
        "\n",
        "Xte= Xte.reshape((Xte.shape[0], Xte.shape[-2] * Xte.shape[-1]))\n",
        "\n",
        "\n",
        "y = jnp.eye(10)[y]\n",
        "yte = jnp.eye(10)[yte]\n",
        "\n",
        "\n",
        "scales = [1e-3, 0.01, 0.05, 0.5]\n",
        "traccs, taccs = [], []\n",
        "\n",
        "for scale in tqdm(scales):\n",
        "  # Constants and hyper-params\n",
        "  torch.manual_seed(42)\n",
        "  N = 200\n",
        "  D = 784\n",
        "  L = 3\n",
        "  weight_scale = 150.0  # Scaling factor for Kaiming initialization, replicated from the Omnigrok paper\n",
        "  lr = 1e-3\n",
        "  wd = 0.01\n",
        "  T = 250000\n",
        "  batch = 200\n",
        "  key = random.PRNGKey(0)\n",
        "\n",
        "  # Model\n",
        "  def MLP(params, x):\n",
        "      h = params[0] @ x.T / jnp.sqrt(params[0].shape[1])  # N x P\n",
        "      for l, Wl in enumerate(params[1:-1]):\n",
        "          h = h * (h > 0.0)\n",
        "          h = Wl @ h / jnp.sqrt(Wl.shape[1])\n",
        "      h = h * (h > 0.0)\n",
        "      f = h.T @ params[-1].T / h.shape[0]\n",
        "      return f*scale\n",
        "\n",
        "  key = random.PRNGKey(0)\n",
        "  params = init_params(N, D, L, key, weight_scale)\n",
        "\n",
        "  # Initialization with Kaiming scaling.\n",
        "  def init_params(N, D, L, key, weight_scale=1.0):\n",
        "      params = [weight_scale * random.normal(key, (N, D)) * jnp.sqrt(2. / D)]\n",
        "      for l in range(L-1):\n",
        "          key, _ = random.split(key)\n",
        "          params += [weight_scale * random.normal(key, (N, N)) * jnp.sqrt(2. / N)]\n",
        "      params += [weight_scale * random.normal(key, (10, N)) * jnp.sqrt(2. / N)]\n",
        "      return params\n",
        "\n",
        "\n",
        "  loss_fn = jit(lambda p, X, y: jnp.mean((MLP(p, X) - y)**2))\n",
        "  grad_fn = jit(grad(loss_fn))\n",
        "  optimizer = optax.adamw(learning_rate=lr, weight_decay=wd)\n",
        "  opt_state = optimizer.init(params)\n",
        "\n",
        "  train_loss = []\n",
        "  test_loss = []\n",
        "  train_accuracy = []\n",
        "  test_accuracy = []\n",
        "\n",
        "  def compute_accuracy(predictions, targets):\n",
        "      return jnp.mean(jnp.argmax(predictions, axis=1) == jnp.argmax(targets, axis=1))\n",
        "\n",
        "\n",
        "  compute_every = 100\n",
        "  for t in range(T):\n",
        "      if t % compute_every == 0:\n",
        "          train_pred = MLP(params, X)\n",
        "          test_pred = MLP(params, Xte)\n",
        "\n",
        "          # Compute and store train & test loss\n",
        "          train_loss.append(loss_fn(params, X, y))\n",
        "          test_loss.append(loss_fn(params, Xte, yte))\n",
        "\n",
        "          # Compute and store train & test accuracy\n",
        "          train_accuracy.append(compute_accuracy(train_pred, y))\n",
        "          test_accuracy.append(compute_accuracy(test_pred, yte))\n",
        "\n",
        "      ind = batch * t % subset_size # take a slice of 200 out of our 1000 size dataset and cycle like that\n",
        "      grads = grad_fn(params, X[ind:ind+batch], y[ind:ind+batch])\n",
        "      updates, opt_state = optimizer.update(grads, opt_state, params)\n",
        "      params = optax.apply_updates(params, updates)\n",
        "\n",
        "  traccs += [ train_accuracy ]\n",
        "  taccs += [ test_accuracy ]\n",
        "\n",
        "\n",
        "plt.figure(figsize=(9, 6))\n",
        "colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'orange', 'purple', 'pink']\n",
        "for i, acc in enumerate(traccs):\n",
        "    x = np.arange(0, compute_every * len(acc), compute_every)\n",
        "    xnew = np.linspace(0, 250000, 250000)\n",
        "    spl = make_interp_spline(x, traccs[i], k=3) # k=3 for cubic spline\n",
        "    y_train_smooth = spl(xnew)\n",
        "    spl = make_interp_spline(x, taccs[i], k=3)\n",
        "    y_test_smooth = spl(xnew)\n",
        "\n",
        "    plt.plot(xnew, y_train_smooth*100, label=rf'Train Accuracy, $\\alpha$={scales[i]}', linestyle='--', color=colors[-i])\n",
        "    plt.plot(xnew, y_test_smooth*100, label=rf'Test Accuracy, $\\alpha$={scales[i]}', color=colors[-i])\n",
        "\n",
        "plt.xlabel('Epochs', fontsize=20)\n",
        "plt.xscale('log')\n",
        "plt.ylabel('Accuracy', fontsize=20)\n",
        "plt.legend(fontsize=14)\n",
        "plt.tight_layout()\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 885
        },
        "id": "xB-7hguANAJn",
        "outputId": "9c87a2af-b2e2-452b-c420-f9e07908fd25"
      },
      "outputs": [],
      "source": [
        "# Grokking in a One-Layer Transformer on a modular arithmetic task\n",
        "# Replicated then modified to induce lazy dynamics from Neel Nanda's MI paper: https://arxiv.org/pdf/2301.05217.pdf and code at https://colab.research.google.com/drive/1F6_1_cWXE5M7WocUcpQWp3v8z4b1jL20\n",
        "import random\n",
        "import torch.nn.functional as F\n",
        "\n",
        "!git clone https://github.com/neelnanda-io/Grokking.git\n",
        "root = Path('/content/Grokking/saved_runs')\n",
        "large_root = Path('/content/Grokking/large_files')\n",
        "try:\n",
        "    os.mkdir(large_root)\n",
        "except:\n",
        "    pass\n",
        "!pip install gdown\n",
        "!gdown 12pmgxpTHLDzSNMbMCuAMXP1lE_XiCQRy -O /content/Grokking/large_files/full_run_data.pth\n",
        "\n",
        "class HookPoint(nn.Module):\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "        self.fwd_hooks = []\n",
        "        self.bwd_hooks = []\n",
        "\n",
        "    def give_name(self, name):\n",
        "        # Called by the model at initialisation\n",
        "        self.name = name\n",
        "\n",
        "    def add_hook(self, hook, dir='fwd'):\n",
        "        # Hook format is fn(activation, hook_name)\n",
        "        # Change it into PyTorch hook format (this includes input and output,\n",
        "        # which are the same for a HookPoint)\n",
        "        def full_hook(module, module_input, module_output):\n",
        "            return hook(module_output, name=self.name)\n",
        "        if dir=='fwd':\n",
        "            handle = self.register_forward_hook(full_hook)\n",
        "            self.fwd_hooks.append(handle)\n",
        "        elif dir=='bwd':\n",
        "            handle = self.register_backward_hook(full_hook)\n",
        "            self.bwd_hooks.append(handle)\n",
        "        else:\n",
        "            raise ValueError(f\"Invalid direction {dir}\")\n",
        "\n",
        "    def remove_hooks(self, dir='fwd'):\n",
        "        if (dir=='fwd') or (dir=='both'):\n",
        "            for hook in self.fwd_hooks:\n",
        "                hook.remove()\n",
        "            self.fwd_hooks = []\n",
        "        if (dir=='bwd') or (dir=='both'):\n",
        "            for hook in self.bwd_hooks:\n",
        "                hook.remove()\n",
        "            self.bwd_hooks = []\n",
        "        if dir not in ['fwd', 'bwd', 'both']:\n",
        "            raise ValueError(f\"Invalid direction {dir}\")\n",
        "\n",
        "    def forward(self, x):\n",
        "        return x\n",
        "\n",
        "\n",
        "# Embed & Unembed\n",
        "class Embed(nn.Module):\n",
        "    def __init__(self, d_vocab, d_model):\n",
        "        super().__init__()\n",
        "        self.W_E = nn.Parameter(torch.randn(d_model, d_vocab)/np.sqrt(d_model))\n",
        "\n",
        "    def forward(self, x):\n",
        "        return torch.einsum('dbp -> bpd', self.W_E[:, x])\n",
        "\n",
        "class Unembed(nn.Module):\n",
        "    def __init__(self, d_vocab, d_model):\n",
        "        super().__init__()\n",
        "        self.W_U = nn.Parameter(torch.randn(d_model, d_vocab)/np.sqrt(d_vocab))\n",
        "\n",
        "    def forward(self, x):\n",
        "        return (x @ self.W_U)\n",
        "\n",
        "# Positional Embeddings\n",
        "class PosEmbed(nn.Module):\n",
        "    def __init__(self, max_ctx, d_model):\n",
        "        super().__init__()\n",
        "        self.W_pos = nn.Parameter(torch.randn(max_ctx, d_model)/np.sqrt(d_model))\n",
        "\n",
        "    def forward(self, x):\n",
        "        return x+self.W_pos[:x.shape[-2]]\n",
        "\n",
        "# LayerNorm\n",
        "class LayerNorm(nn.Module):\n",
        "    def __init__(self, d_model, epsilon = 1e-4, model=[None]):\n",
        "        super().__init__()\n",
        "        self.model = model\n",
        "        self.w_ln = nn.Parameter(torch.ones(d_model))\n",
        "        self.b_ln = nn.Parameter(torch.zeros(d_model))\n",
        "        self.epsilon = epsilon\n",
        "\n",
        "    def forward(self, x):\n",
        "        if self.model[0].use_ln:\n",
        "            x = x - x.mean(axis=-1)[..., None]\n",
        "            x = x / (x.std(axis=-1)[..., None] + self.epsilon)\n",
        "            x = x * self.w_ln\n",
        "            x = x + self.b_ln\n",
        "            return x\n",
        "        else:\n",
        "            return x\n",
        "\n",
        "# Attention\n",
        "class Attention(nn.Module):\n",
        "    def __init__(self, d_model, num_heads, d_head, n_ctx, model):\n",
        "        super().__init__()\n",
        "        self.model = model\n",
        "        self.scale = scale\n",
        "        self.W_K = nn.Parameter(torch.randn(num_heads, d_head, d_model)/np.sqrt(d_model))\n",
        "        self.W_Q = nn.Parameter(torch.randn(num_heads, d_head, d_model)/np.sqrt(d_model))\n",
        "        self.W_V = nn.Parameter(torch.randn(num_heads, d_head, d_model)/np.sqrt(d_model))\n",
        "        self.W_O = nn.Parameter(torch.randn(d_model, d_head * num_heads)/np.sqrt(d_model))\n",
        "        self.register_buffer('mask', torch.tril(torch.ones((n_ctx, n_ctx))))\n",
        "        self.d_head = d_head\n",
        "        self.hook_k = HookPoint()\n",
        "        self.hook_q = HookPoint()\n",
        "        self.hook_v = HookPoint()\n",
        "        self.hook_z = HookPoint()\n",
        "        self.hook_attn = HookPoint()\n",
        "        self.hook_attn_pre = HookPoint()\n",
        "\n",
        "    def forward(self, x):\n",
        "        k = self.hook_k(torch.einsum('ihd,bpd->biph', self.W_K, x))\n",
        "        q = self.hook_q(torch.einsum('ihd,bpd->biph', self.W_Q, x))\n",
        "        v = self.hook_v(torch.einsum('ihd,bpd->biph', self.W_V, x))\n",
        "        attn_scores_pre = torch.einsum('biph,biqh->biqp', k, q)\n",
        "        attn_scores_masked = torch.tril(attn_scores_pre) - 1e10 * (1 - self.mask[:x.shape[-2], :x.shape[-2]])\n",
        "        attn_matrix = self.hook_attn(F.softmax(self.hook_attn_pre(attn_scores_masked/np.sqrt(self.d_head)), dim=-1))\n",
        "        z = self.hook_z(torch.einsum('biph,biqp->biqh', v, attn_matrix))\n",
        "        z_flat = einops.rearrange(z, 'b i q h -> b q (i h)')\n",
        "        out = torch.einsum('df,bqf->bqd', self.W_O, z_flat)\n",
        "        return out\n",
        "\n",
        "# MLP Layers\n",
        "class MLP(nn.Module):\n",
        "    def __init__(self, d_model, d_mlp, act_type, model):\n",
        "        super().__init__()\n",
        "        self.model = model\n",
        "        self.W_in = nn.Parameter(torch.randn(d_mlp, d_model)/np.sqrt(d_model))\n",
        "        self.b_in = nn.Parameter(torch.zeros(d_mlp))\n",
        "        self.W_out = nn.Parameter(torch.randn(d_model, d_mlp)/np.sqrt(d_model))\n",
        "        self.b_out = nn.Parameter(torch.zeros(d_model))\n",
        "        self.act_type = act_type\n",
        "        # self.ln = LayerNorm(d_mlp, model=self.model)\n",
        "        self.hook_pre = HookPoint()\n",
        "        self.hook_post = HookPoint()\n",
        "        assert act_type in ['ReLU', 'GeLU']\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.hook_pre(torch.einsum('md,bpd->bpm', self.W_in, x) + self.b_in)\n",
        "        if self.act_type=='ReLU':\n",
        "            x = F.relu(x)\n",
        "        elif self.act_type=='GeLU':\n",
        "            x = F.gelu(x)\n",
        "        x = self.hook_post(x)\n",
        "        x = torch.einsum('dm,bpm->bpd', self.W_out, x) + self.b_out\n",
        "        return x\n",
        "\n",
        "# Transformer Block\n",
        "class TransformerBlock(nn.Module):\n",
        "    def __init__(self, d_model, d_mlp, d_head, num_heads, n_ctx, act_type, model):\n",
        "        super().__init__()\n",
        "        torch.manual_seed(1)\n",
        "        torch.cuda.manual_seed(1)\n",
        "        self.model = model\n",
        "        # self.ln1 = LayerNorm(d_model, model=self.model)\n",
        "        self.attn = Attention(d_model, num_heads, d_head, n_ctx, model=self.model)\n",
        "        # self.ln2 = LayerNorm(d_model, model=self.model)\n",
        "        self.mlp = MLP(d_model, d_mlp, act_type, model=self.model)\n",
        "        self.hook_attn_out = HookPoint()\n",
        "        self.hook_mlp_out = HookPoint()\n",
        "        self.hook_resid_pre = HookPoint()\n",
        "        self.hook_resid_mid = HookPoint()\n",
        "        self.hook_resid_post = HookPoint()\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.hook_resid_mid(x + self.hook_attn_out(self.attn((self.hook_resid_pre(x)))))\n",
        "        x = self.hook_resid_post(x + self.hook_mlp_out(self.mlp((x))))\n",
        "        return x\n",
        "\n",
        "# Full transformer\n",
        "class Transformer(nn.Module):\n",
        "    def __init__(self, num_layers, d_vocab, d_model, d_mlp, d_head, num_heads, n_ctx, act_type, use_cache=False, use_ln=True, scale=1.0):\n",
        "        super().__init__()\n",
        "        self.cache = {}\n",
        "        self.scale = scale\n",
        "        self.use_cache = use_cache\n",
        "\n",
        "        self.embed = Embed(d_vocab, d_model)\n",
        "        self.pos_embed = PosEmbed(n_ctx, d_model)\n",
        "        self.blocks = nn.ModuleList([TransformerBlock(d_model, d_mlp, d_head, num_heads, n_ctx, act_type, model=[self]) for i in range(num_layers)])\n",
        "        # self.ln = LayerNorm(d_model, model=[self])\n",
        "        self.unembed = Unembed(d_vocab, d_model)\n",
        "        self.use_ln = use_ln\n",
        "\n",
        "        for name, module in self.named_modules():\n",
        "            if type(module)==HookPoint:\n",
        "                module.give_name(name)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.embed(x)\n",
        "        x = self.pos_embed(x)\n",
        "        for block in self.blocks:\n",
        "            x = block(x)\n",
        "        # x = self.ln(x)\n",
        "        x = self.unembed(x)\n",
        "        return x*self.scale\n",
        "\n",
        "    def set_use_cache(self, use_cache):\n",
        "        self.use_cache = use_cache\n",
        "\n",
        "    def hook_points(self):\n",
        "        return [module for name, module in self.named_modules() if 'hook' in name]\n",
        "\n",
        "    def remove_all_hooks(self):\n",
        "        for hp in self.hook_points():\n",
        "            hp.remove_hooks('fwd')\n",
        "            hp.remove_hooks('bwd')\n",
        "\n",
        "    def cache_all(self, cache, incl_bwd=False):\n",
        "        # Caches all activations wrapped in a HookPoint\n",
        "        def save_hook(tensor, name):\n",
        "            cache[name] = tensor.detach()\n",
        "        def save_hook_back(tensor, name):\n",
        "            cache[name+'_grad'] = tensor[0].detach()\n",
        "        for hp in self.hook_points():\n",
        "            hp.add_hook(save_hook, 'fwd')\n",
        "            if incl_bwd:\n",
        "                hp.add_hook(save_hook_back, 'bwd')\n",
        "\n",
        "# Helper functions\n",
        "def cuda_memory():\n",
        "    print(torch.cuda.memory_allocated()/1e9)\n",
        "\n",
        "def cross_entropy_high_precision(logits, labels):\n",
        "    logprobs = F.log_softmax(logits.to(torch.float64), dim=-1)\n",
        "    prediction_logprobs = torch.gather(logprobs, index=labels[:, None], dim=-1)\n",
        "    loss = -torch.mean(prediction_logprobs)\n",
        "    return loss\n",
        "\n",
        "def full_loss(model, data):\n",
        "    logits = model(data)[:, -1]\n",
        "    labels = torch.tensor([fn(i, j) for i, j, _ in data]).to('cuda')\n",
        "    return cross_entropy_high_precision(logits, labels)\n",
        "\n",
        "def test_logits(logits, bias_correction=False, original_logits=None, mode='all'):\n",
        "    if logits.shape[1]==p*p:\n",
        "        logits = logits.T\n",
        "    if logits.shape==torch.Size([p*p, p+1]):\n",
        "        logits = logits[:, :-1]\n",
        "    logits = logits.reshape(p*p, p)\n",
        "    if bias_correction:\n",
        "        # Applies bias correction - we correct for any missing bias terms,\n",
        "        # independent of the input, by centering the new logits along the batch\n",
        "        # dimension, and then adding the average original logits across all inputs\n",
        "        logits = einops.reduce(original_logits - logits, 'batch ... -> ...', 'mean') + logits\n",
        "    if mode=='train':\n",
        "        return cross_entropy_high_precision(logits[is_train], labels[is_train])\n",
        "    elif mode=='test':\n",
        "        return cross_entropy_high_precision(logits[is_test], labels[is_test])\n",
        "    elif mode=='all':\n",
        "        return cross_entropy_high_precision(logits, labels)\n",
        "\n",
        "\n",
        "# Helper functions\n",
        "def cuda_memory():\n",
        "    print(torch.cuda.memory_allocated()/1e9)\n",
        "\n",
        "def cross_entropy_high_precision(logits, labels):\n",
        "    # Shapes: batch x vocab, batch\n",
        "    # Cast logits to float64 because log_softmax has a float32 underflow on overly\n",
        "    # confident data and can only return multiples of 1.2e-7 (the smallest float x\n",
        "    # such that 1+x is different from 1 in float32). This leads to loss spikes\n",
        "    # and dodgy gradients\n",
        "    logprobs = F.log_softmax(logits.to(torch.float64), dim=-1)\n",
        "    prediction_logprobs = torch.gather(logprobs, index=labels[:, None], dim=-1)\n",
        "    loss = -torch.mean(prediction_logprobs)\n",
        "    return loss\n",
        "\n",
        "def full_loss(model, data):\n",
        "    # Take the final position only\n",
        "    logits = model(data)[:, -1]\n",
        "    labels = torch.tensor([fn(i, j) for i, j, _ in data]).to('cuda')\n",
        "    return cross_entropy_high_precision(logits, labels)\n",
        "\n",
        "def test_logits(logits, bias_correction=False, original_logits=None, mode='all'):\n",
        "    # Calculates cross entropy loss of logits representing a batch of all p^2\n",
        "    # possible inputs\n",
        "    # Batch dimension is assumed to be first\n",
        "    if logits.shape[1]==p*p:\n",
        "        logits = logits.T\n",
        "    if logits.shape==torch.Size([p*p, p+1]):\n",
        "        logits = logits[:, :-1]\n",
        "    logits = logits.reshape(p*p, p)\n",
        "    if bias_correction:\n",
        "        # Applies bias correction - we correct for any missing bias terms,\n",
        "        # independent of the input, by centering the new logits along the batch\n",
        "        # dimension, and then adding the average original logits across all inputs\n",
        "        logits = einops.reduce(original_logits - logits, 'batch ... -> ...', 'mean') + logits\n",
        "    if mode=='train':\n",
        "        return cross_entropy_high_precision(logits[is_train], labels[is_train])\n",
        "    elif mode=='test':\n",
        "        return cross_entropy_high_precision(logits[is_test], labels[is_test])\n",
        "    elif mode=='all':\n",
        "        return cross_entropy_high_precision(logits, labels)\n",
        "\n",
        "# want Plotly hacking practice\n",
        "def to_numpy(tensor, flat=False):\n",
        "    if type(tensor)!=torch.Tensor:\n",
        "        return tensor\n",
        "    if flat:\n",
        "        return tensor.flatten().detach().cpu().numpy()\n",
        "    else:\n",
        "        return tensor.detach().cpu().numpy()\n",
        "def imshow(tensor, xaxis=None, yaxis=None, animation_name='Snapshot', **kwargs):\n",
        "    if tensor.shape[0]==p*p:\n",
        "        tensor = unflatten_first(tensor)\n",
        "    tensor = torch.squeeze(tensor)\n",
        "    px.imshow(to_numpy(tensor, flat=False),\n",
        "              labels={'x':xaxis, 'y':yaxis, 'animation_name':animation_name},\n",
        "              **kwargs).show()\n",
        "# Set default colour scheme\n",
        "imshow = partial(imshow, color_continuous_scale='Blues')\n",
        "# Creates good defaults for showing divergent colour scales (ie with both\n",
        "# positive and negative values, where 0 is white)\n",
        "imshow_div = partial(imshow, color_continuous_scale='RdBu', color_continuous_midpoint=0.0)\n",
        "# Presets a bunch of defaults to imshow to make it suitable for showing heatmaps\n",
        "# of activations with x axis being input 1 and y axis being input 2.\n",
        "inputs_heatmap = partial(imshow, xaxis='Input 1', yaxis='Input 2', color_continuous_scale='RdBu', color_continuous_midpoint=0.0)\n",
        "def line(x, y=None, hover=None, xaxis='', yaxis='', **kwargs):\n",
        "    if type(y)==torch.Tensor:\n",
        "        y = to_numpy(y, flat=True)\n",
        "    if type(x)==torch.Tensor:\n",
        "        x=to_numpy(x, flat=True)\n",
        "    fig = px.line(x, y=y, hover_name=hover, **kwargs)\n",
        "    fig.update_layout(xaxis_title=xaxis, yaxis_title=yaxis)\n",
        "    fig.show()\n",
        "def scatter(x, y, **kwargs):\n",
        "    px.scatter(x=to_numpy(x, flat=True), y=to_numpy(y, flat=True), **kwargs).show()\n",
        "def lines(lines_list, x=None, mode='lines', labels=None, xaxis='', yaxis='', title = '', log_y=False, hover=None, **kwargs):\n",
        "    # Helper function to plot multiple lines\n",
        "    if type(lines_list)==torch.Tensor:\n",
        "        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]\n",
        "    if x is None:\n",
        "        x=np.arange(len(lines_list[0]))\n",
        "    fig = go.Figure(layout={'title':title})\n",
        "    fig.update_xaxes(title=xaxis)\n",
        "    fig.update_yaxes(title=yaxis)\n",
        "    for c, line in enumerate(lines_list):\n",
        "        if type(line)==torch.Tensor:\n",
        "            line = to_numpy(line)\n",
        "        if labels is not None:\n",
        "            label = labels[c]\n",
        "        else:\n",
        "            label = c\n",
        "        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))\n",
        "    if log_y:\n",
        "        fig.update_layout(yaxis_type=\"log\")\n",
        "    fig.show()\n",
        "def line_marker(x, **kwargs):\n",
        "    lines([x], mode='lines+markers', **kwargs)\n",
        "def animate_lines(lines_list, snapshot_index = None, snapshot='snapshot', hover=None, xaxis='x', yaxis='y', **kwargs):\n",
        "    if type(lines_list)==list:\n",
        "        lines_list = torch.stack(lines_list, axis=0)\n",
        "    lines_list = to_numpy(lines_list, flat=False)\n",
        "    if snapshot_index is None:\n",
        "        snapshot_index = np.arange(lines_list.shape[0])\n",
        "    if hover is not None:\n",
        "        hover = [i for j in range(len(snapshot_index)) for i in hover]\n",
        "    print(lines_list.shape)\n",
        "    rows=[]\n",
        "    for i in range(lines_list.shape[0]):\n",
        "        for j in range(lines_list.shape[1]):\n",
        "            rows.append([lines_list[i][j], snapshot_index[i], j])\n",
        "    df = pd.DataFrame(rows, columns=[yaxis, snapshot, xaxis])\n",
        "    px.line(df, x=xaxis, y=yaxis, animation_frame=snapshot, range_y=[lines_list.min(), lines_list.max()], hover_name=hover,**kwargs).show()\n",
        "\n",
        "def imshow_fourier(tensor, title='', animation_name='snapshot', facet_labels=[], **kwargs):\n",
        "    # Set nice defaults for plotting functions in the 2D fourier basis\n",
        "    # tensor is assumed to already be in the Fourier Basis\n",
        "    if tensor.shape[0]==p*p:\n",
        "        tensor = unflatten_first(tensor)\n",
        "    tensor = torch.squeeze(tensor)\n",
        "    fig=px.imshow(to_numpy(tensor),\n",
        "            x=fourier_basis_names,\n",
        "            y=fourier_basis_names,\n",
        "            labels={'x':'x Component',\n",
        "                    'y':'y Component',\n",
        "                    'animation_frame':animation_name},\n",
        "            title=title,\n",
        "            color_continuous_midpoint=0.,\n",
        "            color_continuous_scale='RdBu',\n",
        "            **kwargs)\n",
        "    fig.update(data=[{'hovertemplate':\"%{x}x * %{y}y<br>Value:%{z:.4f}\"}])\n",
        "    if facet_labels:\n",
        "        for i, label in enumerate(facet_labels):\n",
        "            fig.layout.annotations[i]['text'] = label\n",
        "    fig.show()\n",
        "\n",
        "def animate_multi_lines(lines_list, y_index=None, snapshot_index = None, snapshot='snapshot', hover=None, swap_y_animate=False, **kwargs):\n",
        "    # Can plot an animation of lines with multiple lines on the plot.\n",
        "    if type(lines_list)==list:\n",
        "        lines_list = torch.stack(lines_list, axis=0)\n",
        "    lines_list = to_numpy(lines_list, flat=False)\n",
        "    if swap_y_animate:\n",
        "        lines_list = lines_list.transpose(1, 0, 2)\n",
        "    if snapshot_index is None:\n",
        "        snapshot_index = np.arange(lines_list.shape[0])\n",
        "    if y_index is None:\n",
        "        y_index = [str(i) for i in range(lines_list.shape[1])]\n",
        "    if hover is not None:\n",
        "        hover = [i for j in range(len(snapshot_index)) for i in hover]\n",
        "    print(lines_list.shape)\n",
        "    rows=[]\n",
        "    for i in range(lines_list.shape[0]):\n",
        "        for j in range(lines_list.shape[2]):\n",
        "            rows.append(list(lines_list[i, :, j])+[snapshot_index[i], j])\n",
        "    df = pd.DataFrame(rows, columns=y_index+[snapshot, 'x'])\n",
        "    px.line(df, x='x', y=y_index, animation_frame=snapshot, range_y=[lines_list.min(), lines_list.max()], hover_name=hover, **kwargs).show()\n",
        "\n",
        "def animate_scatter(lines_list, snapshot_index = None, snapshot='snapshot', hover=None, yaxis='y', xaxis='x', color=None, color_name = 'color', **kwargs):\n",
        "    # Can plot an animated scatter plot\n",
        "    # lines_list has shape snapshot x 2 x line\n",
        "    if type(lines_list)==list:\n",
        "        lines_list = torch.stack(lines_list, axis=0)\n",
        "    lines_list = to_numpy(lines_list, flat=False)\n",
        "    if snapshot_index is None:\n",
        "        snapshot_index = np.arange(lines_list.shape[0])\n",
        "    if hover is not None:\n",
        "        hover = [i for j in range(len(snapshot_index)) for i in hover]\n",
        "    if color is None:\n",
        "        color = np.ones(lines_list.shape[-1])\n",
        "    if type(color)==torch.Tensor:\n",
        "        color = to_numpy(color)\n",
        "    if len(color.shape)==1:\n",
        "        color = einops.repeat(color, 'x -> snapshot x', snapshot=lines_list.shape[0])\n",
        "    print(lines_list.shape)\n",
        "    rows=[]\n",
        "    for i in range(lines_list.shape[0]):\n",
        "        for j in range(lines_list.shape[2]):\n",
        "            rows.append([lines_list[i, 0, j].item(), lines_list[i, 1, j].item(), snapshot_index[i], color[i, j]])\n",
        "    print([lines_list[:, 0].min(), lines_list[:, 0].max()])\n",
        "    print([lines_list[:, 1].min(), lines_list[:, 1].max()])\n",
        "    df = pd.DataFrame(rows, columns=[xaxis, yaxis, snapshot, color_name])\n",
        "    px.scatter(df, x=xaxis, y=yaxis, animation_frame=snapshot, range_x=[lines_list[:, 0].min(), lines_list[:, 0].max()], range_y=[lines_list[:, 1].min(), lines_list[:, 1].max()], hover_name=hover, color=color_name, **kwargs).show()\n",
        "\n",
        "def unflatten_first(tensor):\n",
        "    if tensor.shape[0]==p*p:\n",
        "        return einops.rearrange(tensor, '(x y) ... -> x y ...', x=p, y=p)\n",
        "    else:\n",
        "        return tensor\n",
        "def cos(x, y):\n",
        "    return (x.dot(y))/x.norm()/y.norm()\n",
        "def mod_div(a, b):\n",
        "    return (a*pow(b, p-2, p))%p\n",
        "def normalize(tensor, axis=0):\n",
        "    return tensor/(tensor).pow(2).sum(keepdim=True, axis=axis).sqrt()\n",
        "def extract_freq_2d(tensor, freq):\n",
        "    # Takes in a pxpx... or batch x ... tensor, returns a 3x3x... tensor of the\n",
        "    # Linear and quadratic terms of frequency freq\n",
        "    tensor = unflatten_first(tensor)\n",
        "    # Extracts the linear and quadratic terms corresponding to frequency freq\n",
        "    index_1d = [0, 2*freq-1, 2*freq]\n",
        "    # Some dumb manipulation to use fancy array indexing rules\n",
        "    # Gets the rows and columns in index_1d\n",
        "    return tensor[[[i]*3 for i in index_1d], [index_1d]*3]\n",
        "def get_cov(tensor, norm=True):\n",
        "    # Calculate covariance matrix\n",
        "    if norm:\n",
        "        tensor = normalize(tensor, axis=1)\n",
        "    return tensor @ tensor.T\n",
        "def is_close(a, b):\n",
        "    return ((a-b).pow(2).sum()/(a.pow(2).sum().sqrt())/(b.pow(2).sum().sqrt())).item()\n",
        "\n",
        "def unflatten_first(tensor):\n",
        "    if tensor.shape[0]==p*p:\n",
        "        return einops.rearrange(tensor, '(x y) ... -> x y ...', x=p, y=p)\n",
        "    else:\n",
        "        return tensor\n",
        "def cos(x, y):\n",
        "    return (x.dot(y))/x.norm()/y.norm()\n",
        "def mod_div(a, b):\n",
        "    return (a*pow(b, p-2, p))%p\n",
        "def normalize(tensor, axis=0):\n",
        "    return tensor/(tensor).pow(2).sum(keepdim=True, axis=axis).sqrt()\n",
        "def extract_freq_2d(tensor, freq):\n",
        "    # Takes in a pxpx... or batch x ... tensor, returns a 3x3x... tensor of the\n",
        "    # Linear and quadratic terms of frequency freq\n",
        "    tensor = unflatten_first(tensor)\n",
        "    # Extracts the linear and quadratic terms corresponding to frequency freq\n",
        "    index_1d = [0, 2*freq-1, 2*freq]\n",
        "    # Some dumb manipulation to use fancy array indexing rules\n",
        "    # Gets the rows and columns in index_1d\n",
        "    return tensor[[[i]*3 for i in index_1d], [index_1d]*3]\n",
        "def get_cov(tensor, norm=True):\n",
        "    # Calculate covariance matrix\n",
        "    if norm:\n",
        "        tensor = normalize(tensor, axis=1)\n",
        "    return tensor @ tensor.T\n",
        "def is_close(a, b):\n",
        "    return ((a-b).pow(2).sum()/(a.pow(2).sum().sqrt())/(b.pow(2).sum().sqrt())).item()\n",
        "\n",
        "\n",
        "lr=1e-3 #@param\n",
        "weight_decay = 1.0 #@param\n",
        "p=113 #@param\n",
        "d_model = 128 #@param\n",
        "fn_name = 'add' #@param ['add', 'subtract', 'x2xyy2','rand']\n",
        "frac_train = 0.4 #@param\n",
        "num_epochs = 50000 #@param\n",
        "save_models = True #@param\n",
        "save_every = 100 #@param\n",
        "# Stop training when test loss is <stopping_thresh\n",
        "stopping_thresh = -1 #@param\n",
        "seed = 0 #@param\n",
        "\n",
        "num_layers = 1\n",
        "batch_style = 'full'\n",
        "d_vocab = p+1\n",
        "n_ctx = 3\n",
        "d_mlp = 4*d_model\n",
        "num_heads = 4\n",
        "assert d_model % num_heads == 0\n",
        "d_head = d_model//num_heads\n",
        "act_type = 'ReLU' #@param ['ReLU', 'GeLU']\n",
        "# batch_size = 512\n",
        "use_ln = False\n",
        "random_answers = np.random.randint(low=0, high=p, size=(p, p))\n",
        "fns_dict = {'add': lambda x,y:(x+y)%p, 'subtract': lambda x,y:(x-y)%p, 'x2xyy2':lambda x,y:(x**2+x*y+y**2)%p, 'rand':lambda x,y:random_answers[x][y]}\n",
        "fn = fns_dict[fn_name]\n",
        "\n",
        "train_model = True #@param\n",
        "\n",
        "def gen_train_test(frac_train, num, seed=0):\n",
        "    # Generate train and test split\n",
        "    pairs = [(i, j, num) for i in range(num) for j in range(num)]\n",
        "    random.seed(seed)\n",
        "    random.shuffle(pairs)\n",
        "    div = int(frac_train*len(pairs))\n",
        "    return pairs[:div], pairs[div:]\n",
        "\n",
        "train, test = gen_train_test(frac_train, p, seed)\n",
        "\n",
        "# Creates an array of Boolean indices according to whether each data point is in\n",
        "# train or test\n",
        "# Used to index into the big batch of all possible data\n",
        "is_train = []\n",
        "is_test = []\n",
        "for x in range(p):\n",
        "    for y in range(p):\n",
        "        if (x, y, 113) in train:\n",
        "            is_train.append(True)\n",
        "            is_test.append(False)\n",
        "        else:\n",
        "            is_train.append(False)\n",
        "            is_test.append(True)\n",
        "is_train = np.array(is_train)\n",
        "is_test = np.array(is_test)\n",
        "\n",
        "\n",
        "def centered_loss(model, model_init, data): # modified\n",
        "    # Take the final position only\n",
        "    logits = model(data)[:, -1]\n",
        "    with torch.no_grad():\n",
        "      logits_init = model_init(data)[:, -1]\n",
        "    labels = torch.tensor([fn(i, j) for i, j, _ in data]).to('cuda')\n",
        "    return cross_entropy_high_precision(logits-logits_init, labels) # train the model to produce logits relative to init\n",
        "\n",
        "from tqdm import tqdm\n",
        "from copy import deepcopy\n",
        "\n",
        "num_epochs = 200000\n",
        "\n",
        "scales = [0.05 * i for i in range(1, 11)]\n",
        "tls = []\n",
        "trs = []\n",
        "save_models = True\n",
        "\n",
        "if train_model:\n",
        "  for scale in scales:\n",
        "      model = Transformer(num_layers=num_layers, d_vocab=d_vocab, d_model=d_model, d_mlp=d_mlp, d_head=d_head, num_heads=num_heads, n_ctx=n_ctx, act_type=act_type, use_cache=False, use_ln=use_ln, scale=scale)\n",
        "      # transformer has same random seed upon init\n",
        "      model_init = deepcopy(model) # we'll use this to center loss\n",
        "\n",
        "      model_init.to('cuda')\n",
        "      model.to('cuda')\n",
        "\n",
        "      lr=6e-6 #@param\n",
        "      optimizer = optim.AdamW(model.parameters(), lr=lr/scale**2, weight_decay=weight_decay, betas=(0.9, 0.98))\n",
        "      scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda step: min(step/10, 1))\n",
        "      run_name = f\"grok_{int(time.time())}\"\n",
        "      print(f'Run name {run_name}')\n",
        "      if save_models:\n",
        "          os.mkdir(root/run_name)\n",
        "          save_dict = {'model':model.state_dict(), 'train_data':train, 'test_data':test}\n",
        "          torch.save(save_dict, root/run_name/'init.pth')\n",
        "      train_losses = []\n",
        "      test_losses = []\n",
        "      for epoch in tqdm(range(num_epochs)):\n",
        "          train_loss = centered_loss(model, model_init, train)\n",
        "          test_loss = centered_loss(model, model_init, test)\n",
        "          train_losses.append(train_loss.item())\n",
        "          test_losses.append(test_loss.item())\n",
        "          train_loss.backward()\n",
        "          optimizer.step()\n",
        "          scheduler.step()\n",
        "          optimizer.zero_grad()\n",
        "          if test_loss.item() < stopping_thresh:\n",
        "              break\n",
        "\n",
        "      trs += [train_losses]\n",
        "      tls += [test_losses]\n",
        "\n",
        "plt.figure()\n",
        "for i in range(len(trs)):\n",
        "    plt.plot(trs[i], linestyle='--', color=colors[i], label=rf'$Train, \\alpha$={scale}')\n",
        "    plt.plot(tls[i], color=colors[i], label=rf'$Test, \\alpha$={scale}')\n",
        "\n",
        "plt.tight_layout()\n",
        "plt.xlabel('Epochs', fontsize=20)\n",
        "plt.ylabel('Cross Entropy Loss', fontsize=20)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0xWLCXatNALW"
      },
      "outputs": [],
      "source": [
        "# Appendix Plots"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 304
        },
        "id": "DYmYk21O5r0e",
        "outputId": "cbb0d9e7-7d90-4acd-a688-8170e4c07c0a"
      },
      "outputs": [],
      "source": [
        "# Weight norm and laziness have the same effect.\n",
        "from jax import random # sometimes we use jax.random and sometimes python.random\n",
        "\n",
        "def phi(z, eps = 0.25):\n",
        "    return z + 0.5*eps*z**2\n",
        "\n",
        "def NN_func(params,X, alpha, eps=0.25):\n",
        "    a, W = params\n",
        "\n",
        "    D = W.shape[1]\n",
        "    N = a.shape[0]\n",
        "\n",
        "    h = W @ X / jnp.sqrt(D)\n",
        "    f = alpha/N * phi(h, eps = eps).T @ a\n",
        "    return f\n",
        "\n",
        "def NN_func2(params,X, alpha, eps=0.25):\n",
        "    a, W = params\n",
        "\n",
        "    D = W.shape[1]\n",
        "    N = a.shape[0]\n",
        "\n",
        "    h = W @ X / jnp.sqrt(D)\n",
        "    f = alpha * jnp.mean( phi(h, eps = eps), axis = 0)\n",
        "    return f\n",
        "\n",
        "\n",
        "def target_fn(beta, X):\n",
        "    return (X.T @ beta / jnp.sqrt(D))**2\n",
        "\n",
        "\n",
        "D = 100\n",
        "P = 450\n",
        "N = 500\n",
        "\n",
        "X = random.normal(random.PRNGKey(0), (D,P))\n",
        "Xt = random.normal(random.PRNGKey(1), (D,1000))\n",
        "beta = random.normal(random.PRNGKey(2), (D,))\n",
        "\n",
        "y = target_fn(beta, X)\n",
        "yt = target_fn(beta,Xt)\n",
        "\n",
        "a = random.normal(random.PRNGKey(0), (N, ))\n",
        "W = random.normal(random.PRNGKey(0), (N, D))\n",
        "params = [a, W]\n",
        "\n",
        "\n",
        "eta = 0.5 * N\n",
        "lamb = 0.0\n",
        "opt_init, opt_update, get_params = optimizers.sgd(eta)\n",
        "\n",
        "alphas = [2**(-5),0.25,0.5,1.0,2.0,4.0,8.0,16,32]\n",
        "\n",
        "all_tr_losses = []\n",
        "all_te_losses = []\n",
        "all_acc_tr = []\n",
        "all_acc_te = []\n",
        "\n",
        "param_movement = []\n",
        "\n",
        "weight_norms = [0.125,0.25,0.5,1.0,2.0]\n",
        "alpha = 1.0\n",
        "\n",
        "eta = 0.5 * N\n",
        "lamb = 0.0\n",
        "\n",
        "all_tr_losses_w = []\n",
        "all_te_losses_w = []\n",
        "all_acc_tr_w = []\n",
        "all_acc_te_w = []\n",
        "\n",
        "param_movement_w = []\n",
        "\n",
        "for i, wscale in enumerate(weight_norms):\n",
        "\n",
        "    a = wscale * random.normal(random.PRNGKey(0), (N, ))\n",
        "    W = wscale * random.normal(random.PRNGKey(0), (N, D))\n",
        "    params = [a, W]\n",
        "\n",
        "    opt_init, opt_update, get_params = optimizers.sgd( eta / wscale**2 )\n",
        "    opt_state = opt_init(params)\n",
        "\n",
        "\n",
        "    loss_fn = jit(lambda p, X, y: jnp.mean( ( NN_func2(p, X,alpha)- NN_func2(params,X,alpha) - y )**2 / alpha**2 ))\n",
        "    acc_fn = jit(lambda p, X, y: jnp.mean( ( y * ( NN_func2(p, X,alpha)- NN_func2(params,X,alpha)) ) > 0.0 ))\n",
        "    reg_loss = jit(lambda p, X, y: loss_fn(p,X,y) + lamb / alpha * optimizers.l2_norm(p)**2 )\n",
        "\n",
        "    grad_loss = jit(grad(reg_loss,0))\n",
        "\n",
        "    tr_losses = []\n",
        "    te_losses = []\n",
        "    tr_acc = []\n",
        "    te_acc = []\n",
        "    for t in range(50000):\n",
        "        opt_state = opt_update(t, grad_loss(get_params(opt_state), X, y), opt_state)\n",
        "\n",
        "        if t % 2 == 0:\n",
        "            train_loss = alpha**2*loss_fn(get_params(opt_state), X, y)\n",
        "            test_loss = alpha**2*loss_fn(get_params(opt_state), Xt, yt)\n",
        "            tr_losses += [train_loss]\n",
        "            te_losses += [test_loss]\n",
        "            tr_acc += [ acc_fn(get_params(opt_state), X, y) ]\n",
        "            te_acc += [ acc_fn(get_params(opt_state), Xt, yt) ]\n",
        "            sys.stdout.write(f'\\r t: {t} | train loss: {train_loss} | test loss: {test_loss}')\n",
        "        if t % 10000 == 0:\n",
        "            print(\" \")\n",
        "\n",
        "    all_tr_losses_w += [tr_losses]\n",
        "    all_te_losses_w += [te_losses]\n",
        "    all_acc_tr_w += [tr_acc]\n",
        "    all_acc_te_w += [te_acc]\n",
        "\n",
        "    paramsf = get_params(opt_state)\n",
        "    dparam = (jnp.sum((paramsf[0]-params[0])**2) + jnp.sum((paramsf[1]-params[1])**2)) / ( jnp.sum( params[0]**2 ) + jnp.sum(params[1]**2) )\n",
        "    param_movement_w += [  dparam ]\n",
        "\n",
        "\n",
        "plt.rcParams.update({'font.size': 14})\n",
        "plt.figure()\n",
        "for i,wscale in enumerate(weight_norms):\n",
        "    print(alpha)\n",
        "    plt.plot(jnp.linspace(1,len(all_tr_losses_w[i]),len(all_tr_losses_w[i])), jnp.array(all_tr_losses_w[i]) / all_tr_losses_w[i][0], '--',  color = f'C{i}')\n",
        "    plt.plot(jnp.linspace(1,len(all_tr_losses_w[i]),len(all_tr_losses_w[i])), jnp.array(all_te_losses_w[i]) / all_te_losses_w[i][0],  color = f'C{i}', label = r'$\\sigma = 2^{%0.0f}$' % jnp.log2(wscale))\n",
        "plt.xscale('log')\n",
        "plt.xlabel('t',fontsize = 20)\n",
        "plt.ylabel('Loss',fontsize = 20)\n",
        "plt.legend()\n",
        "plt.tight_layout()\n",
        "plt.show()\n",
        "\n",
        "plt.loglog(alphas, jnp.sqrt(jnp.array(param_movement)),'-o', label = r'vary $\\alpha$, $\\sigma=1$')\n",
        "plt.loglog(jnp.array(weight_norms)**2, jnp.array(param_movement_w)**(0.5) ,'--o', label =r'vary $\\sigma$, $\\alpha=1$')\n",
        "plt.xlabel(r'$\\sigma^2 \\alpha$', fontsize = 20)\n",
        "plt.ylabel(r'$|\\theta-\\theta_0|/|\\theta_0|$',fontsize = 20)\n",
        "plt.legend()\n",
        "plt.tight_layout()\n",
        "plt.show()\n",
        "\n",
        "final_losses_w =  [ te_loss[-1]  for te_loss in all_te_losses_w ]\n",
        "final_losses = [ te_loss[-1] for te_loss in all_te_losses ]\n",
        "\n",
        "plt.loglog(alphas, final_losses, '-o', label = r'vary $\\alpha$, $\\sigma=1$')\n",
        "plt.loglog(jnp.array(weight_norms)**2, final_losses_w, '--o', label =r'vary $\\sigma$, $\\alpha=1$')\n",
        "plt.xlabel(r'$\\sigma^2 \\alpha$', fontsize = 20)\n",
        "plt.ylabel(r'Final Test Loss',fontsize = 20)\n",
        "plt.legend()\n",
        "plt.tight_layout()\n",
        "plt.show()\n",
        "\n",
        "\n",
        "plt.loglog(jnp.sqrt(jnp.array(param_movement)) , final_losses, '-o', label = r'vary $\\alpha$, $\\sigma=1$')\n",
        "plt.loglog(jnp.array(param_movement_w)**(0.5), final_losses_w, '--o', label =r'vary $\\sigma$, $\\alpha=1$')\n",
        "plt.xlabel(r'$|\\theta-\\theta_0|/|\\theta_0|$', fontsize = 20)\n",
        "plt.ylabel(r'Final Test Loss',fontsize = 20)\n",
        "plt.legend()\n",
        "plt.tight_layout()\n",
        "plt.show()\n",
        "\n",
        "\n",
        "weight_norms = [0.25,0.5,1.0,2.0]\n",
        "alpha_base = 1.0\n",
        "\n",
        "eta = 0.5 * N\n",
        "lamb = 0.0\n",
        "\n",
        "all_tr_losses_walpha = []\n",
        "all_te_losses_walpha = []\n",
        "all_acc_tr_walpha = []\n",
        "all_acc_te_walpha = []\n",
        "\n",
        "param_movement_walpha = []\n",
        "\n",
        "for i, wscale in enumerate(weight_norms):\n",
        "\n",
        "    a = wscale * random.normal(random.PRNGKey(0), (N, ))\n",
        "    W = wscale * random.normal(random.PRNGKey(0), (N, D))\n",
        "    params = [a, W]\n",
        "\n",
        "    alpha = alpha_base / wscale**2\n",
        "    opt_init, opt_update, get_params = optimizers.sgd( eta / wscale**2 )\n",
        "    opt_state = opt_init(params)\n",
        "\n",
        "\n",
        "    loss_fn = jit(lambda p, X, y: jnp.mean( ( NN_func2(p, X,alpha)- NN_func2(params,X,alpha) - y )**2 / alpha**2 ))\n",
        "    acc_fn = jit(lambda p, X, y: jnp.mean( ( y * ( NN_func2(p, X,alpha)- NN_func2(params,X,alpha)) ) > 0.0 ))\n",
        "    reg_loss = jit(lambda p, X, y: loss_fn(p,X,y) + lamb / alpha * optimizers.l2_norm(p)**2 )\n",
        "\n",
        "    grad_loss = jit(grad(reg_loss,0))\n",
        "\n",
        "    tr_losses = []\n",
        "    te_losses = []\n",
        "    tr_acc = []\n",
        "    te_acc = []\n",
        "    for t in range(50000):\n",
        "        opt_state = opt_update(t, grad_loss(get_params(opt_state), X, y), opt_state)\n",
        "\n",
        "        if t % 2 == 0:\n",
        "            train_loss = alpha**2*loss_fn(get_params(opt_state), X, y)\n",
        "            test_loss = alpha**2*loss_fn(get_params(opt_state), Xt, yt)\n",
        "            tr_losses += [train_loss]\n",
        "            te_losses += [test_loss]\n",
        "            tr_acc += [ acc_fn(get_params(opt_state), X, y) ]\n",
        "            te_acc += [ acc_fn(get_params(opt_state), Xt, yt) ]\n",
        "            sys.stdout.write(f'\\r t: {t} | train loss: {train_loss} | test loss: {test_loss}')\n",
        "        if t % 10000 == 0:\n",
        "            print(\" \")\n",
        "\n",
        "    all_tr_losses_walpha += [tr_losses]\n",
        "    all_te_losses_walpha += [te_losses]\n",
        "    all_acc_tr_walpha += [tr_acc]\n",
        "    all_acc_te_walpha += [te_acc]\n",
        "\n",
        "    paramsf = get_params(opt_state)\n",
        "    dparam = (jnp.sum((paramsf[0]-params[0])**2) + jnp.sum((paramsf[1]-params[1])**2)) / ( jnp.sum( params[0]**2 ) + jnp.sum(params[1]**2) )\n",
        "    param_movement_walpha += [  dparam ]\n",
        "\n",
        "\n",
        "plt.rcParams.update({'font.size': 14})\n",
        "plt.figure()\n",
        "for i,wscale in enumerate(weight_norms):\n",
        "    plt.plot(jnp.linspace(1,len(all_tr_losses_walpha[i]),len(all_tr_losses_walpha[i])), jnp.array(all_tr_losses_walpha[i]) / all_tr_losses_walpha[i][0], '--',  color = f'C{i}')\n",
        "    plt.plot(jnp.linspace(1,len(all_tr_losses_walpha[i]),len(all_tr_losses_walpha[i])), jnp.array(all_te_losses_walpha[i]) / all_te_losses_walpha[i][0],  color = f'C{i}', label = r'$\\sigma = 2^{%0.0f}$ | $\\alpha = 2^{%0.0f}$' % (jnp.log2(wscale),-2*jnp.log2(wscale)))\n",
        "plt.xscale('log')\n",
        "plt.xlabel('t',fontsize = 20)\n",
        "plt.ylabel('Loss',fontsize = 20)\n",
        "plt.legend()\n",
        "plt.tight_layout()\n",
        "plt.savefig('loss_curves_linear_quad_vary_w_alpha_jointly.pdf')\n",
        "plt.show()\n",
        "\n",
        "\n",
        "plt.loglog(weight_norms, jnp.array(param_movement)**(0.5) ,'-o')\n",
        "plt.loglog(weight_norms, 2*jnp.array(weight_norms)**(-0.0), '--', label = r'$\\sigma^{0}$', color = 'black')\n",
        "plt.xlabel(r'$\\sigma$',fontsize = 20)\n",
        "plt.ylabel(r'$|\\theta-\\theta_0|/|\\theta_0|$',fontsize = 20)\n",
        "plt.legend()\n",
        "plt.tight_layout()\n",
        "plt.show()\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "04U7bG6e15as",
        "outputId": "861525c7-5702-4f41-ae82-0098407cd7b1"
      },
      "outputs": [],
      "source": [
        "# STUDENT-TEACHER -- showing that measuring \"accuracy\" on a regression task is misleading\n",
        "# Direct replication of the Omnigrok paper https://arxiv.org/abs/2210.01117.\n",
        "# Taken from their code repo and modified to show our parameter works out of the box.\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import torch\n",
        "from torch import nn\n",
        "import torch.nn.functional as F\n",
        "\n",
        "\n",
        "def L2(model):\n",
        "    params = list(model.parameters())\n",
        "    l2 = 0\n",
        "    for i in range(6):\n",
        "        if i == 0:\n",
        "            params_flatten = params[i].reshape(-1,)\n",
        "        params_flatten = torch.cat([params_flatten, params[i].reshape(-1,)])\n",
        "    l2 = torch.sum(params_flatten**2)\n",
        "    return params_flatten, l2\n",
        "\n",
        "def init(model, alpha):\n",
        "    state_dict = model.state_dict()\n",
        "    modules = [\"l1.weight\", \"l1.bias\", \"l2.weight\", \"l2.bias\", \"l3.weight\", \"l3.bias\"]\n",
        "    for module in modules:\n",
        "        state_dict[module] = state_dict[module] * alpha\n",
        "    model.load_state_dict(state_dict)\n",
        "\n",
        "def init2(model, alpha):\n",
        "    model.l1.weight.data = model.l1.weight * alpha\n",
        "    model.l1.bias.data = model.l1.bias * alpha\n",
        "    model.l2.weight.data = model.l2.weight * alpha\n",
        "    model.l2.bias.data = model.l2.bias * alpha\n",
        "    model.l3.weight.data = model.l3.weight * alpha\n",
        "    model.l3.bias.data = model.l3.bias * alpha\n",
        "\n",
        "\n",
        "def grad(model):\n",
        "    grads = list(student.parameters())\n",
        "    for i in range(6):\n",
        "        if i == 0:\n",
        "            grad = grads[0].reshape(-1,)\n",
        "        else:\n",
        "            grad = torch.cat([grad, grads[i].reshape(-1,)])\n",
        "    return grad\n",
        "\n",
        "seed = 0\n",
        "np.random.seed(seed)\n",
        "torch.manual_seed(seed)\n",
        "\n",
        "d_in = 5\n",
        "d_out = 5\n",
        "train_size = 100\n",
        "test_size = 100\n",
        "w = 100\n",
        "\n",
        "class Net(nn.Module):\n",
        "\n",
        "    def __init__(self, w=w, scale=1.0):\n",
        "        super(Net, self).__init__()\n",
        "        self.scale = scale\n",
        "        self.l1 = nn.Linear(d_in, w)\n",
        "        self.l2 = nn.Linear(w, w)\n",
        "        self.l3 = nn.Linear(w,d_out)\n",
        "\n",
        "    def forward(self, x):\n",
        "        f = torch.nn.Tanh()\n",
        "        self.x1 = f(self.l1(x))\n",
        "        self.x2 = f(self.l2(self.x1))\n",
        "        self.x3 = self.l3(self.x2)\n",
        "        return self.x3*self.scale\n",
        "\n",
        "teacher = Net()\n",
        "alpha = 1.0\n",
        "init(teacher, alpha=alpha)\n",
        "inputs_train = torch.tensor(torch.normal(0,1,size=(train_size, d_in)), dtype=torch.float, requires_grad=True)\n",
        "labels_train = torch.tensor(teacher(inputs_train), dtype=torch.float, requires_grad=True)\n",
        "\n",
        "inputs_test = torch.normal(0,1,size=(test_size, d_in))\n",
        "labels_test = teacher(inputs_test)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 453
        },
        "id": "3Rku1opg2Ijq",
        "outputId": "6476862e-6b76-4292-cd95-db341ae0d1c9"
      },
      "outputs": [],
      "source": [
        "from tqdm import tqdm\n",
        "\n",
        "alpha = 2.0\n",
        "\n",
        "seed = 1\n",
        "\n",
        "ov_tr = []\n",
        "ov_te = []\n",
        "\n",
        "# for scale in [0.1, 0.25, 0.5, 0.75, 1, 3]:\n",
        "for scale in [1]:\n",
        "  np.random.seed(seed)\n",
        "  torch.manual_seed(seed)\n",
        "  student = Net(scale=scale)\n",
        "\n",
        "  init(student, alpha=alpha)\n",
        "  _, scale = L2(student)\n",
        "\n",
        "  epochs = 10000\n",
        "  log = 200\n",
        "  wd = 0.05\n",
        "\n",
        "  optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4, weight_decay = wd)\n",
        "\n",
        "  losses_train = []\n",
        "  losses_test = []\n",
        "  accs_train = []\n",
        "  accs_test = []\n",
        "\n",
        "  l2s = []\n",
        "  threshold = 0.001\n",
        "\n",
        "\n",
        "  for epoch in tqdm(range(epochs)):  # loop over the dataset multiple times\n",
        "\n",
        "      optimizer.zero_grad()\n",
        "\n",
        "      outputs_train = student(inputs_train)\n",
        "      loss_train_vec = torch.mean((outputs_train-labels_train)**2, dim=1)\n",
        "      loss_train = torch.mean(loss_train_vec)\n",
        "      train_acc = torch.sum(loss_train_vec < threshold)/train_size\n",
        "\n",
        "      outputs_test = student(inputs_test)\n",
        "      loss_test_vec = torch.mean((outputs_test-labels_test)**2, dim=1)\n",
        "      loss_test = torch.mean(loss_test_vec)\n",
        "      test_acc = torch.sum(loss_test_vec < threshold)/test_size\n",
        "\n",
        "      params, l2 = L2(student)\n",
        "\n",
        "      loss_train.backward()\n",
        "\n",
        "      optimizer.step()\n",
        "\n",
        "      losses_train.append(loss_train.detach().numpy())\n",
        "      losses_test.append(loss_test.detach().numpy())\n",
        "      l2s.append(l2.detach().numpy())\n",
        "      accs_train.append(train_acc.detach().numpy())\n",
        "      accs_test.append(test_acc.detach().numpy())\n",
        "  ov_tr.append(accs_train)\n",
        "  ov_te.append(accs_test)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 632
        },
        "id": "d_sCHH-_2N_g",
        "outputId": "e82756f5-a94f-407f-8f5d-123420598a7d"
      },
      "outputs": [],
      "source": [
        "scales = [0.1, 0.25, 0.5, 0.75, 1, 3]\n",
        "plt.figure(figsize=(9, 6))\n",
        "colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'orange', 'purple', 'pink']\n",
        "for i, _ in enumerate(ov_tr):\n",
        "  plt.plot(np.arange(epochs), ov_tr[i], color=colors[i], linestyle='--', label=rf'Train, $\\alpha={scales[i]}$')\n",
        "  plt.plot(np.arange(epochs), ov_te[i], color=colors[i], label=rf'Test, $\\alpha={scales[i]}$')\n",
        "plt.xlabel(\"Epochs\", fontsize=20)\n",
        "plt.ylabel(f\"Accuracy\", fontsize=20)\n",
        "plt.xscale('log')\n",
        "plt.legend(fontsize=16)\n",
        "plt.tight_layout()\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 435
        },
        "id": "H2w0jTJd3VMO",
        "outputId": "d57fce9a-a414-4136-f03f-ec84b6d09829"
      },
      "outputs": [],
      "source": [
        "alpha = 2.0\n",
        "\n",
        "seed = 1\n",
        "\n",
        "thresh_tr, thresh_te = [], []\n",
        "\n",
        "for threshold in [1e-4, 5e-4, 5e-3, 0.01, 0.1]:\n",
        "  np.random.seed(seed)\n",
        "  torch.manual_seed(seed)\n",
        "  student = Net(scale=1)\n",
        "\n",
        "  init(student, alpha=alpha)\n",
        "  _, scale = L2(student)\n",
        "\n",
        "  epochs = 100000\n",
        "  log = 200\n",
        "  wd = 0.05\n",
        "\n",
        "  optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4, weight_decay = wd)\n",
        "\n",
        "  losses_train = []\n",
        "  losses_test = []\n",
        "  accs_train = []\n",
        "  accs_test = []\n",
        "\n",
        "  l2s = []\n",
        "\n",
        "  for epoch in tqdm(range(epochs)):  # loop over the dataset multiple times\n",
        "\n",
        "      optimizer.zero_grad()\n",
        "\n",
        "      outputs_train = student(inputs_train)\n",
        "      loss_train_vec = torch.mean((outputs_train-labels_train)**2, dim=1)\n",
        "      loss_train = torch.mean(loss_train_vec)\n",
        "      train_acc = torch.sum(loss_train_vec < threshold)/train_size\n",
        "\n",
        "      outputs_test = student(inputs_test)\n",
        "      loss_test_vec = torch.mean((outputs_test-labels_test)**2, dim=1)\n",
        "      loss_test = torch.mean(loss_test_vec)\n",
        "      test_acc = torch.sum(loss_test_vec < threshold)/test_size\n",
        "\n",
        "      params, l2 = L2(student)\n",
        "\n",
        "      loss_train.backward()\n",
        "\n",
        "      optimizer.step()\n",
        "\n",
        "      losses_train.append(loss_train.detach().numpy())\n",
        "      losses_test.append(loss_test.detach().numpy())\n",
        "      l2s.append(l2.detach().numpy())\n",
        "      accs_train.append(train_acc.detach().numpy())\n",
        "      accs_test.append(test_acc.detach().numpy())\n",
        "  thresh_tr.append(accs_train)\n",
        "  thresh_te.append(accs_test)\n",
        "\n",
        "\n",
        "threshes = [1e-4, 5e-4, 5e-3, 0.01, 0.1]\n",
        "plt.figure(figsize=(9, 6))\n",
        "colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'orange', 'purple', 'pink']\n",
        "for i, _ in enumerate(thresh_tr):\n",
        "  plt.plot(np.arange(epochs), thresh_tr[i], color=colors[i], linestyle='--', label=rf'Threshold = {threshes[i]}')\n",
        "  plt.plot(np.arange(epochs), thresh_te[i], color=colors[i], label=rf'Threshold = {threshes[i]}')\n",
        "plt.xlabel(\"Epochs\", fontsize=20)\n",
        "plt.ylabel(f\"Accuracy\", fontsize=20)\n",
        "plt.xscale('log')\n",
        "plt.legend(fontsize=15)\n",
        "plt.tight_layout()\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 539
        },
        "id": "2U26mPdP5348",
        "outputId": "2565cb64-ed4e-459b-c021-a9fd311b8955"
      },
      "outputs": [],
      "source": [
        "# INTERPOLATING FROM RANDOM TO ALIGNED SOLUTION WEIGHTS\n",
        "from jax import grad\n",
        "TRAIN_READOUTS = False\n",
        "CENTER_LOSS = False\n",
        "\n",
        "D = 40\n",
        "N = 2\n",
        "alpha = 1\n",
        "eps = 0.25 # large eps leads to lazy -> sharp drop\n",
        "P = 120\n",
        "alphas = [1]\n",
        "\n",
        "epochs = 20000\n",
        "\n",
        "ntk_interval = 200\n",
        "vars_compute_interval = 10\n",
        "\n",
        "\n",
        "if TRAIN_READOUTS: print('We ARE training readouts, a.')\n",
        "else: print('We are NOT training readouts, a.')\n",
        "\n",
        "all_tr_losses_eps = []\n",
        "all_te_losses_eps = []\n",
        "all_alignments, all_alignmentst = [], []\n",
        "\n",
        "for interp in [0, 0.05, 0.2, 0.5, 0.75, 0.95]:\n",
        "  def phi(z, eps = 0.25):\n",
        "      return z + 0.5*eps*z**2 # can also try other functional forms for phi and for the target\n",
        "\n",
        "  def NN_func2(params,X, alpha, eps = 0.25):\n",
        "      a, W = params\n",
        "      D = W.shape[1]\n",
        "      N = a.shape[0]\n",
        "      h = W @ X\n",
        "\n",
        "      ap = a.reshape(-1, 1)\n",
        "      if TRAIN_READOUTS: f = alpha * np.mean(ap * phi(h,eps),axis=0)\n",
        "      else: f = alpha * np.mean(phi(h,eps),axis=0) # w/o readouts\n",
        "\n",
        "      return f\n",
        "\n",
        "  def target_fn(beta, X):\n",
        "      return (X.T @ beta)**2/2.0\n",
        "\n",
        "  def ntk(X1, X2, params, NN_func2, alpha, eps):\n",
        "    a, W = params\n",
        "    D = W.shape[1]\n",
        "\n",
        "    f1 = NN_func2(params, X1, alpha, eps)\n",
        "    f2 = NN_func2(params, X2, alpha, eps)\n",
        "\n",
        "    grad_f_wrt_W_1 = jacrev(lambda W, X: NN_func2([params[0], W], X, alpha, eps))(params[1], X1)\n",
        "    grad_f_wrt_a_1 = jacrev(lambda a, X: NN_func2([a, params[1]], X, alpha, eps))(params[0], X1)\n",
        "\n",
        "    grad_f_wrt_W_2 = jacrev(lambda W, X: NN_func2([params[0], W], X, alpha, eps))(params[1], X2)\n",
        "    grad_f_wrt_a_2 = jacrev(lambda a, X: NN_func2([a, params[1]], X, alpha, eps))(params[0], X2)\n",
        "\n",
        "    grad_f_wrt_a_1 = grad_f_wrt_a_1.reshape(*grad_f_wrt_a_1.shape, 1).repeat(D, axis=-1) if grad_f_wrt_a_1.ndim < 3 else grad_f_wrt_a_1\n",
        "    grad_f_wrt_a_2 = grad_f_wrt_a_2.reshape(*grad_f_wrt_a_2.shape, 1).repeat(D, axis=-1) if grad_f_wrt_a_2.ndim < 3 else grad_f_wrt_a_2\n",
        "\n",
        "    grad_f_1 = jnp.concatenate([grad_f_wrt_a_1, grad_f_wrt_W_1], axis=1).reshape(*grad_f_wrt_a_1.shape[:2], -1)\n",
        "    grad_f_2 = jnp.concatenate([grad_f_wrt_a_2, grad_f_wrt_W_2], axis=1).reshape(*grad_f_wrt_a_2.shape[:2], -1)\n",
        "\n",
        "    return jnp.tensordot(grad_f_1, grad_f_2, axes=((1, 2), (1, 2)))\n",
        "\n",
        "  def get_K(params, X):\n",
        "      return ntk(X, X, params, NN_func2, alpha, eps)\n",
        "\n",
        "  def kernel_regression(X, y, Xt, yt, params):\n",
        "    K_train = ntk(X, X, params, NN_func2, alpha, eps)\n",
        "    K_train_reg = K_train + 1e-6 * jnp.eye(K_train.shape[0])\n",
        "\n",
        "    a = jnp.linalg.solve(K_train_reg, y)\n",
        "\n",
        "    def estimate(xt):\n",
        "      k_test_train = ntk(X, xt.reshape(1, -1).T, params, NN_func2, alpha, eps)\n",
        "      k_test_train_squeezed = jnp.squeeze(k_test_train)\n",
        "      return jnp.dot(k_test_train_squeezed, a)\n",
        "\n",
        "    estimates = vmap(estimate)(Xt.T)\n",
        "    mse = jnp.mean((estimates - yt) ** 2)\n",
        "    return mse\n",
        "\n",
        "  def kalignment(K, train_y):\n",
        "    train_yc = train_y - train_y.mean(axis=0)\n",
        "    Kc = K - K.mean(axis=0)\n",
        "    top = jnp.dot(jnp.dot(train_yc.T, Kc), train_yc)\n",
        "    bottom = jnp.linalg.norm(Kc) * (jnp.linalg.norm(train_yc)**2)\n",
        "    return jnp.trace(top)/bottom\n",
        "\n",
        "  X = random.normal(random.PRNGKey(0), (D,P))/ jnp.sqrt(D)\n",
        "  Xt = random.normal(random.PRNGKey(1), (D,1000))/ jnp.sqrt(D)\n",
        "  beta = random.normal(random.PRNGKey(2), (D,))\n",
        "\n",
        "  y = target_fn(beta, X)\n",
        "  yt = target_fn(beta,Xt)\n",
        "\n",
        "  W = interp * beta + (1 - interp) * random.normal(random.PRNGKey(0), (N, D)) # interp determines alignment\n",
        "  row1 = interp * beta * (1/np.sqrt(eps)) + (1 - interp) * random.normal(random.PRNGKey(0), (N, D))[0:1, :]\n",
        "  row2 = interp * beta * (-1/np.sqrt(eps)) + (1 - interp) * random.normal(random.PRNGKey(0), (N, D))[1:2, :]\n",
        "\n",
        "  # Create the new W array\n",
        "  W = jnp.concatenate([row1, row2, W[2:, :]], axis=0)\n",
        "\n",
        "  a = random.normal(random.PRNGKey(0), (N, ))\n",
        "  params = [a, W]\n",
        "\n",
        "  eta = 0.5 * N / alpha**2\n",
        "  lamb = 0\n",
        "  opt_init, opt_update, get_params = optimizers.sgd(eta)\n",
        "\n",
        "  opt_state = opt_init(params)\n",
        "  if CENTER_LOSS:\n",
        "    loss_fn = jit(lambda p, X, y: jnp.mean( ( NN_func2(p, X,alpha, eps)- NN_func2(params,X,alpha, eps) - y )**2 ))\n",
        "  else:\n",
        "    loss_fn = jit(lambda p, X, y: jnp.mean( ( NN_func2(p, X,alpha, eps) - y )**2 ))\n",
        "  reg_loss = jit(lambda p, X, y: loss_fn(p,X,y) + lamb / alpha * optimizers.l2_norm(p)**2 )\n",
        "\n",
        "  grad_loss = jit(grad(reg_loss,0))\n",
        "\n",
        "  tr_losses = []\n",
        "  te_losses = []\n",
        "\n",
        "  alignments, alignmentst = [], []\n",
        "  epochs_to_plot = []\n",
        "\n",
        "  t1s, t2s, t3s, epochs_to_compute = [], [], [], []\n",
        "  t1sm, t2sm, t3sm, ts_summ = [], [], [], []\n",
        "  ts_sum = []\n",
        "  alignments, alignmentst = [], []\n",
        "\n",
        "  kmse = kernel_regression(X, y, Xt, yt, get_params(opt_state))\n",
        "\n",
        "  for t in tqdm(range(epochs)):\n",
        "    opt_state = opt_update(t, grad_loss(get_params(opt_state), X, y), opt_state)\n",
        "    pars = get_params(opt_state)\n",
        "\n",
        "    train_loss = loss_fn(pars, X, y)\n",
        "    test_loss = loss_fn(pars, Xt, yt)\n",
        "    tr_losses += [train_loss]\n",
        "    te_losses += [test_loss]\n",
        "\n",
        "    if t % vars_compute_interval == 0:\n",
        "        W = pars[1]\n",
        "        A = W @ beta / D\n",
        "        M = W.T @ W / N\n",
        "\n",
        "        # in terms of M and A\n",
        "        t1m = ((alpha * eps/(2*N)) * (N/D) * np.trace(M) - np.mean(beta**2)/2)**2\n",
        "        t2m = (1/2) * (np.mean(beta**2)**2 - 2 * alpha * eps * (1/N) * np.sum(A**2) + alpha**2 *(eps/N)**2 * (N**2/D**2)*np.trace(M@M)  )\n",
        "        t3m = alpha**2 * (1/D) * np.linalg.norm(np.mean(W, axis=0))**2\n",
        "\n",
        "        t1sm.append(t1m)\n",
        "        t2sm.append(t2m)\n",
        "        t3sm.append(t3m)\n",
        "        ts_summ.append(t1m + t2m + t3m)\n",
        "\n",
        "        epochs_to_compute.append(t)\n",
        "\n",
        "    if t % ntk_interval == 0:\n",
        "          K = get_K(get_params(opt_state), X)\n",
        "          align = kalignment(K, y[:, None]) # Reshape y to 2D here\n",
        "          Kt = get_K(get_params(opt_state), Xt)\n",
        "          alignt = kalignment(Kt, yt[:, None]) # Reshape y to 2D here\n",
        "          alignments.append(align)\n",
        "          alignmentst.append(alignt)\n",
        "          epochs_to_plot.append(t)\n",
        "\n",
        "  all_tr_losses_eps += [tr_losses]\n",
        "  all_te_losses_eps += [te_losses]\n",
        "\n",
        "  from scipy.interpolate import make_interp_spline, BSpline\n",
        "\n",
        "  T_compute = np.array(epochs_to_compute)\n",
        "\n",
        "  t1sm_spline = make_interp_spline(T_compute, np.array(t1sm), k=3)\n",
        "  t2sm_spline = make_interp_spline(T_compute, np.array(t2sm), k=3)\n",
        "  t3sm_spline = make_interp_spline(T_compute, np.array(t3sm), k=3)\n",
        "\n",
        "  t1sm_smooth = t1sm_spline(np.arange(epochs))\n",
        "  t2sm_smooth = t2sm_spline(np.arange(epochs))\n",
        "  t3sm_smooth = t3sm_spline(np.arange(epochs))\n",
        "  ts_summ_smooth = make_interp_spline(T_compute, np.array(ts_summ), k=3)(np.arange(epochs))\n",
        "\n",
        "  def ftl(align_smooth, thresh):\n",
        "    for i, value in enumerate(align_smooth):\n",
        "        if value / align_smooth[0] > thresh:\n",
        "            return i\n",
        "    # return a sentinel value if the condition is never met\n",
        "    return -1\n",
        "\n",
        "  T = np.array(epochs_to_plot)\n",
        "  align_new = np.array(alignments)\n",
        "  align_new_test = np.array(alignmentst)\n",
        "\n",
        "  # Create the splines\n",
        "  align_spline = make_interp_spline(T, align_new, k=3)\n",
        "  align_spline_test = make_interp_spline(T, align_new_test, k=3)\n",
        "\n",
        "  # Create the smooth curves\n",
        "  align_smooth = align_spline(np.arange(epochs))\n",
        "  align_smooth_test = align_spline_test(np.arange(epochs))\n",
        "  all_alignmentst.append(align_smooth_test)\n",
        "\n",
        "  def get_random_color():\n",
        "      return np.random.rand(3,)\n",
        "\n",
        "colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'orange', 'purple', 'pink']\n",
        "\n",
        "plt.figure(figsize=(15, 6))\n",
        "interps = [0, 0.05, 0.2, 0.5, 0.75, 0.95]\n",
        "i = 0\n",
        "for tr_losses, te_losses in zip(all_tr_losses_eps, all_te_losses_eps):\n",
        "  if i != 1:\n",
        "    plt.plot(tr_losses, label=f'Train, init weights {round(interps[i]*100, 2)}% solution, {round((1-interps[i])*100, 2)}% random', linestyle='--', color=colors[i])\n",
        "    plt.plot(te_losses, label=f'Test,  init weights {round(interps[i]*100, 2)}% solution, {round((1-interps[i])*100, 2)}% random', color=colors[i])\n",
        "  i += 1\n",
        "plt.xscale('log')\n",
        "plt.xlabel('Epochs', fontsize=20)\n",
        "plt.ylabel('MSE', fontsize=20)\n",
        "plt.legend(fontsize=14, loc='upper left', bbox_to_anchor=(1, 1))\n",
        "plt.tight_layout()\n",
        "plt.savefig('interp_sweep.pdf')\n",
        "plt.show()\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "iXZdnGMwNBv9",
        "outputId": "0d3b01a6-1e9d-4388-f1e1-a0577800e4da"
      },
      "outputs": [],
      "source": [
        "# GOLDILOCKS DATA\n",
        "D = 40\n",
        "N = 100\n",
        "alpha = 1\n",
        "eps = 0.5\n",
        "Ps = [50, 120, 500]\n",
        "\n",
        "epochs = 50000\n",
        "\n",
        "ntk_interval = 200\n",
        "vars_compute_interval = 10\n",
        "\n",
        "\n",
        "all_tr_losses_eps = []\n",
        "all_te_losses_eps = []\n",
        "all_alignments, all_alignmentst = [], []\n",
        "\n",
        "for P in Ps:\n",
        "  def phi(z, eps = 0.25):\n",
        "      return z + 0.5*eps*z**2\n",
        "\n",
        "  def NN_func2(params,X, alpha, eps = 0.25):\n",
        "      a, W = params\n",
        "      D = W.shape[1]\n",
        "      N = a.shape[0]\n",
        "      h = W @ X\n",
        "\n",
        "      f = alpha * np.mean(phi(h,eps),axis=0) # w/o readouts\n",
        "\n",
        "      return f\n",
        "\n",
        "  def target_fn(beta, X):\n",
        "      return (X.T @ beta)**2/2.0\n",
        "\n",
        "  def ntk(X1, X2, params, NN_func2, alpha, eps):\n",
        "    a, W = params\n",
        "    D = W.shape[1]\n",
        "\n",
        "    f1 = NN_func2(params, X1, alpha, eps)\n",
        "    f2 = NN_func2(params, X2, alpha, eps)\n",
        "\n",
        "    grad_f_wrt_W_1 = jacrev(lambda W, X: NN_func2([params[0], W], X, alpha, eps))(params[1], X1)\n",
        "    grad_f_wrt_a_1 = jacrev(lambda a, X: NN_func2([a, params[1]], X, alpha, eps))(params[0], X1)\n",
        "\n",
        "    grad_f_wrt_W_2 = jacrev(lambda W, X: NN_func2([params[0], W], X, alpha, eps))(params[1], X2)\n",
        "    grad_f_wrt_a_2 = jacrev(lambda a, X: NN_func2([a, params[1]], X, alpha, eps))(params[0], X2)\n",
        "\n",
        "    grad_f_wrt_a_1 = grad_f_wrt_a_1.reshape(*grad_f_wrt_a_1.shape, 1).repeat(D, axis=-1) if grad_f_wrt_a_1.ndim < 3 else grad_f_wrt_a_1\n",
        "    grad_f_wrt_a_2 = grad_f_wrt_a_2.reshape(*grad_f_wrt_a_2.shape, 1).repeat(D, axis=-1) if grad_f_wrt_a_2.ndim < 3 else grad_f_wrt_a_2\n",
        "\n",
        "    grad_f_1 = jnp.concatenate([grad_f_wrt_a_1, grad_f_wrt_W_1], axis=1).reshape(*grad_f_wrt_a_1.shape[:2], -1)\n",
        "    grad_f_2 = jnp.concatenate([grad_f_wrt_a_2, grad_f_wrt_W_2], axis=1).reshape(*grad_f_wrt_a_2.shape[:2], -1)\n",
        "\n",
        "    return jnp.tensordot(grad_f_1, grad_f_2, axes=((1, 2), (1, 2)))\n",
        "\n",
        "  def get_K(params, X):\n",
        "      return ntk(X, X, params, NN_func2, alpha, eps)\n",
        "\n",
        "  def kernel_regression(X, y, Xt, yt, params):\n",
        "    K_train = ntk(X, X, params, NN_func2, alpha, eps)\n",
        "    K_train_reg = K_train + 1e-6 * jnp.eye(K_train.shape[0])\n",
        "\n",
        "    a = jnp.linalg.solve(K_train_reg, y)\n",
        "\n",
        "    def estimate(xt):\n",
        "      k_test_train = ntk(X, xt.reshape(1, -1).T, params, NN_func2, alpha, eps)\n",
        "      k_test_train_squeezed = jnp.squeeze(k_test_train)\n",
        "      return jnp.dot(k_test_train_squeezed, a)\n",
        "\n",
        "    estimates = vmap(estimate)(Xt.T)\n",
        "    mse = jnp.mean((estimates - yt) ** 2)\n",
        "    return mse\n",
        "\n",
        "  def kalignment(K, train_y):\n",
        "    train_yc = train_y - train_y.mean(axis=0)\n",
        "    Kc = K - K.mean(axis=0)\n",
        "    top = jnp.dot(jnp.dot(train_yc.T, Kc), train_yc)\n",
        "    bottom = jnp.linalg.norm(Kc) * (jnp.linalg.norm(train_yc)**2)\n",
        "    return jnp.trace(top)/bottom\n",
        "\n",
        "  X = random.normal(random.PRNGKey(10), (D,P))/ jnp.sqrt(D)\n",
        "  Xt = random.normal(random.PRNGKey(11), (D,1000))/ jnp.sqrt(D)\n",
        "  beta = random.normal(random.PRNGKey(12), (D,))\n",
        "\n",
        "  y = target_fn(beta, X)\n",
        "  yt = target_fn(beta,Xt)\n",
        "\n",
        "  W = random.normal(random.PRNGKey(0), (N, D))\n",
        "  a = random.normal(random.PRNGKey(0), (N, ))\n",
        "  params = [a, W]\n",
        "\n",
        "  eta = 50\n",
        "  lamb = 0\n",
        "  opt_init, opt_update, get_params = optimizers.sgd(eta)\n",
        "\n",
        "  opt_state = opt_init(params)\n",
        "  loss_fn = jit(lambda p, X, y: jnp.mean( ( NN_func2(p, X,alpha, eps)- NN_func2(params,X,alpha, eps) - y )**2 ))\n",
        "  reg_loss = jit(lambda p, X, y: loss_fn(p,X,y) + lamb / alpha * optimizers.l2_norm(p)**2 )\n",
        "\n",
        "  grad_loss = jit(grad(reg_loss,0))\n",
        "\n",
        "  tr_losses = []\n",
        "  te_losses = []\n",
        "\n",
        "  alignments, alignmentst = [], []\n",
        "  epochs_to_plot = []\n",
        "\n",
        "  t1s, t2s, t3s, epochs_to_compute = [], [], [], []\n",
        "  t1sm, t2sm, t3sm, ts_summ = [], [], [], []\n",
        "  ts_sum = []\n",
        "  alignments, alignmentst = [], []\n",
        "\n",
        "  kmse = kernel_regression(X, y, Xt, yt, get_params(opt_state))\n",
        "\n",
        "  for t in tqdm(range(epochs)):\n",
        "    opt_state = opt_update(t, grad_loss(get_params(opt_state), X, y), opt_state)\n",
        "    pars = get_params(opt_state)\n",
        "\n",
        "    train_loss = loss_fn(pars, X, y)\n",
        "    test_loss = loss_fn(pars, Xt, yt)\n",
        "    tr_losses += [train_loss]\n",
        "    te_losses += [test_loss]\n",
        "\n",
        "    if t % vars_compute_interval == 0:\n",
        "        W = pars[1]\n",
        "        A = W @ beta / D\n",
        "        M = W.T @ W / N\n",
        "\n",
        "        # in terms of M and A\n",
        "        t1m = ((alpha * eps/(2*N)) * (N/D) * np.trace(M) - np.mean(beta**2)/2)**2\n",
        "        t2m = (1/2) * (np.mean(beta**2)**2 - 2 * alpha * eps * (1/N) * np.sum(A**2) + alpha**2 *(eps/N)**2 * (N**2/D**2)*np.trace(M@M)  )\n",
        "        t3m = alpha**2 * (1/D) * np.linalg.norm(np.mean(W, axis=0))**2\n",
        "\n",
        "        t1sm.append(t1m)\n",
        "        t2sm.append(t2m)\n",
        "        t3sm.append(t3m)\n",
        "        ts_summ.append(t1m + t2m + t3m)\n",
        "\n",
        "        epochs_to_compute.append(t)\n",
        "\n",
        "    if t % ntk_interval == 0:\n",
        "          K = get_K(get_params(opt_state), X)\n",
        "          align = kalignment(K, y[:, None]) # Reshape y to 2D here\n",
        "          Kt = get_K(get_params(opt_state), Xt)\n",
        "          alignt = kalignment(Kt, yt[:, None]) # Reshape y to 2D here\n",
        "          alignments.append(align)\n",
        "          alignmentst.append(alignt)\n",
        "          epochs_to_plot.append(t)\n",
        "\n",
        "  T_compute = np.array(epochs_to_compute)\n",
        "\n",
        "  t1sm_spline = make_interp_spline(T_compute, np.array(t1sm), k=3)\n",
        "  t2sm_spline = make_interp_spline(T_compute, np.array(t2sm), k=3)\n",
        "  t3sm_spline = make_interp_spline(T_compute, np.array(t3sm), k=3)\n",
        "\n",
        "  t1sm_smooth = t1sm_spline(np.arange(epochs))\n",
        "  t2sm_smooth = t2sm_spline(np.arange(epochs))\n",
        "  t3sm_smooth = t3sm_spline(np.arange(epochs))\n",
        "  ts_summ_smooth = make_interp_spline(T_compute, np.array(ts_summ), k=3)(np.arange(epochs))\n",
        "\n",
        "  def ftl(align_smooth, thresh):\n",
        "    for i, value in enumerate(align_smooth):\n",
        "        if value / align_smooth[0] > thresh:\n",
        "            return i\n",
        "    # return a sentinel value if the condition is never met\n",
        "    return -1\n",
        "\n",
        "  tit = f'N={N}, P={P}, D={D}, eps={eps}, lr={round(eta, 2)}, alpha={round(alpha, 2)}, lamb={lamb}'\n",
        "\n",
        "  def get_random_color():\n",
        "      return np.random.rand(3,)\n",
        "\n",
        "  plt.figure(figsize=(9, 6))\n",
        "  plt.plot(tr_losses, label='Train Loss')\n",
        "  plt.plot(te_losses, label='Test Loss')\n",
        "  plt.xscale('log')\n",
        "  plt.xlabel('Epochs', fontsize=20)\n",
        "  plt.ylabel('MSE', fontsize=20)\n",
        "  plt.legend(fontsize=16)\n",
        "  plt.tight_layout()\n",
        "  plt.show()\n",
        "\n",
        "  plt.figure(figsize=(9, 6))\n",
        "  plt.plot(ts_summ_smooth, label='Full Test Loss', color='darkorange')\n",
        "  plt.plot(t1sm_smooth, label='Variance error component', color='red', linestyle='--')\n",
        "  plt.plot(t2sm_smooth, label='Alignment error component', color='green', linestyle='--')\n",
        "  plt.plot(t3sm_smooth, label='Linear term error component', color='black', linestyle='--')\n",
        "  plt.xscale('log')\n",
        "  plt.ylabel('MSE', fontsize=20)\n",
        "  plt.xlabel('Epochs', fontsize=20)\n",
        "  plt.legend(fontsize=16)\n",
        "  plt.tight_layout()\n",
        "  plt.show()\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_zxMix3X9Pk6"
      },
      "outputs": [],
      "source": [
        "# alter loss to train readouts\n",
        "\n",
        "# Figure 12. Loss Decomposition WITH READOUTS.\n",
        "D = 40\n",
        "N = 2\n",
        "alpha = 1\n",
        "eps = 0.25\n",
        "P = 120\n",
        "alphas = [1]\n",
        "\n",
        "epochs = 100000\n",
        "\n",
        "ntk_interval = 200\n",
        "vars_compute_interval = 10\n",
        "\n",
        "all_tr_losses_eps = []\n",
        "all_te_losses_eps = []\n",
        "all_alignments, all_alignmentst = [], []\n",
        "\n",
        "for alpha in alphas:\n",
        "  def phi(z, eps = 0.25):\n",
        "      return z + 0.5*eps*z**2\n",
        "\n",
        "  def target_fn(beta, X):\n",
        "    return (X.T @ beta)**2/2.0\n",
        "\n",
        "  def NN_func2(params,X, alpha, eps = 0.25):\n",
        "      a, W = params\n",
        "      D = W.shape[1]\n",
        "      N = a.shape[0]\n",
        "      h = W @ X\n",
        "\n",
        "      ap = a.reshape(-1, 1)\n",
        "      f = alpha * np.mean(ap * phi(h,eps),axis=0) # train readouts\n",
        "      # f = alpha * np.mean(phi(h,eps),axis=0) # w/o readouts\n",
        "\n",
        "      return f\n",
        "\n",
        "\n",
        "  X = random.normal(random.PRNGKey(0), (D,P))/ jnp.sqrt(D)\n",
        "  Xt = random.normal(random.PRNGKey(1), (D,1000))/ jnp.sqrt(D)\n",
        "  beta = random.normal(random.PRNGKey(2), (D,))\n",
        "\n",
        "  y = target_fn(beta, X)\n",
        "  yt = target_fn(beta,Xt)\n",
        "\n",
        "  W = random.normal(random.PRNGKey(0), (N, D))\n",
        "\n",
        "  a = random.normal(random.PRNGKey(0), (N, ))\n",
        "  params = [a, W]\n",
        "\n",
        "  eta = 5e-2 * N / alpha**2\n",
        "  lamb = 0\n",
        "  opt_init, opt_update, get_params = optimizers.sgd(eta)\n",
        "\n",
        "  opt_state = opt_init(params)\n",
        "  # loss_fn = jit(lambda p, X, y: jnp.mean( ( NN_func2(p, X,alpha, eps)- NN_func2(params,X,alpha, eps) - y )**2 ))\n",
        "  loss_fn = jit(lambda p, X, y: jnp.mean( ( NN_func2(p, X,alpha, eps)- y )**2 ))\n",
        "  reg_loss = jit(lambda p, X, y: loss_fn(p,X,y) + lamb / alpha * optimizers.l2_norm(p)**2 )\n",
        "\n",
        "  grad_loss = jit(grad(reg_loss,0))\n",
        "\n",
        "  tr_losses = []\n",
        "  te_losses = []\n",
        "\n",
        "  epochs_to_plot = []\n",
        "\n",
        "  t1s, t2s, t3s, epochs_to_compute = [], [], [], []\n",
        "  t1sm, t2sm, t3sm, ts_summ = [], [], [], []\n",
        "  ts_sum = []\n",
        "\n",
        "  wbar0 = np.dot(W.T, a) / N\n",
        "  M0 = np.dot(W.T, np.dot(np.diag(a), W)) / N\n",
        "\n",
        "  for t in tqdm(range(epochs)):\n",
        "    opt_state = opt_update(t, grad_loss(get_params(opt_state), X, y), opt_state)\n",
        "    pars = get_params(opt_state)\n",
        "\n",
        "    train_loss = loss_fn(pars, X, y)\n",
        "    test_loss = loss_fn(pars, Xt, yt)\n",
        "    tr_losses += [train_loss]\n",
        "    te_losses += [test_loss]\n",
        "\n",
        "    if t % vars_compute_interval == 0:\n",
        "        a, W = pars\n",
        "\n",
        "        wbar = W.T @ a / N\n",
        "        M = W.T @ np.diag(a) @ W / N\n",
        "\n",
        "        t1m = (alpha * eps/2.0 * 1.0/D * np.trace(M) - np.mean(beta**2)/2)**2\n",
        "        t2m = (1/2) * np.mean( ( alpha*eps* M - np.outer(beta, beta) )**2 )\n",
        "        t3m = alpha**2 * (1/D) * np.linalg.norm(wbar)**2\n",
        "\n",
        "        t1sm.append(t1m)\n",
        "        t2sm.append(t2m)\n",
        "        t3sm.append(t3m)\n",
        "        ts_summ.append(t1m + t2m + t3m)\n",
        "\n",
        "        epochs_to_compute.append(t)\n",
        "\n",
        "  all_tr_losses_eps += [tr_losses]\n",
        "  all_te_losses_eps += [te_losses]\n",
        "\n",
        "  T_compute = np.array(epochs_to_compute)\n",
        "\n",
        "  t1sm_spline = make_interp_spline(T_compute, np.array(t1sm), k=3)\n",
        "  t2sm_spline = make_interp_spline(T_compute, np.array(t2sm), k=3)\n",
        "  t3sm_spline = make_interp_spline(T_compute, np.array(t3sm), k=3)\n",
        "\n",
        "  t1sm_smooth = t1sm_spline(np.arange(epochs))\n",
        "  t2sm_smooth = t2sm_spline(np.arange(epochs))\n",
        "  t3sm_smooth = t3sm_spline(np.arange(epochs))\n",
        "  ts_summ_smooth = make_interp_spline(T_compute, np.array(ts_summ), k=3)(np.arange(epochs))\n",
        "\n",
        "  T = np.array(epochs_to_plot)\n",
        "\n",
        "  def get_random_color():\n",
        "      return np.random.rand(3,)\n",
        "\n",
        "  plt.figure(figsize=(9, 6))\n",
        "  plt.plot(tr_losses, label='Train Loss')\n",
        "  plt.plot(te_losses, label='Test Loss')\n",
        "  plt.xscale('log')\n",
        "  plt.xlabel('Epochs', fontsize=20)\n",
        "  plt.ylabel('MSE', fontsize=20)\n",
        "  plt.legend(fontsize=14)\n",
        "  plt.tight_layout()\n",
        "  plt.show()\n",
        "\n",
        "  plt.figure(figsize=(9, 6))\n",
        "  plt.plot(ts_summ_smooth, label='Full Test Loss', color='darkorange')\n",
        "  plt.plot(t1sm_smooth, label='Variance error component', color='red', linestyle='--')\n",
        "  plt.plot(t2sm_smooth, label='Alignment error component', color='green', linestyle='--')\n",
        "  plt.plot(t3sm_smooth, label='Linear term error component', color='black', linestyle='--')\n",
        "  plt.xscale('log')\n",
        "  plt.ylabel('MSE', fontsize=20)\n",
        "  plt.xlabel('Epochs', fontsize=20)\n",
        "  plt.legend(fontsize=14)\n",
        "  plt.tight_layout()\n",
        "  plt.show()\n",
        "\n",
        "  # more acc for large N\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dSRfDrA2Cu02"
      },
      "outputs": [],
      "source": [
        "# Role of weight deay\n",
        "  # All three plots computed by re-running code for Figure 6(b) with scale (alpha) and weight decay as given in manuscript."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EfZy8pP3Cwew"
      },
      "outputs": [],
      "source": [
        "# Role of adaptive optimizers"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FrMhw7AuDDP6"
      },
      "outputs": [],
      "source": [
        "# (a)\n",
        "# Same as that for figure 1.\n",
        "  '''\n",
        "  Just replace the line: opt_init, opt_update, get_params = optimizers.sgd(eta)\n",
        "    with: opt_init, opt_update, get_params = optimizers.momentum(eta, 0.95)\n",
        "      and the line: opt_init_lin, opt_update_lin, get_params_lin = optimizers.sgd(eta)\n",
        "        with: opt_init_lin, opt_update_lin, get_params_lin = optimizers.momentum(eta, 0.95)\n",
        "\n",
        "  '''\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GhnbdqmfDEQp"
      },
      "outputs": [],
      "source": [
        "# (b)\n",
        "def kernel_regression(X, y, Xt, yt, params):\n",
        "    K_train = ntk_fn(X, None, params)\n",
        "    a = jnp.linalg.solve(K_train, y)\n",
        "\n",
        "    def estimate(xt):\n",
        "      k_test_train = ntk_fn(X, xt.reshape(1, -1).T, params)\n",
        "      k_test_train_squeezed = jnp.squeeze(k_test_train)\n",
        "      return jnp.dot(k_test_train_squeezed, a)\n",
        "\n",
        "    estimates = vmap(estimate)(Xt.T)\n",
        "    mse = jnp.mean((estimates - yt) ** 2)\n",
        "    return mse\n",
        "\n",
        "def L2(model_params):\n",
        "    params_flatten = [param.flatten() for param in model_params]\n",
        "    params_flatten = jnp.concatenate(params_flatten)\n",
        "    l2 = jnp.sum(params_flatten ** 2)\n",
        "    return params_flatten, l2\n",
        "\n",
        "alpha = 1\n",
        "def MLP(params, X):\n",
        "  w1, w2, w3 = params\n",
        "  z1 = jnp.tanh(w1 @ X)\n",
        "  z2 = jnp.tanh(w2 @ z1)\n",
        "  return (w3 @ z2)*alpha\n",
        "\n",
        "D = 5\n",
        "P = 100\n",
        "wd = 0.08\n",
        "thresh = 0.001\n",
        "lr = 3e-4\n",
        "\n",
        "X = random.normal(random.PRNGKey(0), (D, P))\n",
        "Xt = random.normal(random.PRNGKey(1), (D,P))\n",
        "\n",
        "w1t, w2t, w3t = random.normal(random.PRNGKey(0), (100, 5)), random.normal(random.PRNGKey(1), (100, 100)), random.normal(random.PRNGKey(2), (5, 100))\n",
        "teacher_params = [w1t, w2t, w3t]\n",
        "y = MLP(teacher_params, X)\n",
        "yt = MLP(teacher_params, Xt)\n",
        "\n",
        "ntk_fn = nt.empirical_ntk_fn(\n",
        "    MLP, vmap_axes=0, trace_axes=(),\n",
        "    implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES)\n",
        "\n",
        "\n",
        "ov_tr, ov_te = [], []\n",
        "ovl_tr, ovl_te = [], []\n",
        "wds = [0.1]\n",
        "for wd in wds:\n",
        "\n",
        "  w1s, w2s, w3s = random.normal(random.PRNGKey(3), (100, 5)), random.normal(random.PRNGKey(4), (100, 100)), random.normal(random.PRNGKey(5), (5, 100))\n",
        "  student_params = [w1s, w2s, w3s]\n",
        "  student_params_init = deepcopy([w1s, w2s, w3s])\n",
        "  student_params_lin = deepcopy(student_params)\n",
        "\n",
        "  optimizer = optax.adamw(learning_rate=lr*10, weight_decay=wd)\n",
        "  opt_state = optimizer.init(student_params)\n",
        "\n",
        "  f_lin = nt.linearize(MLP, student_params)\n",
        "  optimizer_lin = optax.sgd(learning_rate=lr)\n",
        "  opt_state_lin = optimizer.init(student_params_lin)\n",
        "\n",
        "\n",
        "  loss_fn = jit(lambda p, X, y: jnp.mean( ( MLP(p, X) - y )**2 ))\n",
        "  loss_lin = jit(lambda p, X, y: jnp.mean( ( f_lin(p, X) - y )**2 ))\n",
        "\n",
        "  grad_loss = jit(grad(loss_fn, 0))\n",
        "  grad_loss_lin = jit(grad(loss_lin, 0))\n",
        "\n",
        "  epochs = 20000\n",
        "  tr_losses, te_losses = [], []\n",
        "  lin_tr_losses, lin_te_losses = [], []\n",
        "\n",
        "  for t in tqdm(range(epochs)):\n",
        "    grads = grad_loss(student_params, X, y)\n",
        "    updates, opt_state = optimizer.update(grads, opt_state, student_params)\n",
        "    student_params = optax.apply_updates(student_params, updates)\n",
        "\n",
        "    grads_lin = grad_loss_lin(student_params_lin, X, y)\n",
        "    updates_lin, opt_state_lin = optimizer.update(grads_lin, opt_state_lin, student_params_lin)\n",
        "    student_params_lin = optax.apply_updates(student_params_lin, updates_lin)\n",
        "\n",
        "    train_loss = loss_fn(student_params, X, y)/alpha\n",
        "    test_loss = loss_fn(student_params, Xt, yt)/alpha\n",
        "    tr_losses += [train_loss]\n",
        "    te_losses += [test_loss]\n",
        "\n",
        "    lin_train_loss = loss_lin(student_params_lin, X, y)/alpha\n",
        "    lin_test_loss = loss_lin(student_params_lin, Xt, yt)/alpha\n",
        "\n",
        "    lin_tr_losses += [lin_train_loss]\n",
        "    lin_te_losses += [lin_test_loss]\n",
        "\n",
        "  plt.figure()\n",
        "  t = [i for i in range(epochs)]\n",
        "  plt.plot(t, tr_losses, label='Train Loss')\n",
        "  plt.plot(t, te_losses, label='Test Loss')\n",
        "  plt.plot(t, lin_tr_losses, label='Linearized Train')\n",
        "  plt.plot(t, lin_te_losses, label='Linearized Test')\n",
        "  plt.title(f'Student-Teacher MLP, AdamW with wd={wd}')\n",
        "  plt.legend()\n",
        "  plt.show()\n",
        "  ov_tr += [tr_losses]; ov_te += [te_losses]\n",
        "  ovl_tr += [lin_tr_losses]; ovl_te += [lin_te_losses]\n",
        "\n",
        "  plt.figure()\n",
        "  t = [i for i in range(epochs)]\n",
        "  plt.plot(t, tr_losses, label='Train Loss')\n",
        "  plt.plot(t, te_losses, label='Test Loss')\n",
        "  plt.plot(t, lin_tr_losses, label='Linearized Train')\n",
        "  plt.plot(t, lin_te_losses, label='Linearized Test')\n",
        "  plt.xscale('log')\n",
        "  plt.legend()\n",
        "  plt.tight_layout()\n",
        "  plt.savefig('')\n",
        "  plt.show()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Fat4OQWLCx2j"
      },
      "outputs": [],
      "source": [
        "# (c) Reimplementing Nanda's Transformer in JAX with his hyperparams.\n",
        "lr=1e-3 #@param\n",
        "weight_decay = 1.0 #@param\n",
        "P = 113\n",
        "d_model = 128 #@param\n",
        "fn_name = 'add' #@param ['add', 'subtract', 'x2xyy2','rand']\n",
        "frac_train = 0.4 #@param\n",
        "num_epochs = 50000 #@param\n",
        "save_models = True #@param\n",
        "save_every = 100 #@param\n",
        "stopping_thresh = -1 #@param\n",
        "seed = 0 #@param\n",
        "\n",
        "num_layers = 1\n",
        "batch_style = 'full'\n",
        "d_vocab = P+1\n",
        "n_ctx = 3\n",
        "d_mlp = 4*d_model\n",
        "num_heads = 4\n",
        "assert d_model % num_heads == 0\n",
        "d_head = d_model//num_heads\n",
        "act_type = 'ReLU' #@param ['ReLU', 'GeLU']\n",
        "use_ln = False\n",
        "random_answers = np.random.randint(low=0, high=P, size=(P, P))\n",
        "fns_dict = {'add': lambda x,y:(x+y)%P, 'subtract': lambda x,y:(x-y)%P, 'x2xyy2':lambda x,y:(x**2+x*y+y**2)%p, 'rand':lambda x,y:random_answers[x][y]}\n",
        "fn = fns_dict[fn_name]\n",
        "\n",
        "P = 113\n",
        "seed = 1\n",
        "\n",
        "import random as nrandom\n",
        "def gen_train_test(frac_train, num, seed=0):\n",
        "    # Generate train and test split\n",
        "    pairs = [(i, j, num) for i in range(num) for j in range(num)]\n",
        "    nrandom.seed(seed)\n",
        "    nrandom.shuffle(pairs)\n",
        "    div = int(frac_train*len(pairs))\n",
        "    return pairs[:div], pairs[div:]\n",
        "\n",
        "train, test = gen_train_test(frac_train, P, seed)\n",
        "\n",
        "\n",
        "# Implementation of Nanda's 1-layer-transformer as a functional JAX function\n",
        "  # so that we can linearize it as nt.linearize and track kernel GD\n",
        "def transformer_forward(p, x):\n",
        "  (WE, Wpos, WK, WQ, WV, WO, Win, bin, Wout, bout, WU, scale) = p\n",
        "\n",
        "  # Embed + PosEmbed\n",
        "  x = jnp.einsum('dbp -> bpd', WE[:, x])  # [5107, 3, 114]\n",
        "  x = x + Wpos[:x.shape[-2]] # [5107, 3, 114]\n",
        "\n",
        "  # Attention\n",
        "  k = jnp.einsum('ihd,bpd->biph', WK, x) # [5107, 4, 3, 32]\n",
        "  q = jnp.einsum('ihd,bpd->biph', WQ, x)\n",
        "  v = jnp.einsum('ihd,bpd->biph', WV, x)\n",
        "  attn_scores_pre = jnp.einsum('biph, biqh-> biqp', k, q) # [5107, 4, 3, 3]\n",
        "  lower_tri = jnp.tril(attn_scores_pre)\n",
        "  mask = jnp.tril(jnp.ones((n_ctx, n_ctx)))\n",
        "  inverted_mask = 1 - mask[:x.shape[-2], :x.shape[-2]]\n",
        "  large_negative = -1e10 * inverted_mask\n",
        "  attn_scores_masked = lower_tri + large_negative\n",
        "  attn_matrix = jax.nn.softmax(attn_scores_masked/np.sqrt(d_head), axis=-1)\n",
        "  z = jnp.einsum('biph,biqp->biqh', v, attn_matrix)\n",
        "  z_flat = einops.rearrange(z, 'b i q h -> b q (i h)')\n",
        "  out = jnp.einsum('df,bqf->bqd', WO, z_flat)\n",
        "  x = x + out\n",
        "\n",
        "  # MLP\n",
        "  mlp_out = jnp.einsum('md,bpd->bpm', Win, x) + bin\n",
        "  mlp_out = jax.nn.relu(mlp_out)\n",
        "  mlp_out = jnp.einsum('dm,bpm->bpd', Wout, mlp_out) + bout\n",
        "  x = x + mlp_out # [5107, 3, 128]\n",
        "\n",
        "  # Unembed\n",
        "  x = x @ WU # [5107, 3, 114]\n",
        "\n",
        "  return x*scale\n",
        "\n",
        "scales = [0.15, 1]\n",
        "for scale in scales:\n",
        "  d_vocab = P+1\n",
        "  alpha = 1\n",
        "  d_model = 128\n",
        "  d_mlp = 4*d_model\n",
        "  n_ctx = 3\n",
        "  num_heads = 4\n",
        "  d_head = d_model // num_heads\n",
        "  key = random.PRNGKey(0)\n",
        "  WE = random.normal(key, (d_model, d_vocab))/np.sqrt(d_model)\n",
        "  Wpos = random.normal(key, (n_ctx, d_model))/np.sqrt(d_model)\n",
        "  WK = random.normal(key, (num_heads, d_head, d_model))/np.sqrt(d_model)\n",
        "  WQ = random.normal(key, (num_heads, d_head, d_model))/np.sqrt(d_model)\n",
        "  WV = random.normal(key, (num_heads, d_head, d_model))/np.sqrt(d_model)\n",
        "  WO = random.normal(key, (d_model, d_head * num_heads))/np.sqrt(d_model)\n",
        "\n",
        "\n",
        "  Win = random.normal(key, (d_mlp, d_model))/np.sqrt(d_model)\n",
        "  bin = jnp.zeros(d_mlp)\n",
        "\n",
        "  Wout = random.normal(key, (d_model, d_mlp))/np.sqrt(d_model)\n",
        "  bout = jnp.zeros(d_model)\n",
        "\n",
        "  WU = random.normal(key, (d_model, d_vocab))/np.sqrt(d_vocab)\n",
        "\n",
        "  params = [WE, Wpos, WK, WQ, WV, WO, Win, bin, Wout, bout, WU, scale]\n",
        "  params_lin = deepcopy(params)\n",
        "  params_init = deepcopy(params)\n",
        "\n",
        "  def xent(logits, labels):\n",
        "      logits = jnp.array(logits, dtype=jnp.float32)\n",
        "      log_probs = jax.nn.log_softmax(logits)\n",
        "      prediction_logprobs = jnp.take_along_axis(log_probs, labels[:, None], axis=-1)\n",
        "      loss = -jnp.mean(prediction_logprobs)\n",
        "      return loss\n",
        "\n",
        "\n",
        "  def centered_loss(p, X):\n",
        "    # get logits from model\n",
        "    logits = transformer_forward(p, X)[:, -1]\n",
        "    labels = jnp.array([(i+j)%P for i, j, _ in X])\n",
        "    logits_init = transformer_forward(params_init, X)[:, -1]\n",
        "    return xent(logits-logits_init, labels)\n",
        "\n",
        "  def centered_loss_lin(p, X):\n",
        "    logits = f_lin(p, X)[:, -1]\n",
        "    labels = jnp.array([(i+j)%P for i, j, _ in X])\n",
        "    logits_init = f_lin0(params_init, X)[:, -1]\n",
        "    return xent(logits-logits_init, labels)\n",
        "\n",
        "grad_fn = grad(centered_loss)\n",
        "\n",
        "optimizer = optax.adamw(\n",
        "    learning_rate=lr / scale**2,\n",
        "    b1=0.9,\n",
        "    b2=0.98,\n",
        "    weight_decay=weight_decay*2\n",
        ")\n",
        "optimizer_lin = optax.adamw(\n",
        "    learning_rate=lr / scale**2,\n",
        "    b1=0.9,\n",
        "    b2=0.98,\n",
        "    weight_decay=weight_decay*2\n",
        ")\n",
        "\n",
        "opt_state = optimizer.init(params)\n",
        "opt_state_lin = optimizer_lin.init(params)\n",
        "\n",
        "train_losses, test_losses = [], []\n",
        "lin_train_losses, lin_test_losses = [], []\n",
        "f_lin = nt.linearize(transformer_forward, params_init)\n",
        "f_lin0 = nt.linearize(transformer_forward, params_init)\n",
        "\n",
        "num_epochs = 15000\n",
        "for epoch in tqdm(range(num_epochs)):\n",
        "  loss, grads = jax.value_and_grad(centered_loss)(params, train)\n",
        "  updates, opt_state = optimizer.update(grads, opt_state, params)\n",
        "  params = optax.apply_updates(params, updates)\n",
        "\n",
        "  lin_loss, grads_lin = jax.value_and_grad(centered_loss_lin)(params_lin, train)\n",
        "  updates_lin, opt_state_lin = optimizer_lin.update(grads_lin, opt_state_lin, params_lin)\n",
        "  params_lin = optax.apply_updates(params_lin, updates_lin)\n",
        "\n",
        "  train_losses += [loss]\n",
        "  lin_train_losses += [lin_loss]\n",
        "  often = 100\n",
        "\n",
        "  if epoch % often == 0:\n",
        "\n",
        "    test_loss = centered_loss(params, test)\n",
        "    test_losses += [test_loss]\n",
        "\n",
        "    test_loss_lin = centered_loss_lin(params_lin, test)\n",
        "    lin_test_losses += [test_loss_lin]\n",
        "\n",
        "    if epoch % 2000 == 0 and epoch > 0:\n",
        "      print(f'Epoch {epoch}: train loss={loss} & test loss={test_loss}')\n",
        "      print(f'Epoch {epoch}: LINEAR train loss={lin_loss} & test loss={test_loss_lin}')\n",
        "\n",
        "      x_test_epochs = np.arange(0, epoch, often)\n",
        "      x_all_epochs = np.arange(epoch)\n",
        "      interp_test_losses = interp1d(x_test_epochs, test_losses[:(epoch//often)], kind='cubic', fill_value='extrapolate')(x_all_epochs)\n",
        "      interp_lin_test_losses = interp1d(x_test_epochs, lin_test_losses[:(epoch//often)], kind='cubic', fill_value='extrapolate')(x_all_epochs)\n",
        "\n",
        "      plt.figure()\n",
        "      plt.plot(train_losses, label='Train Loss', linestyle='--', color='r')\n",
        "      plt.plot(lin_train_losses, label='Linearized Train', color='g', linestyle='--')\n",
        "      plt.plot(interp_test_losses, label='Test Loss', color='r')\n",
        "      plt.plot(interp_lin_test_losses, label='Linearized Test Loss', color='g')\n",
        "      plt.title('1LT modar Losses')\n",
        "      plt.xlabel('Epochs')\n",
        "      plt.yscale('log')\n",
        "      plt.legend()\n",
        "      plt.show()\n",
        "\n",
        "\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "A100",
      "machine_shape": "hm",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
