{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e9fb0d39",
   "metadata": {},
   "source": [
    "# `Contrast` Feature Analysis in VGG-16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "engaged-shell",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pickle\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "\n",
    "from torchvision.models import vgg16\n",
    "\n",
    "from torchvision import transforms\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data import Subset\n",
    "from torchvision.datasets import ImageNet\n",
    "\n",
    "from utils.infonets import InfoNet\n",
    "from utils.miscellaneous import get_info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "stylish-present",
   "metadata": {},
   "outputs": [],
   "source": [
    "from random import sample\n",
    "\n",
    "# Prepare the Experiment hyperparameters\n",
    "img_size = 224\n",
    "pop_size = 250\n",
    "num_imgs = 1500\n",
    "num_exps = 5\n",
    "\n",
    "batch_size = 100\n",
    "\n",
    "img_mean, img_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]\n",
    "pipeline = transforms.Compose([\n",
    "    transforms.Resize(img_size),\n",
    "    transforms.CenterCrop(img_size),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean = img_mean, std = img_std)\n",
    "])\n",
    "\n",
    "# * Here we load the ImageNet dataset\n",
    "imagenet = ImageNet('path_to_ImageNet', split = 'val', transform = pipeline)\n",
    "\n",
    "# Define the indices of ImageNet images we want to use for the experiments\n",
    "tot_size = len(imagenet)\n",
    "Exp_Idxs = [sample(range(tot_size), num_imgs) for _ in range(num_exps)]\n",
    "\n",
    "# Take a subset of ImageNet containing only the selected classes\n",
    "datasets = [Subset(imagenet, idxs) for idxs in Exp_Idxs]\n",
    "\n",
    "# Construct the loaders for the ImageNet subsets and sample the images\n",
    "loaders = [DataLoader(dataset, batch_size = batch_size, shuffle = False) for dataset in datasets]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "injured-management",
   "metadata": {},
   "outputs": [],
   "source": [
    "from numpy.random import randint\n",
    "\n",
    "# Generate the unit indices and compute their receptive fields for all experiments \n",
    "vgg = vgg16(pretrained = True)\n",
    "tmp = InfoNet(vgg)\n",
    "\n",
    "# Get network layer shapes for the selected input\n",
    "excluded_layers = (nn.Dropout, nn.BatchNorm2d, nn.AdaptiveAvgPool2d)\n",
    "info = get_info(vgg, (1, 3, img_size, img_size), exclude = excluded_layers)\n",
    "\n",
    "# Use network shapes to select a random population of fixed size\n",
    "Exp_unit = []\n",
    "Exp_uRFs = []\n",
    "\n",
    "# NOTE: In this setup every experiment should measure from the same units\n",
    "units = [randint(0, shape, size = (pop_size, len(shape))) for shape in info.shapes]\n",
    "untrf = tmp.measure_RF(keys = info.names, units = units)\n",
    "\n",
    "for _ in range(num_exps):\n",
    "    Exp_unit += [units]\n",
    "    Exp_uRFs += [untrf]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "metropolitan-receipt",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.miscellaneous import getnet\n",
    "\n",
    "# We prepare the network we want to use for the analysis\n",
    "transfer = [i for i, name in enumerate(info.names) if 'conv' in name][:4] + ['all']\n",
    "assert len(transfer) == num_exps\n",
    "\n",
    "tnets = (getnet('all') for _ in range(num_exps))\n",
    "rnets = (getnet('rnd') for _ in range(num_exps))\n",
    "\n",
    "# * Prepare the experiments\n",
    "tExps = (InfoNet(net) for net in tnets)\n",
    "rExps = (InfoNet(net) for net in rnets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "minimal-affiliation",
   "metadata": {},
   "outputs": [],
   "source": [
    "def storeit(storepath : str, bundle : dict):\n",
    "    desc = \\\n",
    "        '''\n",
    "            This data bundle contains the data from the mutual information experiment\n",
    "            using Contrast as stimuli feature on VGG-16 and repeating the experiment\n",
    "            for different VGG-16 setups. We scan from a compleately random network\n",
    "            to a fully trained network, with intermediate cases being network with\n",
    "            layers [0, N] fully trained (equal to VGG-16 trained on ImageNet), and\n",
    "            layers [N, L] (with L being VGG-16 total number of layers) being fully\n",
    "            random. Each experiment is performed once for each configuration.\n",
    "        '''\n",
    "    \n",
    "    bundle['desc'] = desc\n",
    "    \n",
    "    with open(storepath, 'wb') as f:\n",
    "        pickle.dump(bundle, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "consolidated-gabriel",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5c39c68c106741ed82b22dd2b8f305ad",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Experiments'), FloatProgress(value=0.0, max=5.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                             \r"
     ]
    }
   ],
   "source": [
    "from utils.measures import contrast\n",
    "from utils.algorithms import iprofile\n",
    "\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "# Here we sample the dataset\n",
    "Data = ([batch for batch in loader] for loader in loaders)\n",
    "\n",
    "miter = enumerate(tqdm(zip(tExps, rExps, Data, Exp_unit, Exp_uRFs), total = num_exps, desc = 'Experiments', leave = False))\n",
    "for i, (texp, rexp, data, units, uRFs) in miter:\n",
    "    # Record the network activations from the images\n",
    "    texp.record(units = units, rec_imgs = data)\n",
    "    rexp.record(units = units, rec_imgs = data)\n",
    "\n",
    "    # Collect the experiment images for feature extraction \n",
    "    imgs = torch.cat([tmp[0] for tmp in data]).numpy()\n",
    "\n",
    "    # Compute image contrast, the feature in our experiment\n",
    "    Feat = contrast(imgs, uRFs)\n",
    "\n",
    "    # Extract the unit activations from the experiment\n",
    "    ls, Ls = info.names, texp.recorder.keys\n",
    "\n",
    "    Acts = ({l : texp.features[L].T for l, L in zip(ls, Ls)},\n",
    "            {l : rexp.features[L].T for l, L in zip(ls, Ls)})\n",
    "\n",
    "    # Compute the mutual information profile using the latest Features and Activations\n",
    "    Inet = [iprofile(Feat, acts, nbins = (20, 20), bias = 'pt') for acts in Acts]\n",
    "    \n",
    "    # Store the computed experiments in long storage\n",
    "    storepath = f'path_where_to_store/Experiment_{i+1}.pkl'\n",
    "\n",
    "    bundle = {\n",
    "        'info' : info,\n",
    "        \n",
    "        'exp_id' : i + 1,\n",
    "\n",
    "        'img_size' : img_size,\n",
    "        'pop_size' : pop_size,\n",
    "        'num_imgs' : num_imgs,\n",
    "        'num_exps' : num_exps,\n",
    "        'transfer' : transfer,\n",
    "\n",
    "        'batch_size' : batch_size,\n",
    "\n",
    "        'Exp_Idxs' : Exp_Idxs,\n",
    "        'Exp_unit' : Exp_unit,\n",
    "        'Exp_uRFs' : Exp_uRFs,\n",
    "\n",
    "        'Feat' : Feat,\n",
    "        'Acts' : Acts,\n",
    "        'Inet' : Inet\n",
    "    }\n",
    "    \n",
    "    storeit(storepath, bundle)\n"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "d463e5fc653bab9d8e4f24359b8c55f1b66bd68a84a58e4a6e27c30ef7709220"
  },
  "kernelspec": {
   "display_name": "Python 3.8.10 64-bit ('SISSA': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
