{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "HOME = os.environ['HOME']  # change if necessary\n",
    "sys.path.append(f'{HOME}/Finite-groups/src')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Syntax warning: Unbound global variable in /usr/share/gap/pkg/browse/PackageIn\\\n",
      "fo.g:73\n",
      "  if not IsKernelExtensionAvailable(\"Browse\", \"ncurses\") then\n",
      "         ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "Syntax warning: Unbound global variable in /usr/share/gap/pkg/edim/PackageInfo\\\n",
      ".g:60\n",
      "  if not IsKernelExtensionAvailable(\"EDIM\",\"ediv\") then\n",
      "         ^^^^^^^^^^^^^^^^^^^^^^^^^^\n"
     ]
    }
   ],
   "source": [
    "import torch as t\n",
    "import numpy as np\n",
    "from matplotlib import pyplot as plt\n",
    "import json\n",
    "from itertools import product\n",
    "from jaxtyping import Float\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm.notebook import tqdm\n",
    "import plotly.graph_objects as go\n",
    "import copy\n",
    "import math\n",
    "from itertools import product\n",
    "import pandas as pd\n",
    "from typing import Union\n",
    "from einops import repeat\n",
    "from huggingface_hub import snapshot_download\n",
    "from huggingface_hub.utils import disable_progress_bars\n",
    "\n",
    "\n",
    "from model import MLP3, MLP4, InstancedModule\n",
    "from utils import *\n",
    "from group_data import *\n",
    "from model_utils import *\n",
    "from group_utils import *\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Intersection size: 2809/2809 (1.00)\n",
      "Added 2809 elements from intersection\n",
      "Added 0 elements from group 0: Z(53)\n",
      "Taking random subset: 1123/2809 (0.40)\n",
      "Train set size: 1123/2809 (0.40)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXX-1/Finite-groups/src/model_utils.py:50: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See XXXX for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  model.load_state_dict(t.load(model_path, map_location=device))\n"
     ]
    }
   ],
   "source": [
    "\n",
    "device = t.device(\"cuda\" if t.cuda.is_available() else \"cpu\")\n",
    "#MODEL_DIR = '2024-09-20_20-20-22_MLP2_Z_53_'\n",
    "MODEL_DIR = '2024-09-20_20-50-54_MLP2_Z_53_'\n",
    "local_dir = f'{HOME}/models/{MODEL_DIR}'\n",
    "models, params = load_models(local_dir)\n",
    "data = GroupData(params)\n",
    "group = data.groups[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'G0_loss': tensor(0.0711), 'G0_acc': tensor(0.9955)}"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss_dict = test_loss(models[-1].to(device), data)\n",
    "mean(loss_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "# instance = loss_dict['G0_loss'].argmin().item()\n",
    "# print(loss_dict[f'G0_loss'][instance], loss_dict[f'G0_acc'][instance], instance)\n",
    "# model = models[-1][instance].to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean = lambda d: {k: v.mean() for k, v in d.items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Swap embeds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'G0_loss': tensor(4.1553), 'G0_acc': tensor(0.0385)}"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model2 = copy.deepcopy(model)\n",
    "model2.embedding_left = nn.Parameter(model.embedding_right)\n",
    "model2.embedding_right = nn.Parameter(model.embedding_left)\n",
    "mean(test_loss(model2, data))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Change embedding signs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "({'G0_loss': tensor(17.1501), 'G0_acc': tensor(0.)},\n",
       " {'G0_loss': tensor(17.1500), 'G0_acc': tensor(0.)},\n",
       " {'G0_loss': tensor(0.0663), 'G0_acc': tensor(0.9969)})"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model2 = copy.deepcopy(model)\n",
    "model3 = copy.deepcopy(model)\n",
    "model4 = copy.deepcopy(model)\n",
    "model2.embedding_left = nn.Parameter(-model.embedding_left)\n",
    "model3.embedding_right = nn.Parameter(-model.embedding_right)\n",
    "model4.embedding_left = nn.Parameter(-model.embedding_left)\n",
    "model4.embedding_right = nn.Parameter(-model.embedding_right)\n",
    "mean(test_loss(model2, data)), mean(test_loss(model3, data)), mean(test_loss(model4, data))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Absolute value nonlinearity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'G0_loss': tensor(0.0045), 'G0_acc': tensor(0.9988)}"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Abs(nn.Module):\n",
    "    def __init__(self, scale=1.):\n",
    "        super().__init__()\n",
    "        self.scale = scale\n",
    "\n",
    "    def forward(self, input: t.Tensor) -> t.Tensor:\n",
    "        return t.abs(input) * self.scale\n",
    "\n",
    "model2 = copy.deepcopy(model)\n",
    "model2.activation= Abs()\n",
    "mean(test_loss(model2, data))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Add noise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Should've used transformerlens....\n",
    "class MLP2Noise(InstancedModule):\n",
    "    '''\n",
    "    Architecture used by Chughtai et al. and Stander et al.\n",
    "    '''\n",
    "    def __init__(self, model, mean, std):\n",
    "        super().__init__()\n",
    "        model = copy.deepcopy(model)\n",
    "        self.params = model.params\n",
    "        self.N = model.N\n",
    "\n",
    "        # self.embedding_left = init_func(\n",
    "        self.embedding_left = nn.Parameter(model.embedding_left)\n",
    "        self.embedding_right = nn.Parameter(model.embedding_right)\n",
    "        self.linear_left = nn.Parameter(model.linear_left)\n",
    "        self.linear_right = nn.Parameter(model.linear_right)\n",
    "        self.unembedding = nn.Parameter(model.unembedding)\n",
    "        if model.unembed_bias is not None:\n",
    "            self.unembed_bias = nn.Parameter(model.unembed_bias)\n",
    "        else:\n",
    "            self.unembed_bias = None\n",
    "        self.activation = model.activation\n",
    "        self.mean = mean\n",
    "        self.std = std\n",
    "\n",
    "    def _forward(\n",
    "        self, a: Int[t.Tensor, \"batch_size entries\"]\n",
    "    ) -> Float[t.Tensor, \"batch_size instances d_vocab\"]:\n",
    "\n",
    "        a_instances = einops.repeat(\n",
    "            a, \" batch_size entries -> batch_size n entries\", n=self.num_instances(),\n",
    "        )  # batch_size instances entries\n",
    "        a_1, a_2 = a_instances[..., 0], a_instances[..., 1]\n",
    "\n",
    "        a_1_onehot = F.one_hot(a_1, num_classes=self.N).float()\n",
    "        a_2_onehot = F.one_hot(a_2, num_classes=self.N).float()\n",
    "\n",
    "        x_1 = einops.einsum(\n",
    "            a_1_onehot,\n",
    "            self.embedding_left,\n",
    "            \"batch_size instances d_vocab, instances d_vocab embed_dim -> batch_size instances embed_dim\",\n",
    "        )\n",
    "        x_2 = einops.einsum(\n",
    "            a_2_onehot,\n",
    "            self.embedding_right,\n",
    "            \"batch_size instances d_vocab, instances d_vocab embed_dim -> batch_size instances embed_dim\",\n",
    "        )\n",
    "\n",
    "        hidden_1 = einops.einsum(\n",
    "            x_1,\n",
    "            self.linear_left,\n",
    "            \"batch_size instances embed_dim, instances embed_dim hidden -> batch_size instances hidden\",\n",
    "        )\n",
    "        hidden_2 = einops.einsum(\n",
    "            x_2,\n",
    "            self.linear_right,\n",
    "            \"batch_size instances embed_dim, instances embed_dim hidden -> batch_size instances hidden\",\n",
    "        )\n",
    "        hidden = hidden_1 + hidden_2\n",
    "\n",
    "        hidden += t.randn_like(hidden) * self.std + self.mean\n",
    "\n",
    "        out = einops.einsum(\n",
    "            self.activation(hidden),\n",
    "            self.unembedding,\n",
    "            \"batch_size instances hidden, instances hidden d_vocab-> batch_size instances d_vocab \",\n",
    "        )\n",
    "        if self.unembed_bias is not None:\n",
    "            out += einops.repeat(\n",
    "                self.unembed_bias,\n",
    "                'instances d_vocab -> batch_size instances d_vocab',\n",
    "                batch_size=out.shape[0]\n",
    "            )\n",
    "\n",
    "        return out\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "({'G0_loss': tensor(0.8285), 'G0_acc': tensor(0.7690)},\n",
       " {'G0_loss': tensor(0.0752), 'G0_acc': tensor(0.9951)},\n",
       " {'G0_loss': tensor(1.7851), 'G0_acc': tensor(0.4980)},\n",
       " {'G0_loss': tensor(0.7797), 'G0_acc': tensor(0.8317)})"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model2 = MLP2Noise(model, 0., 1.)\n",
    "model3 = MLP2Noise(model, 0., .1)\n",
    "model4 = MLP2Noise(model, 1., 1.)\n",
    "model5 = MLP2Noise(model, -1., 1.)\n",
    "mean(test_loss(model2, data)), mean(test_loss(model3, data)), mean(test_loss(model4, data)), mean(test_loss(model5, data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "group_addition-BDyFuYvs-py3.10",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
