{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51202438",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d4003d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.models as models\n",
    "import torch\n",
    "from nesim.utils.json_stuff import load_json_as_dict\n",
    "from nesim.utils.getting_modules import get_module_by_name\n",
    "from nesim.utils.hook import ForwardHook\n",
    "from nesim.utils.grid_size import find_rectangle_dimensions\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dae327c",
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoints_root = '/research/XXXX-1/nesim/training/imagenet/resnet18/checkpoints/imagenet/'\n",
    "dataset_root = './datasets/curated'\n",
    "device ='cuda:0'\n",
    "layer_names_filename = '/research/XXXX-1/nesim/training/imagenet/resnet18/possible_nesim_layers.json'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdb1651b",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.resnet18(weights = None)\n",
    "\n",
    "baseline_checkpoint_filename = os.path.join(\n",
    "    checkpoints_root,\n",
    "    'baseline_shrink_factor_[5.0]_loss_scale_None_layers_all_conv_layers__bimt_scale_None_from_pretrained_False_apply_every_20_steps',\n",
    "    'best/best_model.ckpt'\n",
    ")\n",
    "\n",
    "our_checkpoint_filename = os.path.join(\n",
    "    checkpoints_root,\n",
    "    'shrink_factor_[5.0]_loss_scale_150_layers_all_conv_layers__bimt_scale_None_from_pretrained_False_apply_every_20_steps',\n",
    "    'best/best_model.ckpt'\n",
    ")\n",
    "\n",
    "checkpoint_map = {\n",
    "    'baseline': baseline_checkpoint_filename,\n",
    "    'ours': our_checkpoint_filename\n",
    "}\n",
    "\n",
    "for path in checkpoint_map.values():\n",
    "    assert os.path.exists(path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "719b55bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint_map = {\n",
    "    'baseline': baseline_checkpoint_filename,\n",
    "    'ours': our_checkpoint_filename\n",
    "}\n",
    "target_layer_names = load_json_as_dict(layer_names_filename)['all_conv_layers']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16190ce8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nesim.vis.weights_pca import WeightsPCAViewer\n",
    "\n",
    "\n",
    "checkpoint_filename = checkpoint_map['ours']\n",
    "state_dict = torch.load(checkpoint_filename)['state_dict']\n",
    "state_dict_with_fixed_keys = {}\n",
    "for key in state_dict:\n",
    "    state_dict_with_fixed_keys[key.replace('model.','')] = state_dict[key]\n",
    "    \n",
    "model.load_state_dict(state_dict_with_fixed_keys)\n",
    "torch.save(state_dict_with_fixed_keys, 'our_model.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4f4c255",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nesim.vis.weights_pca import WeightsPCAViewer\n",
    "\n",
    "\n",
    "checkpoint_filename = checkpoint_map['baseline']\n",
    "state_dict = torch.load(checkpoint_filename)['state_dict']\n",
    "state_dict_with_fixed_keys = {}\n",
    "for key in state_dict:\n",
    "    state_dict_with_fixed_keys[key.replace('model.','')] = state_dict[key]\n",
    "    \n",
    "model.load_state_dict(state_dict_with_fixed_keys)\n",
    "torch.save(state_dict_with_fixed_keys, 'baseline.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e02a995",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model.state_dict().keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6ab46f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(target_layer_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c89a516e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from typing import List\n",
    "from nesim.utils.getting_modules import get_module_by_name\n",
    "from nesim.utils.grid_size import find_rectangle_dimensions\n",
    "\n",
    "def visualize_layers(\n",
    "    model,\n",
    "    target_layer_names: List[str],\n",
    "    checkpoint_filename: str,\n",
    "    device: str = 'cuda:0',\n",
    "    figsize: tuple = (15, 15),\n",
    "    layer_name_fontsize: int = 18\n",
    "):\n",
    "    num_images = len(target_layer_names)\n",
    "    num_rows = int(np.sqrt(num_images))\n",
    "    num_cols = int(np.ceil(num_images / num_rows))\n",
    "\n",
    "    fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize)\n",
    "    fig.suptitle(f'Visualising the top 3 principal components of the\\nConvolutional Conceptual Grid', fontsize = layer_name_fontsize)\n",
    "    if num_images == 1:\n",
    "        axes = np.array([[axes]])\n",
    "\n",
    "    for idx, name in enumerate(target_layer_names):\n",
    "        row = idx // num_cols\n",
    "        col = idx % num_cols\n",
    "        ax = axes[row, col]\n",
    "\n",
    "        layer = get_module_by_name(module=model, name=name)\n",
    "        area = layer.weight.data.shape[0]\n",
    "        size = find_rectangle_dimensions(area=area)\n",
    "\n",
    "        viewer = WeightsPCAViewer(\n",
    "            model=model,\n",
    "            checkpoint_filenames=[checkpoint_filename],\n",
    "            layer_name=name,\n",
    "            device=device,\n",
    "            resize_height=int(size.width * 4),\n",
    "            resize_width=int(size.height * 4),\n",
    "            scale_by_magnitude=True,\n",
    "            load_from_brain_inspired_layers=True\n",
    "        )\n",
    "\n",
    "        ax.imshow(viewer[0])\n",
    "        ax.set_title(f'layer: {name}\\ngrid size: {(size.height, size.width)}', fontsize=layer_name_fontsize)\n",
    "        ax.axis('off')\n",
    "\n",
    "    for idx in range(len(target_layer_names), num_rows * num_cols):\n",
    "        row = idx // num_cols\n",
    "        col = idx % num_cols\n",
    "        axes[row, col].axis('off')\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "448c412a",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "visualize_layers(\n",
    "    model=model,\n",
    "    target_layer_names = target_layer_names,\n",
    "    checkpoint_filename = 'our_model.pth',\n",
    "    device = 'cuda:0',\n",
    "    figsize = (15, 18),\n",
    "    layer_name_fontsize = 18\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddbc62c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_layers(\n",
    "    model=model,\n",
    "    target_layer_names = target_layer_names,\n",
    "    checkpoint_filename = 'baseline.pth',\n",
    "    device = 'cuda:0',\n",
    "    figsize = (15, 18),\n",
    "    layer_name_fontsize = 18\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4e64e01",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
