{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "python"
    }
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "vscode": {
     "languageId": "python"
    }
   },
   "outputs": [],
   "source": [
    "from preprocessing import *\n",
    "from word_swap import *\n",
    "from activations import *\n",
    "import torch\n",
    "from transformers import AutoTokenizer,AutoModel,pipeline\n",
    "from nltk.corpus import brown\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "from einops import rearrange, reduce\n",
    "from nesim.utils.grid_size import find_rectangle_dimensions\n",
    "# Laod a custom color map for better visualization\n",
    "from scipy.io import loadmat\n",
    "colormap = loadmat('colormap-custom-lightblue-to-yellow1.mat')['cmap']\n",
    "colormap = matplotlib.colors.ListedColormap(colormap)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Reproducing analyses"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This section walks through the reproduction of some of the figures. The remaining figures can be reproduced by making small changes to the code below, e.g., changing the model type, the swap procedure, the structure length (e.g., 12-word sentences), or the structure type."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Overall integration"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we instantiate the model (here, GPT-2), tokenizer, corpus, and swapper. For the overall integration analysis, we select swaps randomly from a pool of words having the same POS-tag (i.e., using the `RandomPosWordSwap` swapper). We then extract natural sequences of 40 words."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "vscode": {
     "languageId": "python"
    }
   },
   "outputs": [],
   "source": [
    "# model_type = \"gpt2\"\n",
    "# tokenizer = AutoTokenizer.from_pretrained(\n",
    "#     model_type, add_special_token=False,add_prefix_space=True)\n",
    "# model = AutoModel.from_pretrained(model_type)\n",
    "\n",
    "from nesim.utils.checkpoint import get_checkpoint_path_gpt_neo_125m\n",
    "from nesim.experiments.gpt_neo_125m import get_checkpoint, get_untrained_model_and_tokenizer\n",
    "\n",
    "topo_scales = [1,5,10,50]\n",
    "global_step = 10500\n",
    "\n",
    "checkpoint_dir = \"/home/mdeb6/repos/nesim/training/gpt_neo_125m/checkpoints\"\n",
    "device = \"cuda:0\"\n",
    "\n",
    "checkpoints_map = {\n",
    "    \"untrained\": None,\n",
    "    # \"pretrained\": \"pretrained\",\n",
    "    \"baseline\": get_checkpoint_path_gpt_neo_125m(\n",
    "        checkpoints_dir=checkpoint_dir, \n",
    "        topo_scale=0, \n",
    "        global_step=global_step\n",
    "    ),\n",
    "}\n",
    "\n",
    "for topo_scale in topo_scales:\n",
    "\n",
    "    checkpoints_map[f\"topo_{topo_scale}\"] = get_checkpoint_path_gpt_neo_125m(\n",
    "        checkpoints_dir=checkpoint_dir, \n",
    "        topo_scale=topo_scale, \n",
    "        global_step=global_step\n",
    "    )\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We perform the swap procedure, obtain activations, and calulate the difference:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "python"
    }
   },
   "outputs": [],
   "source": [
    "from activations import get_activations, get_activations_hacked\n",
    "from analysis import *\n",
    "from plots import *\n",
    "\n",
    "def calculate_differences(swapped_seqs,original_seqs,tokenizer,model,device='cuda', hacked = False):\n",
    "    \"\"\"Calculates the differences for a list of swapped and original sequences.\n",
    "\n",
    "    Args:\n",
    "        swapped_seqs (list): list of lists of swapped sequences\n",
    "        original_seqs (list): list of original sequences\n",
    "        tokenizer: Huggingface tokenizer\n",
    "        model: Huggingface model\n",
    "        device (str): device to use for the model (e.g., 'cpu', 'cuda')\n",
    "\n",
    "    Returns:\n",
    "        out (torch.Tensor): tensor of differences with shape (n_layers, swap position, measured position, n_features)\n",
    "    \"\"\"\n",
    "\n",
    "    out = None\n",
    "    n = 0\n",
    "    for zz,(swapped,original) in tqdm(enumerate(zip(swapped_seqs,original_seqs))):\n",
    "        try:\n",
    "            swapped_dfs = []\n",
    "            for s in swapped:\n",
    "                df = text_to_df(s,tokenizer=tokenizer)\n",
    "                swapped_dfs.append(df)\n",
    "            original_df = text_to_df(original,tokenizer=tokenizer)\n",
    "            if not hacked:\n",
    "                swapped_activations = get_activations(swapped_dfs,model=model,device=device)\n",
    "                original_activations = get_activations([original_df],model=model,device=device)\n",
    "            else:\n",
    "                swapped_activations = get_activations_hacked(swapped_dfs,model=model,device=device)\n",
    "                original_activations = get_activations_hacked([original_df],model=model,device=device)\n",
    "\n",
    "            if (swapped_activations is not None) and (original_activations is not None):\n",
    "                difference = torch.abs(swapped_activations - original_activations)\n",
    "                if out is None:\n",
    "                    out = difference\n",
    "                else:\n",
    "                    out += difference\n",
    "                n += 1\n",
    "        except ValueError:\n",
    "            print(\"Error: \",zz,original)\n",
    "            continue\n",
    "    print(\"Finished calculating difference tensor for \",n,\" sequences\")\n",
    "    return out/n\n",
    "\n",
    "\n",
    "def average_over_units_indiv_plots(layers,fitobj,D_delta):  \n",
    "    # plot differences and fits\n",
    "    max_window_size = 21\n",
    "    idx = 0\n",
    "    t = np.arange(0,max_window_size)\n",
    "    D_delta_all = np.stack(D_delta,0).mean(0)\n",
    "    fig,axes = plt.subplots(1,len(layers))\n",
    "    for k,layer in enumerate(layers):\n",
    "        ax = axes.flat[k]\n",
    "\n",
    "\n",
    "        for u in D_delta:\n",
    "            ax.plot(t, np.median(u[:, :, layer],axis=1), color='gray',label = 'Indiv.' '\\n' 'Units',linewidth=.5)\n",
    "        ax.plot(t, np.median(D_delta_all[:, :, layer], axis=1), color = 'black', linestyle=\"-\", linewidth=2,label=\"Mean\")\n",
    "        ax.set_ylim([0, 1.2])\n",
    "        \n",
    "        \n",
    "        ax.set_xlim([0, max_window_size - 1])\n",
    "\n",
    "        #     ax.set_xticks([])\n",
    "        # ax.tick_params(labelright= False,labeltop= False,labelleft= False, labelbottom= False)\n",
    "        ax.set_title(f\"Layer {layer}\")\n",
    "        ax.grid(False)        \n",
    "        ax.invert_xaxis()\n",
    "        if k == 0:\n",
    "            ax.set_ylabel(r\"$\\theta_{norm}[\\Delta]$\")\n",
    "        idx += 1\n",
    "    fig.tight_layout()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "python"
    }
   },
   "outputs": [],
   "source": [
    "checkpoint_names = [\n",
    "    \"baseline\",\n",
    "    \"topo_1\",\n",
    "    \"topo_5\",\n",
    "    \"topo_10\",\n",
    "    \"topo_50\"\n",
    "]\n",
    "\n",
    "for checkpoint_name in checkpoint_names:\n",
    "    print(f\"{checkpoint_name}\")\n",
    "    model, tokenizer = get_checkpoint(checkpoints_map[checkpoint_name], device=device)\n",
    "    tokenizer.add_prefix_space = True\n",
    "    overall_integration_corpus = Corpus(brown,single_token_words=False,tokenizer=tokenizer)\n",
    "    overall_integration_swapper = RandomPosWordSwap(overall_integration_corpus.word_lookup,\n",
    "                                                    overall_integration_corpus.pos_dict,\n",
    "                                                    tokenizer)\n",
    "    natural_sequences_40 = overall_integration_corpus.get_natural_sequences_of_length(40)\n",
    "    len(natural_sequences_40)\n",
    "\n",
    "    overall_integration_swapper(natural_sequences_40)\n",
    "\n",
    "    differences = calculate_differences(overall_integration_swapper.swapped,\n",
    "                                        overall_integration_swapper.original_sequences,\n",
    "                                        tokenizer,model,device='cuda', hacked = True)\n",
    "    \n",
    "    print(\"differences shape: \",differences.shape)\n",
    "    D = np.transpose(differences.numpy(), (3, 1, 2, 0))\n",
    "    print(\"D shape: \",D.shape)\n",
    "    n_features, n_stim_time, n_model_time, n_layers = D.shape\n",
    "    all_fits, all_D_delta = fit_curves(D)\n",
    "    stacked_fits = [np.stack(fit) for fit in all_fits]\n",
    "    stacked_fits = np.stack(stacked_fits)\n",
    "    # average_over_units_indiv_plots([1,3,6,9,11],all_fits,all_D_delta)\n",
    "\n",
    "    for parameter_index, parameter_name in zip([2,0,1],['c','a','b']):\n",
    "\n",
    "        # Prepare the figure for 3 rows and 4 columns\n",
    "        fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(20, 15))\n",
    "        fig.suptitle(f\"{checkpoint_name}\")\n",
    "        # Flatten the 2D array of axes for easy iteration\n",
    "        axes = axes.flatten()\n",
    "\n",
    "        for layer_index in range(12):\n",
    "            print(layer_index)\n",
    "            c_values = stacked_fits[:, layer_index, parameter_index]\n",
    "            size = find_rectangle_dimensions(c_values.shape[0])\n",
    "            \n",
    "            # Plot in the corresponding subplot with aspect ratio preserved\n",
    "            ax = axes[layer_index]\n",
    "            im = ax.imshow(c_values.reshape(size.height, size.width), aspect='equal', vmin=0, vmax=1, cmap = \"coolwarm\")\n",
    "            \n",
    "            # Disable the grid and remove ticks\n",
    "            ax.grid(False)\n",
    "            ax.set_xticks([])\n",
    "            ax.set_yticks([])\n",
    "            \n",
    "            fig.colorbar(im, ax=ax)\n",
    "            ax.set_title(f\"Layer: {layer_index}\")\n",
    "\n",
    "        # Adjust layout to avoid overlap\n",
    "        plt.tight_layout()\n",
    "        plt.show()\n",
    "        filename = f\"assets/{checkpoint_name}_{parameter_name}_values.pdf\"\n",
    "        fig.savefig(filename)\n",
    "        print(f\"saved: {filename}\")\n",
    "        # plt.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "vscode": {
     "languageId": "python"
    }
   },
   "outputs": [],
   "source": [
    "plot_fit_params(stacked_fits)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Mayukh plotting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "python"
    }
   },
   "outputs": [],
   "source": [
    "!du -sh send.pdf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "python"
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "# Prepare the figure for 3 rows and 4 columns\n",
    "fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(20, 15))\n",
    "\n",
    "# Flatten the 2D array of axes for easy iteration\n",
    "axes = axes.flatten()\n",
    "\n",
    "for layer_index in range(12):\n",
    "    print(layer_index)\n",
    "    c_values = stacked_fits[:, layer_index, 0]\n",
    "    size = find_rectangle_dimensions(c_values.shape[0])\n",
    "    \n",
    "    # Plot in the corresponding subplot with aspect ratio preserved\n",
    "    ax = axes[layer_index]\n",
    "    im = ax.imshow(c_values.reshape(size.height, size.width), aspect='equal', vmin=0, vmax=1, cmap = \"coolwarm\")\n",
    "    \n",
    "    # Disable the grid and remove ticks\n",
    "    ax.grid(False)\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "    \n",
    "    fig.colorbar(im, ax=ax)\n",
    "    ax.set_title(f\"Layer: {layer_index}\")\n",
    "\n",
    "# Adjust layout to avoid overlap\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "# fig.savefig(\"send.png\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "",
   "version": ""
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
