{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "12310869-4953-4f6d-ae52-6123d7f1992e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import importlib\n",
    "from tqdm import tqdm\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import equinox as eqx\n",
    "import jax\n",
    "\n",
    "import jax.numpy as jnp\n",
    "from copy import deepcopy\n",
    "from time import time\n",
    "\n",
    "import main_utils\n",
    "import opto_ as opto\n",
    "import samplers\n",
    "importlib.reload(samplers)\n",
    "import samplers\n",
    "import metrics\n",
    "importlib.reload(metrics)\n",
    "from metrics import ce, mse\n",
    "from main import eval_step, run_with_opts\n",
    "from additional_models import cqk_cpv_fit, gd\n",
    "from analyse import analyse2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e691f2a",
   "metadata": {},
   "source": [
    "## Training\n",
    "In case Colab notebook doesn't work for some reason, here we provide a an alternative. The environment is provided in `environment.yml` file. Due to compute constraints, we provide code for shorter training here, and with higher learning rate---however, this seems to damage the performance for softmax self-attention, especially in the case $d=2$ (as can be seen below)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5ab2271",
   "metadata": {},
   "outputs": [],
   "source": [
    "SAVE_FOLDER = \"\" #TODO Fill in a repository for saving models checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa5becf7-7f73-4b55-8eed-ce3a602ec0a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "parser = main_utils.create_parser()\n",
    "opto.add_args_to_parser(parser)\n",
    "\n",
    "INIT_SEED=5\n",
    "MAIN_RUN_ITERS = 40960000 # 2048 * 20000\n",
    "LONG_RUN_ITERS = 409600000\n",
    "\n",
    "lin_opts = parser.parse_args(\n",
    "    [\n",
    "        '--no_embed',\n",
    "        '--no_unembed',\n",
    "        '--not_causal',\n",
    "        '--exclude_query_token',\n",
    "        '--no_softmax',\n",
    "        '--balanced_classes_queries',\n",
    "        '--init_rescale', '0.002',\n",
    "        '--classification',\n",
    "        '--pos_embedding_type', 'none',\n",
    "        '--n_labels', '5',\n",
    "        '--train_context_len', '100',\n",
    "        '--X_dim', '13',\n",
    "        '--eval_context_len', '100',\n",
    "        '--eval_n_samples', '512',\n",
    "        '--no_norm',\n",
    "        '--optimizer', 'adam',\n",
    "        '--grad_clip_val', '0.001',\n",
    "        '--depth', '1',\n",
    "        '--train_iters', str(MAIN_RUN_ITERS),\n",
    "        '--train_bs', '2048',\n",
    "        '--eval_every', '204800',\n",
    "        '--lr', '0.0005',\n",
    "        '--d_model', '18',\n",
    "        '--num_heads', '1',\n",
    "        '--run', 'lin_c100_lr0.0005',\n",
    "        '--ckpt_every', '2048000',\n",
    "        '--init_seed', str(INIT_SEED),\n",
    "        '--base_folder', SAVE_FOLDER,\n",
    "        '--raw_name',\n",
    "        '--no_proj_bias'\n",
    "    ]\n",
    ")\n",
    "\n",
    "sfmx_opts = deepcopy(lin_opts)\n",
    "sfmx_opts.lr = 0.002\n",
    "sfmx_opts.init_rescale = 0.002\n",
    "sfmx_opts.no_softmax = False\n",
    "sfmx_opts.grad_clip_val = 1.0\n",
    "\n",
    "all_opts = []\n",
    "\n",
    "X_dim = 2\n",
    "cont_len = 100\n",
    "n_labels = 5\n",
    "\n",
    "all_opts.append(deepcopy(sfmx_opts))\n",
    "all_opts[-1].run = f'sfmx_xdim_{X_dim}'\n",
    "all_opts[-1].train_seed = 11\n",
    "all_opts[-1].init_seed = 8\n",
    "all_opts[-1].n_labels = n_labels\n",
    "all_opts[-1].X_dim = X_dim\n",
    "all_opts[-1].train_context_len = cont_len\n",
    "all_opts[-1].eval_context_len = cont_len\n",
    "all_opts[-1].d_model = all_opts[-1].X_dim + all_opts[-1].n_labels\n",
    "all_opts[-1].lr = 0.005\n",
    "all_opts[-1].train_iters = MAIN_RUN_ITERS*2\n",
    "\n",
    "all_opts.append(deepcopy(lin_opts))\n",
    "all_opts[-1].run = f'lin_xdim_{X_dim}'\n",
    "all_opts[-1].train_seed = 11\n",
    "all_opts[-1].init_seed = 8\n",
    "all_opts[-1].X_dim = X_dim\n",
    "all_opts[-1].n_labels = n_labels\n",
    "all_opts[-1].eval_context_len = cont_len\n",
    "all_opts[-1].train_context_len = cont_len\n",
    "all_opts[-1].d_model = all_opts[-1].X_dim + all_opts[-1].n_labels\n",
    "all_opts[-1].train_iters = MAIN_RUN_ITERS\n",
    "all_opts[-1].lr = 0.0001\n",
    "\n",
    "X_dim = 3\n",
    "cont_len = 100\n",
    "n_labels = 5\n",
    "\n",
    "all_opts.append(deepcopy(sfmx_opts))\n",
    "all_opts[-1].run = f'sfmx_xdim_{X_dim}'\n",
    "all_opts[-1].train_seed = 11\n",
    "all_opts[-1].init_seed = 8\n",
    "all_opts[-1].n_labels = n_labels\n",
    "all_opts[-1].X_dim = X_dim\n",
    "all_opts[-1].train_context_len = cont_len\n",
    "all_opts[-1].eval_context_len = cont_len\n",
    "all_opts[-1].d_model = all_opts[-1].X_dim + all_opts[-1].n_labels\n",
    "all_opts[-1].lr = 0.001\n",
    "all_opts[-1].train_iters = MAIN_RUN_ITERS\n",
    "\n",
    "all_opts.append(deepcopy(lin_opts))\n",
    "all_opts[-1].run = f'lin_xdim_{X_dim}'\n",
    "all_opts[-1].train_seed = 11\n",
    "all_opts[-1].init_seed = 8\n",
    "all_opts[-1].X_dim = X_dim\n",
    "all_opts[-1].n_labels = n_labels\n",
    "all_opts[-1].eval_context_len = cont_len\n",
    "all_opts[-1].train_context_len = cont_len\n",
    "all_opts[-1].d_model = all_opts[-1].X_dim + all_opts[-1].n_labels\n",
    "all_opts[-1].train_iters = MAIN_RUN_ITERS\n",
    "all_opts[-1].lr = 0.0001\n",
    "\n",
    "X_dim = 5\n",
    "cont_len = 100\n",
    "n_labels = 5\n",
    "\n",
    "all_opts.append(deepcopy(sfmx_opts))\n",
    "all_opts[-1].run = f'sfmx_xdim_{X_dim}'\n",
    "all_opts[-1].train_seed = 11\n",
    "all_opts[-1].init_seed = 8\n",
    "all_opts[-1].n_labels = n_labels\n",
    "all_opts[-1].X_dim = X_dim\n",
    "all_opts[-1].train_context_len = cont_len\n",
    "all_opts[-1].eval_context_len = cont_len\n",
    "all_opts[-1].d_model = all_opts[-1].X_dim + all_opts[-1].n_labels\n",
    "all_opts[-1].lr = 0.002\n",
    "all_opts[-1].train_iters = MAIN_RUN_ITERS\n",
    "\n",
    "all_opts.append(deepcopy(lin_opts))\n",
    "all_opts[-1].run = f'lin_xdim_{X_dim}'\n",
    "all_opts[-1].train_seed = 11\n",
    "all_opts[-1].init_seed = 8\n",
    "all_opts[-1].X_dim = X_dim\n",
    "all_opts[-1].n_labels = n_labels\n",
    "all_opts[-1].eval_context_len = cont_len\n",
    "all_opts[-1].train_context_len = cont_len\n",
    "all_opts[-1].d_model = all_opts[-1].X_dim + all_opts[-1].n_labels\n",
    "all_opts[-1].train_iters = MAIN_RUN_ITERS\n",
    "all_opts[-1].lr = 0.0001\n",
    "\n",
    "X_dim = 10\n",
    "cont_len = 100\n",
    "n_labels = 5\n",
    "\n",
    "all_opts.append(deepcopy(sfmx_opts))\n",
    "all_opts[-1].run = f'sfmx_xdim_{X_dim}'\n",
    "all_opts[-1].train_seed = 11\n",
    "all_opts[-1].init_seed = 8\n",
    "all_opts[-1].n_labels = n_labels\n",
    "all_opts[-1].X_dim = X_dim\n",
    "all_opts[-1].train_context_len = cont_len\n",
    "all_opts[-1].eval_context_len = cont_len\n",
    "all_opts[-1].d_model = all_opts[-1].X_dim + all_opts[-1].n_labels\n",
    "all_opts[-1].lr = 0.005\n",
    "all_opts[-1].train_iters = MAIN_RUN_ITERS\n",
    "\n",
    "all_opts.append(deepcopy(lin_opts))\n",
    "all_opts[-1].run = f'lin_xdim_{X_dim}'\n",
    "all_opts[-1].train_seed = 11\n",
    "all_opts[-1].init_seed = 8\n",
    "all_opts[-1].X_dim = X_dim\n",
    "all_opts[-1].n_labels = n_labels\n",
    "all_opts[-1].eval_context_len = cont_len\n",
    "all_opts[-1].train_context_len = cont_len\n",
    "all_opts[-1].d_model = all_opts[-1].X_dim + all_opts[-1].n_labels\n",
    "all_opts[-1].train_iters = MAIN_RUN_ITERS\n",
    "all_opts[-1].lr = 0.0001\n",
    "\n",
    "start = time()\n",
    "for opts in tqdm(all_opts):\n",
    "    run_with_opts(opts)\n",
    "print(\"Total time (min):\", (time() - start) / 60)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef83e721-20f2-40c6-a2bc-ede25b0c3806",
   "metadata": {},
   "source": [
    "# Experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "306450a1-14ad-452b-8067-418f31f5982f",
   "metadata": {},
   "outputs": [],
   "source": [
    "matplotlib.rcParams.update({'font.size': 18})\n",
    "matplotlib.rcParams['lines.linewidth'] = 2\n",
    "\n",
    "matplotlib.rc('font', size=18)\n",
    "matplotlib.rc('axes', titlesize=18)\n",
    "matplotlib.rc('axes', labelsize=18)\n",
    "matplotlib.rc('xtick', labelsize=18)\n",
    "matplotlib.rc('ytick', labelsize=18)\n",
    "matplotlib.rc('legend', fontsize=18)\n",
    "matplotlib.rc('figure', titlesize=18)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "1f5271c1-fe23-41ba-8ae8-23d6ebeb7627",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_folder = SAVE_FOLDER"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "d1de5869-7a41-49d2-beb7-63b789e65e33",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_run_defaults(run_dir):\n",
    "    # Given the run dir and a list of checkpoints, plot the Wq^T Wk matrix that determines attention\n",
    "    opts = main_utils.get_opts_from_json_file(f'{run_dir}/config.json')\n",
    "    model = main_utils.get_model_from_opts(opts)\n",
    "    fwd_fn = opto.make_fn_from_opts(opts)\n",
    "    \n",
    "    opt_state = main_utils.get_optimizer_from_opts(opts).init(eqx.filter(model, eqx.is_array))\n",
    "    \n",
    "    ckpt_fmt = {'iter': -1, \n",
    "                'seeds': {'eval_model_seed': jax.random.PRNGKey(0),\n",
    "                          'train_data_seed': jax.random.PRNGKey(0),\n",
    "                          'train_model_seed': jax.random.PRNGKey(0)}, \n",
    "                'opt_state': opt_state,\n",
    "                'model': model}\n",
    "    return fwd_fn, ckpt_fmt, opts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "b1d9698c-dfa8-487a-9c4a-790eb44eb91e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_model(run, i=-1):\n",
    "    fwd_fn, ckpt_fmt, opts = get_run_defaults(f'{base_folder}/{run}')\n",
    "    ckpt_fname = sorted(os.listdir(f'{base_folder}/{run}/checkpoints/'))[i]\n",
    "    ckpt = eqx.tree_deserialise_leaves(f'{base_folder}/{run}/checkpoints/{ckpt_fname}', ckpt_fmt)\n",
    "    model = ckpt['model']\n",
    "    loss_fn = ce if opts.classification else mse\n",
    "    return {'model': model, 'fwd_fn': fwd_fn, 'loss_fn': loss_fn}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "ecbc4c9f-06d6-485f-8ed7-270bf722692f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_eval_data(run):\n",
    "    _, _, opts = get_run_defaults(f'{base_folder}/{run}')\n",
    "    \n",
    "    eval_seed = jax.random.PRNGKey(opts.eval_seed)\n",
    "    eval_data_seed, _, eval_model_seed = jax.random.split(eval_seed, 3)\n",
    "    eval_sampler = jax.jit(samplers.make_balanced_classification_queries_sampler(opts.n_labels, opts.eval_context_len, opts.eval_n_samples, opts.X_dim, opts.eval_noise_scale))\n",
    "    data = dict()\n",
    "    data['eval'] = eval_sampler(eval_data_seed)\n",
    "\n",
    "    return data, eval_model_seed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "1d6452c2-c28f-4465-81b2-4b13ab407d7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_attention(run_dir, ckpt_fnames, qk_vmin=-1, qk_vmax=1, ov_vmin=-5, ov_vmax=5, block=0):\n",
    "    fwd_fn, ckpt_fmt, opts = get_run_defaults(run_dir)\n",
    "    n = len(ckpt_fnames)\n",
    "    \n",
    "    # change if needed\n",
    "    fig, axs = plt.subplots(2, n, figsize=(4 * n, 4 * 2))\n",
    "    #fig.suptitle(run_dir.split(\"/\")[-1])\n",
    "\n",
    "    for i, ckpt_fname in enumerate(ckpt_fnames):\n",
    "        \n",
    "        ckpt = eqx.tree_deserialise_leaves(f'{run_dir}/checkpoints/{ckpt_fname}', ckpt_fmt)\n",
    "        model = ckpt['model']\n",
    "        qkv = model.transformer.blocks[block].attn.qkv.weight\n",
    "        p = model.transformer.blocks[block].attn.proj.weight\n",
    "        q, k, v = jnp.reshape(qkv, (3, -1, opts.d_model))\n",
    "        qk = q.transpose() @ k\n",
    "        pv = p @ v\n",
    "        \n",
    "        qk_img = axs[0,i].imshow(qk, vmin=qk_vmin, vmax=qk_vmax, cmap='coolwarm')\n",
    "        vp_img = axs[1,i].imshow(pv, vmin=ov_vmin, vmax=ov_vmax, cmap='coolwarm')\n",
    "        #axs[0, i].set_title(f'{ckpt_fname}: ' + \"{:.2f}\".format(qk[0, 0]))\n",
    "        #axs[1, i].set_title(f'{ckpt_fname}: ' + \"{:.2f}\".format(pv[-1, -1]))\n",
    "        \n",
    "    cbar = plt.colorbar(qk_img, ax=axs[0,:], orientation='vertical', fraction=0.02, pad=0.04)\n",
    "    cbar2 = plt.colorbar(vp_img, ax=axs[1,:], orientation='vertical', fraction=0.02, pad=0.04)\n",
    "    cbar.set_label('Color Intensity')\n",
    "    cbar2.set_label('Color Intensity')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "5ca3357b-fe14-4a35-ab67-01f7326500ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_alg_differences(ax, model1, model2, fwd_fn1, fwd_fn2, eval_data, key, colors=['g', 'b', 'c', 'r'], metrics=['loss', 'entr'], act='Softmax', corr_vs_incorr=True):\n",
    "    metric_names = {'loss': 'Loss', 'p_corr': 'P correct token', 'entr': 'Entropy'}\n",
    "    out1 = eval_step(model=model1,\n",
    "                     fwd_fn=fwd_fn1, \n",
    "                     loss_fn=ce, \n",
    "                     x=eval_data['examples'],\n",
    "                     y=eval_data['labels'], \n",
    "                     key=key, \n",
    "                     classification=True)\n",
    "    out2 = eval_step(model=model2,\n",
    "                     fwd_fn=fwd_fn2, \n",
    "                     loss_fn=ce, \n",
    "                     x=eval_data['examples'],\n",
    "                     y=eval_data['labels'], \n",
    "                     key=key, \n",
    "                     classification=True)\n",
    "\n",
    "    if corr_vs_incorr:\n",
    "        masks = dict()\n",
    "        \n",
    "        for val1 in [True, False]:\n",
    "            for val2 in [True, False]:\n",
    "                masks['TR: {}, GD: {}'.format('Correct' if val1 else 'Incorrect', 'Correct' if val2 else 'Incorrect')] = jnp.logical_and(out1['accs'] == val1, \n",
    "                                                                                    out2['accs'] == val2)\n",
    "        for j, metric in enumerate(metrics):\n",
    "            for i, label in enumerate(masks):\n",
    "                ax[j].scatter(out1[metric][masks[label]], out2[metric][masks[label]], marker='.', label=label, color=colors[i], s=80)\n",
    "                ax[j].axvline(jnp.mean(out1[metric]), c='gray', ls='--')\n",
    "                ax[j].axhline(jnp.mean(out2[metric]), c='gray', ls='--')\n",
    "                ax[j].set_title(metric_names[metric])\n",
    "                ax[j].plot([jnp.min(out1[metric]), jnp.max(out1[metric])], [jnp.min(out1[metric]), jnp.max(out1[metric])], c='k', scalex=False, scaley=False, ls='--')\n",
    "    \n",
    "                ax[j].set_xlabel(f'{act} SA')\n",
    "                ax[j].set_ylabel('GD')\n",
    "    \n",
    "                ax[j].set_box_aspect(1.0)\n",
    "                #ax[j].tick_params(axis='x', labelsize=15)\n",
    "                #ax[j].tick_params(axis='y', labelsize=15)\n",
    "    \n",
    "        handles, labels = ax[-1].get_legend_handles_labels()\n",
    "        ax[-1].legend(handles, labels, scatterpoints=1, markerscale=2, bbox_to_anchor=(1,1))\n",
    "    else:\n",
    "        for j, metric in enumerate(metrics):\n",
    "            ax[j].scatter(out1[metric], out2[metric], marker='.', s=80)\n",
    "            ax[j].set_xlabel(f'{act} SA')\n",
    "            ax[j].set_ylabel('GD')\n",
    "            ax[j].set_title(metric_names[metric])\n",
    "            ax[j].set_box_aspect(1.0)\n",
    "            ax[j].axvline(jnp.mean(out1[metric]), c='gray', ls='--')\n",
    "            ax[j].axhline(jnp.mean(out2[metric]), c='gray', ls='--')\n",
    "            ax[j].plot([jnp.min(out1[metric]), jnp.max(out1[metric])], [jnp.min(out1[metric]), jnp.max(out1[metric])], c='k', scalex=False, scaley=False, ls='--')\n",
    "    plt.show()\n",
    "    return {k: jnp.sum(v) for k, v in masks.items()} if corr_vs_incorr else {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "d70f7958-043a-4334-acfb-05bca9e5e728",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_analysis(ax, dots=[], pred_norms=[], norms=[], factor=1):\n",
    "    n=len(pred_norms)\n",
    "    ax.plot(jnp.linspace(0,(n-1)*factor, n), pred_norms, label=\"Preds diff\", color=\"brown\")\n",
    "    if len(norms)==n:\n",
    "        ax.plot(jnp.linspace(0,(n-1)*factor, n), norms, label=\"Model diff\", color=\"orange\")\n",
    "    #ax.tick_params(axis='x')\n",
    "    #ax.set_xscale('log')\n",
    "    ax.set_xlabel(\"Training steps\")\n",
    "    ax.set_ylabel(\"L2 Norm\", color=\"black\")\n",
    "    ax.tick_params(axis=\"y\", labelcolor=\"black\")\n",
    "    ax.legend(loc=\"upper left\", bbox_to_anchor=(0.1, 0.95))\n",
    "\n",
    "    if len(dots)==n:\n",
    "        ax2 = ax.twinx()\n",
    "        ax2.plot(jnp.linspace(0,(n-1)*factor, n), dots, label=\"Cos sim\", color=\"green\")\n",
    "        ax2.set_ylabel(\"Cosine sim\", color=\"black\")\n",
    "        ax2.tick_params(axis=\"y\", labelcolor=\"black\")\n",
    "        ax2.legend(loc=\"upper right\", bbox_to_anchor=(0.95, 0.5))\n",
    "    ax.set_box_aspect(0.8)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca48c7bb-397e-4f39-9233-42b2dc35161e",
   "metadata": {},
   "source": [
    "## Linear transformer\n",
    "Following are snippets used to generate Figure 4. Again, we note that the runs in Figure 4 were trained for longer and with a smaller learning rate (and with more checkpoints). However, linear self-attention is not really impacted by these changes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f4db98d-538d-4a65-b49e-d802a4b0a395",
   "metadata": {},
   "outputs": [],
   "source": [
    "run1 = 'lin_xdim_2'\n",
    "m1 = get_model(run1)\n",
    "model1, fwd_fn1 = m1['model'], m1['fwd_fn']\n",
    "_, _, opts = get_run_defaults(f'{base_folder}/{run1}')\n",
    "eval_sampler = jax.jit(samplers.make_balanced_classification_queries_sampler(opts.n_labels, opts.eval_context_len, 10000, opts.X_dim, opts.eval_noise_scale))\n",
    "key = jax.random.PRNGKey(120)\n",
    "data = eval_sampler(key)\n",
    "\n",
    "eval_sampler = jax.jit(samplers.make_balanced_classification_queries_sampler(opts.n_labels, opts.eval_context_len, 100, opts.X_dim, opts.eval_noise_scale))\n",
    "eval_data = eval_sampler(jax.random.PRNGKey(110))\n",
    "\n",
    "comp_data, key = get_eval_data(run1)\n",
    "\n",
    "model2 = gd(eval_data=data)\n",
    "fwd_fn2 = gd.fwd_fn\n",
    "\n",
    "metrics = ['loss', 'p_corr', 'entr']\n",
    "\n",
    "dots = []\n",
    "norms = []\n",
    "pred_norms = []\n",
    "for i in tqdm(range(0,20,1)):\n",
    "    m1 = get_model(run1, i)\n",
    "    model1, fwd_fn1 = m1['model'], m1['fwd_fn']\n",
    "    dot, norm, pred_norm = analyse2(eval_data, model1, model2, fwd_fn1, fwd_fn2, key=jax.random.PRNGKey(10))\n",
    "    dots.append(dot)\n",
    "    norms.append(norm)\n",
    "    pred_norms.append(pred_norm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf90e9be-de12-492c-b860-846d4fc9422d",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(\n",
    "    1, len(metrics) + 1,\n",
    "    figsize=(7 + 4 * len(metrics), 3),\n",
    "    gridspec_kw={'width_ratios': [2.5] + [1.5]*len(metrics)}\n",
    ")\n",
    "plt.subplots_adjust(wspace=0.1)\n",
    "plot_analysis(ax[0], dots, pred_norms, norms) \n",
    "plot_alg_differences(ax[1:], model1, model2, fwd_fn1, fwd_fn2, comp_data['eval'], key, metrics=metrics, act='Linear', corr_vs_incorr=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfbc23a9-bcac-4893-bce7-e48b1b0853b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "run1 = 'lin_xdim_3'\n",
    "m1 = get_model(run1)\n",
    "model1, fwd_fn1 = m1['model'], m1['fwd_fn']\n",
    "_, _, opts = get_run_defaults(f'{base_folder}/{run1}')\n",
    "eval_sampler = jax.jit(samplers.make_balanced_classification_queries_sampler(opts.n_labels, opts.eval_context_len, 10000, opts.X_dim, opts.eval_noise_scale))\n",
    "key = jax.random.PRNGKey(120)\n",
    "data = eval_sampler(key)\n",
    "\n",
    "eval_sampler = jax.jit(samplers.make_balanced_classification_queries_sampler(opts.n_labels, opts.eval_context_len, 100, opts.X_dim, opts.eval_noise_scale))\n",
    "eval_data = eval_sampler(jax.random.PRNGKey(110))\n",
    "\n",
    "comp_data, key = get_eval_data(run1)\n",
    "\n",
    "model2 = gd(eval_data=data)\n",
    "fwd_fn2 = gd.fwd_fn\n",
    "\n",
    "metrics = ['loss', 'p_corr', 'entr']\n",
    "\n",
    "dots = []\n",
    "norms = []\n",
    "pred_norms = []\n",
    "for i in tqdm(range(0,20,1)):\n",
    "    m1 = get_model(run1, i)\n",
    "    model1, fwd_fn1 = m1['model'], m1['fwd_fn']\n",
    "    dot, norm, pred_norm = analyse2(eval_data, model1, model2, fwd_fn1, fwd_fn2, key=jax.random.PRNGKey(10))\n",
    "    dots.append(dot)\n",
    "    norms.append(norm)\n",
    "    pred_norms.append(pred_norm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27153b30-07ff-4156-8ea2-03a9784cc93d",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(\n",
    "    1, len(metrics) + 1,\n",
    "    figsize=(7 + 4 * len(metrics), 3),\n",
    "    gridspec_kw={'width_ratios': [2.5] + [1.5]*len(metrics)}\n",
    ")\n",
    "plt.subplots_adjust(wspace=0.1)\n",
    "plot_analysis(ax[0], dots, pred_norms, norms) \n",
    "plot_alg_differences(ax[1:], model1, model2, fwd_fn1, fwd_fn2, comp_data['eval'], key, metrics=metrics, act='Linear', corr_vs_incorr=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89e5d68f-2862-435c-b1bc-ba6314a21782",
   "metadata": {},
   "outputs": [],
   "source": [
    "run1 = 'lin_xdim_5'\n",
    "m1 = get_model(run1)\n",
    "model1, fwd_fn1 = m1['model'], m1['fwd_fn']\n",
    "_, _, opts = get_run_defaults(f'{base_folder}/{run1}')\n",
    "eval_sampler = jax.jit(samplers.make_balanced_classification_queries_sampler(opts.n_labels, opts.eval_context_len, 10000, opts.X_dim, opts.eval_noise_scale))\n",
    "key = jax.random.PRNGKey(120)\n",
    "data = eval_sampler(key)\n",
    "\n",
    "eval_sampler = jax.jit(samplers.make_balanced_classification_queries_sampler(opts.n_labels, opts.eval_context_len, 100, opts.X_dim, opts.eval_noise_scale))\n",
    "eval_data = eval_sampler(jax.random.PRNGKey(110))\n",
    "\n",
    "comp_data, key = get_eval_data(run1)\n",
    "\n",
    "model2 = gd(eval_data=data)\n",
    "fwd_fn2 = gd.fwd_fn\n",
    "\n",
    "metrics = ['loss', 'p_corr', 'entr']\n",
    "\n",
    "dots = []\n",
    "norms = []\n",
    "pred_norms = []\n",
    "for i in tqdm(range(0,20,1)):\n",
    "    m1 = get_model(run1, i)\n",
    "    model1, fwd_fn1 = m1['model'], m1['fwd_fn']\n",
    "    dot, norm, pred_norm = analyse2(eval_data, model1, model2, fwd_fn1, fwd_fn2, key=jax.random.PRNGKey(10))\n",
    "    dots.append(dot)\n",
    "    norms.append(norm)\n",
    "    pred_norms.append(pred_norm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07f957da-f944-4f39-9cfd-a3fd8ceffb04",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(\n",
    "    1, len(metrics) + 1,\n",
    "    figsize=(7 + 4 * len(metrics), 3),\n",
    "    gridspec_kw={'width_ratios': [2.5] + [1.5]*len(metrics)}\n",
    ")\n",
    "plt.subplots_adjust(wspace=0.1)\n",
    "plot_analysis(ax[0], dots, pred_norms, norms) \n",
    "plot_alg_differences(ax[1:], model1, model2, fwd_fn1, fwd_fn2, comp_data['eval'], key, metrics=metrics, act='Linear', corr_vs_incorr=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39fea9f5-92a4-44e6-9256-ba3183814f54",
   "metadata": {},
   "outputs": [],
   "source": [
    "run1 = 'lin_xdim_10'\n",
    "m1 = get_model(run1)\n",
    "model1, fwd_fn1 = m1['model'], m1['fwd_fn']\n",
    "_, _, opts = get_run_defaults(f'{base_folder}/{run1}')\n",
    "eval_sampler = jax.jit(samplers.make_balanced_classification_queries_sampler(opts.n_labels, opts.eval_context_len, 10000, opts.X_dim, opts.eval_noise_scale))\n",
    "key = jax.random.PRNGKey(120)\n",
    "data = eval_sampler(key)\n",
    "\n",
    "eval_sampler = jax.jit(samplers.make_balanced_classification_queries_sampler(opts.n_labels, opts.eval_context_len, 100, opts.X_dim, opts.eval_noise_scale))\n",
    "eval_data = eval_sampler(jax.random.PRNGKey(110))\n",
    "\n",
    "comp_data, key = get_eval_data(run1)\n",
    "\n",
    "model2 = gd(eval_data=data)\n",
    "fwd_fn2 = gd.fwd_fn\n",
    "\n",
    "metrics = ['loss', 'p_corr', 'entr']\n",
    "\n",
    "dots = []\n",
    "norms = []\n",
    "pred_norms = []\n",
    "for i in tqdm(range(0,20,1)):\n",
    "    m1 = get_model(run1, i)\n",
    "    model1, fwd_fn1 = m1['model'], m1['fwd_fn']\n",
    "    dot, norm, pred_norm = analyse2(eval_data, model1, model2, fwd_fn1, fwd_fn2, key=jax.random.PRNGKey(10))\n",
    "    dots.append(dot)\n",
    "    norms.append(norm)\n",
    "    pred_norms.append(pred_norm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a4b8498-7419-4f44-b1bc-617d7c16b76a",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(\n",
    "    1, len(metrics) + 1,\n",
    "    figsize=(7 + 4 * len(metrics), 3),\n",
    "    gridspec_kw={'width_ratios': [2.5] + [1.5]*len(metrics)}\n",
    ")\n",
    "plt.subplots_adjust(wspace=0.1)\n",
    "plot_analysis(ax[0], dots, pred_norms, norms) \n",
    "plot_alg_differences(ax[1:], model1, model2, fwd_fn1, fwd_fn2, comp_data['eval'], key, metrics=metrics, act='Linear', corr_vs_incorr=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5743be48-c1c2-463a-8f08-495c1b786d38",
   "metadata": {},
   "source": [
    "## Softmax transformer\n",
    "Softmax SA is more impacted by the learning rate in use. We plot the difference metrics for the trained runs, but we note that paper runs were trained for 100 times longer time and with a smaller learning rate (see Appendix G for all the details)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e95e9935-c9d6-4499-b945-c357985e9a1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "run1 = 'sfmx_xdim_2'\n",
    "m1 = get_model(run1)\n",
    "model1, fwd_fn1 = m1['model'], m1['fwd_fn']\n",
    "_, _, opts = get_run_defaults(f'{base_folder}/{run1}')\n",
    "eval_sampler = jax.jit(samplers.make_balanced_classification_queries_sampler(opts.n_labels, opts.eval_context_len, 10000, opts.X_dim, opts.eval_noise_scale))\n",
    "key = jax.random.PRNGKey(120)\n",
    "data = eval_sampler(key)\n",
    "\n",
    "eval_sampler = jax.jit(samplers.make_balanced_classification_queries_sampler(opts.n_labels, opts.eval_context_len, 100, opts.X_dim, opts.eval_noise_scale))\n",
    "eval_data = eval_sampler(jax.random.PRNGKey(110))\n",
    "\n",
    "comp_data, key = get_eval_data(run1)\n",
    "\n",
    "model2 = cqk_cpv_fit(eval_data=data, cqk_low=0, cqk_high=3.5, cpv_low=-1, cpv_high=3) #cqk=cqk, cpv=cpv)\n",
    "fwd_fn2 = model2.fwd_fn\n",
    "\n",
    "metrics = ['loss', 'p_corr', 'entr']\n",
    "\n",
    "dots = []\n",
    "norms = []\n",
    "pred_norms = []\n",
    "for i in tqdm(range(0,40,1)):\n",
    "    m1 = get_model(run1, i)\n",
    "    model1, fwd_fn1 = m1['model'], m1['fwd_fn']\n",
    "    dot, norm, pred_norm = analyse2(eval_data, model1, model2, fwd_fn1, fwd_fn2, key=jax.random.PRNGKey(10))\n",
    "    dots.append(dot)\n",
    "    norms.append(norm)\n",
    "    pred_norms.append(pred_norm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3003a17-8bcf-4b7b-b6bf-3f27521ae4ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(\n",
    "    1, len(metrics) + 1,\n",
    "    figsize=(7 + 4 * len(metrics), 3),\n",
    "    gridspec_kw={'width_ratios': [2.5] + [1.5]*len(metrics)}\n",
    ")\n",
    "plt.subplots_adjust(wspace=0.1)\n",
    "plot_analysis(ax[0], [], pred_norms, [], factor=100) \n",
    "plot_alg_differences(ax[1:], model1, model2, fwd_fn1, fwd_fn2, comp_data['eval'], key, metrics=metrics, act='Softmax', corr_vs_incorr=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b7fe434-2718-4170-b677-5e51f2e79fdf",
   "metadata": {},
   "outputs": [],
   "source": [
    "run1 = 'sfmx_xdim_3'\n",
    "m1 = get_model(run1)\n",
    "model1, fwd_fn1 = m1['model'], m1['fwd_fn']\n",
    "_, _, opts = get_run_defaults(f'{base_folder}/{run1}')\n",
    "eval_sampler = jax.jit(samplers.make_balanced_classification_queries_sampler(opts.n_labels, opts.eval_context_len, 10000, opts.X_dim, opts.eval_noise_scale))\n",
    "key = jax.random.PRNGKey(120)\n",
    "data = eval_sampler(key)\n",
    "\n",
    "eval_sampler = jax.jit(samplers.make_balanced_classification_queries_sampler(opts.n_labels, opts.eval_context_len, 100, opts.X_dim, opts.eval_noise_scale))\n",
    "eval_data = eval_sampler(jax.random.PRNGKey(110))\n",
    "\n",
    "comp_data, key = get_eval_data(run1)\n",
    "\n",
    "model2 = cqk_cpv_fit(eval_data=data, cqk_low=-1, cqk_high=2, cpv_low=-1, cpv_high=2) #cqk=cqk, cpv=cpv)\n",
    "fwd_fn2 = model2.fwd_fn\n",
    "\n",
    "metrics = ['loss', 'p_corr', 'entr']\n",
    "\n",
    "dots = []\n",
    "norms = []\n",
    "pred_norms = []\n",
    "for i in tqdm(range(0,20,1)):\n",
    "    m1 = get_model(run1, i)\n",
    "    model1, fwd_fn1 = m1['model'], m1['fwd_fn']\n",
    "    dot, norm, pred_norm = analyse2(eval_data, model1, model2, fwd_fn1, fwd_fn2, key=jax.random.PRNGKey(10))\n",
    "    dots.append(dot)\n",
    "    norms.append(norm)\n",
    "    pred_norms.append(pred_norm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12f8201d-93aa-4ff3-95fd-5a6ac78a2475",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(\n",
    "    1, len(metrics) + 1,\n",
    "    figsize=(7 + 4 * len(metrics), 3),\n",
    "    gridspec_kw={'width_ratios': [2.5] + [1.5]*len(metrics)}  # Give more width to the first plot\n",
    ")\n",
    "plt.subplots_adjust(wspace=0.1)\n",
    "plot_analysis(ax[0], dots, pred_norms, norms, factor=100) \n",
    "plot_alg_differences(ax[1:], model1, model2, fwd_fn1, fwd_fn2, comp_data['eval'], key, metrics=metrics, act='Softmax', corr_vs_incorr=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2034a8c9-c869-4fb5-8105-64b9f6fe2339",
   "metadata": {},
   "outputs": [],
   "source": [
    "run1 = 'sfmx_xdim_5'\n",
    "m1 = get_model(run1)\n",
    "model1, fwd_fn1 = m1['model'], m1['fwd_fn']\n",
    "_, _, opts = get_run_defaults(f'{base_folder}/{run1}')\n",
    "eval_sampler = jax.jit(samplers.make_balanced_classification_queries_sampler(opts.n_labels, opts.eval_context_len, 10000, opts.X_dim, opts.eval_noise_scale))\n",
    "key = jax.random.PRNGKey(120)\n",
    "data = eval_sampler(key)\n",
    "\n",
    "eval_sampler = jax.jit(samplers.make_balanced_classification_queries_sampler(opts.n_labels, opts.eval_context_len, 100, opts.X_dim, opts.eval_noise_scale))\n",
    "eval_data = eval_sampler(jax.random.PRNGKey(110))\n",
    "\n",
    "comp_data, key = get_eval_data(run1)\n",
    "\n",
    "model2 = cqk_cpv_fit(eval_data=data, cqk_low=-1, cqk_high=2, cpv_low=-1, cpv_high=2)\n",
    "fwd_fn2 = model2.fwd_fn\n",
    "\n",
    "metrics = ['loss', 'p_corr', 'entr']\n",
    "\n",
    "dots = []\n",
    "norms = []\n",
    "pred_norms = []\n",
    "for i in tqdm(range(0,20,1)):\n",
    "    m1 = get_model(run1, i)\n",
    "    model1, fwd_fn1 = m1['model'], m1['fwd_fn']\n",
    "    dot, norm, pred_norm = analyse2(eval_data, model1, model2, fwd_fn1, fwd_fn2, key=jax.random.PRNGKey(10))\n",
    "    dots.append(dot)\n",
    "    norms.append(norm)\n",
    "    pred_norms.append(pred_norm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "819d727a-a95f-48e7-a0f2-c1854a5ded33",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(\n",
    "    1, len(metrics) + 1,\n",
    "    figsize=(7 + 4 * len(metrics), 3),\n",
    "    gridspec_kw={'width_ratios': [2.5] + [1.5]*len(metrics)}  # Give more width to the first plot\n",
    ")\n",
    "plt.subplots_adjust(wspace=0.1)\n",
    "plot_analysis(ax[0], dots, pred_norms, norms, factor=100) \n",
    "plot_alg_differences(ax[1:], model1, model2, fwd_fn1, fwd_fn2, comp_data['eval'], key, metrics=metrics, act='Softmax', corr_vs_incorr=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2d78cfb",
   "metadata": {},
   "source": [
    "For $d=10$ case, what happens here, as well as in all our other experiments in Appendix I, is that the training hasn't converged yet---this suggests a need for a learning rate scheduler."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "603a5722-fcc8-472d-9518-5e0645aa5af5",
   "metadata": {},
   "outputs": [],
   "source": [
    "run1 = 'sfmx_xdim_10'\n",
    "m1 = get_model(run1)\n",
    "model1, fwd_fn1 = m1['model'], m1['fwd_fn']\n",
    "_, _, opts = get_run_defaults(f'{base_folder}/{run1}')\n",
    "eval_sampler = jax.jit(samplers.make_balanced_classification_queries_sampler(opts.n_labels, opts.eval_context_len, 10000, opts.X_dim, opts.eval_noise_scale))\n",
    "key = jax.random.PRNGKey(120)\n",
    "data = eval_sampler(key)\n",
    "\n",
    "eval_sampler = jax.jit(samplers.make_balanced_classification_queries_sampler(opts.n_labels, opts.eval_context_len, 100, opts.X_dim, opts.eval_noise_scale))\n",
    "eval_data = eval_sampler(jax.random.PRNGKey(110))\n",
    "\n",
    "comp_data, key = get_eval_data(run1)\n",
    "\n",
    "model2 = cqk_cpv_fit(eval_data=data, cqk_low=-3, cqk_high=1, cpv_low=1, cpv_high=4)\n",
    "fwd_fn2 = model2.fwd_fn\n",
    "\n",
    "metrics = ['loss', 'p_corr', 'entr']\n",
    "\n",
    "dots = []\n",
    "norms = []\n",
    "pred_norms = []\n",
    "for i in tqdm(range(0,20,1)):\n",
    "    m1 = get_model(run1, i)\n",
    "    model1, fwd_fn1 = m1['model'], m1['fwd_fn']\n",
    "    dot, norm, pred_norm = analyse2(eval_data, model1, model2, fwd_fn1, fwd_fn2, key=jax.random.PRNGKey(10))\n",
    "    dots.append(dot)\n",
    "    norms.append(norm)\n",
    "    pred_norms.append(pred_norm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9f9106d-0c00-486a-bc89-b1f4234935dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(\n",
    "    1, len(metrics) + 1,\n",
    "    figsize=(7 + 4 * len(metrics), 3),\n",
    "    gridspec_kw={'width_ratios': [2.5] + [1.5]*len(metrics)}  # Give more width to the first plot\n",
    ")\n",
    "plt.subplots_adjust(wspace=0.1)\n",
    "plot_analysis(ax[0], dots, pred_norms, norms, factor=100) \n",
    "plot_alg_differences(ax[1:], model1, model2, fwd_fn1, fwd_fn2, comp_data['eval'], key, metrics=metrics, act='Softmax', corr_vs_incorr=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "iclmi",
   "language": "python",
   "name": "iclmi"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
