{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports and general settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import thingsvision.vision as vision\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "#import seaborn as sns\n",
    "import itertools\n",
    "import torch.nn as nn\n",
    "import torch.nn.parallel\n",
    "import torch.backends.cudnn as cudnn\n",
    "import torch.distributed as dist\n",
    "import torch.optim\n",
    "import sys\n",
    "import torch.multiprocessing as mp\n",
    "import torch.utils.data\n",
    "import torch.utils.data.distributed\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.models as models\n",
    "import matplotlib.pyplot as plt\n",
    "from rsa_helpers import prep_condition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set base network\n",
    "base_network = \"Res18\"\n",
    "\n",
    "# Set device\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "# Set base network and load it\n",
    "folder_name = f\"./results/{base_network}/\"\n",
    "figure_path = f\"./figures/{base_network}/RSA/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make condition list\n",
    "conditions = [f\"{base_network}_Base_condition\",\n",
    "              f\"{base_network}_Plus_1ep\",\n",
    "              f\"{base_network}_Plus_10ep\",\n",
    "              f\"{base_network}_Different_optimizer\",\n",
    "              f\"{base_network}_Different_batchsize\",\n",
    "              f\"{base_network}_Different_initialisation\",\n",
    "              f\"{base_network}_Different_LR\",\n",
    "              f\"{base_network}_CUDA_nondeterministic\",\n",
    "              f\"{base_network}_Different_dataorder\",\n",
    "              f\"{base_network}_Different_architecture\",\n",
    "              f\"{base_network}_Different_data\",\n",
    "              f\"{base_network}_Half_data\",\n",
    "              f\"{base_network}_Combined_condition\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare validation loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/wichmann/lschulzebuschoff43/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py:474: UserWarning: This DataLoader will create 20 worker processes in total. Our suggested max number of worker in current system is 8, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n",
      "  warnings.warn(_create_warning_msg(\n"
     ]
    }
   ],
   "source": [
    "# Load ImageNet validation dataset\n",
    "valdir = os.path.join('/scratch_local/datasets/ImageNet2012/', 'val')\n",
    "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
    "val_dataset = datasets.ImageFolder(\n",
    "        valdir,\n",
    "        transforms.Compose([\n",
    "            transforms.Resize(256),\n",
    "            transforms.CenterCrop(224),\n",
    "            transforms.ToTensor(),\n",
    "            normalize,\n",
    "        ]))\n",
    "\n",
    "# Split dataset and init dataloader\n",
    "nImages = 5000\n",
    "split_val_dataset = torch.utils.data.dataset.random_split(val_dataset, [nImages, len(val_dataset)-nImages],\n",
    "                                                          generator=torch.Generator().manual_seed(1312))\n",
    "val_loader = torch.utils.data.DataLoader(split_val_dataset[0], batch_size=50, shuffle=False, num_workers=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define modules that should be analysed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define different modules of interest depending on model family\n",
    "if base_network == \"Res18\":\n",
    "    modules = [\"fc\"]\n",
    "    model_name = \"resnet18\"\n",
    "    diff_arch = \"densenet121\"\n",
    "    diff_modules = [\"classifier\"]\n",
    "    epochs = [90]\n",
    "else:\n",
    "    raise Exception(\"No correct base network defined.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make RDMs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 90, Module: fc, Condition: Res18_Base_condition\n",
      "...Features successfully extracted for all 5000 images in the database.\n",
      "Epoch: 90, Module: fc, Condition: Res18_Plus_1ep\n",
      "...Features successfully extracted for all 5000 images in the database.\n",
      "Epoch: 90, Module: fc, Condition: Res18_Plus_10ep\n",
      "...Features successfully extracted for all 5000 images in the database.\n",
      "Epoch: 90, Module: fc, Condition: Res18_Different_optimizer\n",
      "...Features successfully extracted for all 5000 images in the database.\n",
      "Epoch: 90, Module: fc, Condition: Res18_Different_batchsize\n",
      "...Features successfully extracted for all 5000 images in the database.\n",
      "Epoch: 90, Module: fc, Condition: Res18_Different_initialisation\n",
      "...Features successfully extracted for all 5000 images in the database.\n",
      "Epoch: 90, Module: fc, Condition: Res18_Different_LR\n",
      "...Features successfully extracted for all 5000 images in the database.\n",
      "Epoch: 90, Module: fc, Condition: Res18_CUDA_nondeterministic\n",
      "...Features successfully extracted for all 5000 images in the database.\n",
      "Epoch: 90, Module: fc, Condition: Res18_Different_dataorder\n",
      "...Features successfully extracted for all 5000 images in the database.\n",
      "Epoch: 90, Module: fc, Condition: Res18_Different_architecture\n",
      "...Features successfully extracted for all 5000 images in the database.\n",
      "Epoch: 90, Module: fc, Condition: Res18_Different_data\n",
      "...Features successfully extracted for all 5000 images in the database.\n",
      "Epoch: 90, Module: fc, Condition: Res18_Half_data\n",
      "...Features successfully extracted for all 5000 images in the database.\n",
      "Epoch: 90, Module: fc, Condition: Res18_Combined_condition\n",
      "...Features successfully extracted for all 5000 images in the database.\n"
     ]
    }
   ],
   "source": [
    "# Make RDMs for all conditions\n",
    "for epoch in epochs:\n",
    "    for mod_ind, module in enumerate(modules):\n",
    "        for ind, condition in enumerate(conditions):\n",
    "            \n",
    "            # Print which RDM is being calculated\n",
    "            print(f\"Epoch: {epoch}, Module: {module}, Condition: {condition}\")\n",
    "\n",
    "            # Set path\n",
    "            if condition.endswith(\"Plus_1ep\"):\n",
    "                path = folder_name + f\"{base_network}_Base_condition\" + \"/NUM1\" + f\"/MODEL_EP{epoch+1}\"\n",
    "            elif condition.endswith(\"Plus_10ep\"):\n",
    "                path = folder_name + f\"{base_network}_Base_condition\" + \"/NUM1\" + f\"/MODEL_EP{epoch+10}\"\n",
    "            elif condition.endswith(\"Different_data\"):\n",
    "                path = folder_name + f\"{base_network}_Half_data\" + \"/NUM1\" + f\"/MODEL_EP{epoch}\"\n",
    "            else:\n",
    "                path = folder_name + condition + \"/NUM1\" + f\"/MODEL_EP{epoch}\"\n",
    "\n",
    "            # Check for different architecture condition and load model\n",
    "            if condition == f\"{base_network}_Different_architecture\" or condition == f\"{base_network}_Combined_condition\":\n",
    "                model, _ = vision.load_model(model_name=diff_arch, pretrained=True, device=device, model_path=path)\n",
    "            else:\n",
    "                model, _ = vision.load_model(model_name=model_name, pretrained=True, device=device, model_path=path)\n",
    "\n",
    "            # Extract and save features\n",
    "            model.eval()\n",
    "            if condition == f\"{base_network}_Different_architecture\" or condition == f\"{base_network}_Combined_condition\":\n",
    "                features, _ = vision.extract_features(model, val_loader, diff_modules[mod_ind],\n",
    "                                                      batch_size=10, flatten_acts=True, device=device)\n",
    "            else:\n",
    "                features, _ = vision.extract_features(model, val_loader, module,\n",
    "                                                      batch_size=10, flatten_acts=True, device=device)\n",
    "            \n",
    "            # features = vision.center_features(features)\n",
    "            # features = vision.normalize_features(features)\n",
    "            \n",
    "            # Compute RDM\n",
    "            rdm = vision.compute_rdm(features, \"correlation\")\n",
    "            np.save(f'./RSA/{base_network}/{condition}_{module}_rdm_ep{epoch}', rdm)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Calculate correlation between base RDM and RDMs of conditions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "corrs = np.zeros(len(conditions))\n",
    "epoch = 90\n",
    "module = \"fc\"\n",
    "base_rdm = np.load(f'./RSA/{base_network}/{base_network}_Base_condition_{module}_rdm_ep{epoch}.npy')\n",
    "\n",
    "# Loop through conditions to make array\n",
    "for ind, condition in enumerate(conditions):\n",
    "\n",
    "    # Load data, conditions with more epochs than base are stored in base file\n",
    "    cond_rdm = np.load(f'./RSA/{base_network}/{condition}_{module}_rdm_ep{epoch}.npy')\n",
    "    corrs[ind] = vision.correlate_rdms(base_rdm, cond_rdm, \"pearson\")\n",
    "    \n",
    "np.save(f'{base_network}_rsa_corrs', corrs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
