{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# `Orientation` Feature Analysis in VGG-16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pickle\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "\n",
    "from torchvision.models import vgg16\n",
    "\n",
    "from torchvision import transforms\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data import Subset\n",
    "from torchvision.datasets import ImageNet\n",
    "\n",
    "from utils.infonets import InfoNet\n",
    "from utils.miscellaneous import get_info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from random import sample\n",
    "\n",
    "# Prepare the Experiment hyperparameters\n",
    "img_size = 224\n",
    "pop_size = 250\n",
    "num_imgs = 1500\n",
    "num_exps = 2\n",
    "\n",
    "batch_size = 100\n",
    "\n",
    "img_mean, img_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]\n",
    "pipeline = transforms.Compose([\n",
    "    transforms.Resize(img_size),\n",
    "    transforms.CenterCrop(img_size),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean = img_mean, std = img_std)\n",
    "])\n",
    "\n",
    "# * Here we load the ImageNet dataset\n",
    "imagenet = ImageNet('path_to_ImageNet', split = 'val', transform = pipeline)\n",
    "\n",
    "# Define the indices of ImageNet images we want to use for the experiments\n",
    "tot_size = len(imagenet)\n",
    "Exp_Idxs = [sample(range(tot_size), num_imgs)] * num_exps# for _ in range(num_exps)]\n",
    "\n",
    "# Take a subset of ImageNet containing only the selected classes\n",
    "datasets = [Subset(imagenet, idxs) for idxs in Exp_Idxs]\n",
    "\n",
    "# Construct the loaders for the ImageNet subsets and sample the images\n",
    "loaders = [DataLoader(dataset, batch_size = batch_size, shuffle = False) for dataset in datasets]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/pmurator/SISSA/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /pytorch/c10/core/TensorImpl.h:1156.)\n",
      "  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)\n"
     ]
    }
   ],
   "source": [
    "from numpy.random import randint\n",
    "\n",
    "# Generate the unit indices and compute their receptive fields for all experiments \n",
    "vgg = vgg16(pretrained = True)\n",
    "tmp = InfoNet(vgg)\n",
    "\n",
    "# Get network layer shapes for the selected input\n",
    "excluded_layers = (nn.Dropout, nn.BatchNorm2d, nn.AdaptiveAvgPool2d)\n",
    "info = get_info(vgg, (1, 3, img_size, img_size), exclude = excluded_layers)\n",
    "\n",
    "# Use network shapes to select a random population of fixed size\n",
    "Exp_unit = []\n",
    "Exp_uRFs = []\n",
    "\n",
    "# NOTE: In this setup every experiment should measure from the same units\n",
    "units = [randint(0, shape, size = (pop_size, len(shape))) for shape in info.shapes]\n",
    "untrf = tmp.measure_RF(keys = info.names, units = units)\n",
    "\n",
    "for _ in range(num_exps):\n",
    "    Exp_unit += [units]\n",
    "    Exp_uRFs += [untrf]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from random import randint\n",
    "from utils.miscellaneous import getnet\n",
    "\n",
    "# We prepare the network we want to use for the analysis\n",
    "transfer = [[]] + ['all']\n",
    "assert len(transfer) == num_exps\n",
    "\n",
    "seed = randint(0, 10000)\n",
    "nets = (getnet(tls, seed = seed)  for tls in transfer)\n",
    "\n",
    "# * Prepare the experiments\n",
    "Exps = (InfoNet(net) for net in nets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Computing Orientation:   3%|███                                                                                                             | 1/36 [04:21<2:32:21, 261.19s/it]"
     ]
    }
   ],
   "source": [
    "from utils.measures import orientation\n",
    "from utils.algorithms import iprofile\n",
    "\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "# Here we sample the dataset\n",
    "Data = [[batch for batch in loader] for loader in loaders]\n",
    "\n",
    "Feat = []\n",
    "Msis = []\n",
    "Ents = []\n",
    "\n",
    "Acts = []\n",
    "\n",
    "Inet = []\n",
    "\n",
    "\n",
    "# Collect the experiment images for feature extraction \n",
    "imgs = torch.cat([tmp[0] for tmp in Data[0]]).unsqueeze(0).numpy()\n",
    "\n",
    "# Compute image luminosity, the feature in our experiment\n",
    "ang, msi, ent = orientation(imgs, Exp_uRFs[0])\n",
    "\n",
    "Feat += [ang]\n",
    "Msis += [msi]\n",
    "Ents += [ent]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "Acts = []\n",
    "\n",
    "for exp, data, units, uRFs in tqdm(zip(Exps, Data, Exp_unit, Exp_uRFs), total = num_exps, desc = 'Experiments', leave = False):\n",
    "    # Record the network activations from the images\n",
    "    exp.record(units = units, rec_imgs = data)\n",
    "\n",
    "    # Extract the unit activations from the experiment\n",
    "    ls, Ls = info.names, exp.recorder.keys\n",
    "\n",
    "    Acts += [{l : exp.features[L].T for l, L in zip(ls, Ls)}]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.algorithms import filtbest\n",
    "\n",
    "pop_num = 200\n",
    "img_num = 500\n",
    "\n",
    "f_Feat = []\n",
    "f_Acts = []\n",
    "\n",
    "for acts in Acts:\n",
    "    f_feat, f_acts = filtbest(Feat[0], acts, Msis[0], pop_num, img_num)\n",
    "    f_Feat += [f_feat]\n",
    "    f_Acts += [f_acts]\n",
    "    \n",
    "# Compute the mutual information profile using the filtered Features and Activations\n",
    "f_Inet = [iprofile(f_feat, f_acts, nbins = (20, 20), bias = 'pt') for f_feat, f_acts in zip(f_Feat, f_Acts)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Here we extract the individual entropy measures from all the experiments\n",
    "layers = [l for l in info.names]\n",
    "\n",
    "Hx  = [np.array([[I[l]['Hx']   if I[l] is not None else [np.nan] * pop_num for I in P] for l in layers]) for P in f_Inet]\n",
    "Hy  = [np.array([[I[l]['Hy']   if I[l] is not None else [np.nan] * pop_num for I in P] for l in layers]) for P in f_Inet]\n",
    "Hxy = [np.array([[I[l]['Hx|y'] if I[l] is not None else [np.nan] * pop_num for I in P] for l in layers]) for P in f_Inet]\n",
    "MI  = [np.array([[I[l]['Ix,y'] if I[l] is not None else [np.nan] * pop_num for I in P] for l in layers]) for P in f_Inet]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from utils.plotting import light_theme, shade\n",
    "\n",
    "plt.rcParams.update ({'font.size' : 14, \"text.usetex\": False})\n",
    "light_theme()\n",
    "\n",
    "layers = info.names\n",
    "\n",
    "convs = [True if 'conv'  in l else False for l in info.names]\n",
    "relus = [True if 'relu'  in l else False for l in info.names]\n",
    "pools = [True if 'mpool' in l else False for l in info.names]\n",
    "fulls = [True if 'fc'    in l else False for l in info.names]\n",
    "\n",
    "gmask = np.logical_or(convs, fulls)\n",
    "\n",
    "fig, ax = plt.subplots(figsize = (5, 4))\n",
    "\n",
    "# * Compute the average across channel colors\n",
    "colors = ['darkseagreen'] + ['cornflowerblue']\n",
    "\n",
    "# Plot the lines of the information profiles\n",
    "with np.testing.suppress_warnings() as sup:\n",
    "    sup.filter(RuntimeWarning)\n",
    "    ls = [ax.plot (np.arange(sum(gmask)), np.nanmean(mi / hx, axis = (1, 2))[gmask], c = c, lw = 2)[0] \n",
    "            for mi, hx, c in zip(MI, Hx, colors)]\n",
    "\n",
    "markers = ('o', '*')\n",
    "sizes   = (40, 120)\n",
    "convs = [True if 'conv'  in l else False for l in layers if 'conv' in l or 'fc' in l]\n",
    "fulls = [True if 'fc'    in l else False for l in layers if 'conv' in l or 'fc' in l]\n",
    "\n",
    "masks   = (convs, fulls)\n",
    "\n",
    "sizes   = (70, 160)\n",
    "for mask, m, s in zip(masks, markers, sizes):\n",
    "    par = {'marker' : m, 'ec' : 'k', 's' : s, 'zorder' : 10}\n",
    "    rpar = {'alpha' : 0.7, 's' : s, 'zorder' : 10, 'marker' : m, 'ec' : 'w'}\n",
    "\n",
    "    with np.testing.suppress_warnings() as sup:\n",
    "        sup.filter(RuntimeWarning)\n",
    "        [ax.scatter(np.arange(sum(gmask))[mask], np.nanmean(mi / hx, axis = (1, 2))[gmask][mask], color = c, **par) \n",
    "            for mi, hx, c in zip((MI), (Hx), (colors))]\n",
    "    \n",
    "sizes = (9, 13)\n",
    "\n",
    "ax.set_xlabel('layers')\n",
    "ax.set_xlim(left = -0.5)\n",
    "\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "ax.spines['bottom'].set_bounds(0., 15)\n",
    "ax.set_xticks(np.arange(sum(gmask)))\n",
    "ax.set_xticklabels([])\n",
    "\n",
    "ax.xaxis.set_tick_params(width = 2)\n",
    "ax.yaxis.set_tick_params(width = 2)\n",
    "for s in ax.spines.values(): s.set_linewidth(2)\n",
    "\n",
    "ax.set_ylabel('Normalized MI Orientation')\n",
    "\n",
    "fig.tight_layout()\n",
    "fig.savefig('results/Orientation_Experiment.png', dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "storepath = f'path_where_to_store/Experiment.pkl'\n",
    "\n",
    "bundle = {\n",
    "    'info' : info,\n",
    "    'seed' : seed,\n",
    "\n",
    "    'img_size' : img_size,\n",
    "    'pop_size' : pop_size,\n",
    "    'num_imgs' : num_imgs,\n",
    "    'num_exps' : num_exps,\n",
    "    'transfer' : transfer,\n",
    "\n",
    "    'batch_size' : batch_size,\n",
    "\n",
    "    'Exp_Idxs' : Exp_Idxs,\n",
    "    'Exp_unit' : Exp_unit,\n",
    "    'Exp_uRFs' : Exp_uRFs,\n",
    "    \n",
    "    'Feat' : Feat,\n",
    "    'Msis' : Msis,\n",
    "    'Ents' : Ents,\n",
    "    'Acts' : Acts,\n",
    "    \n",
    "    'f_nimg' : img_num,\n",
    "    'f_nunt' : pop_num,\n",
    "    'f_Feat' : f_Feat,\n",
    "    'f_Acts' : f_Acts,\n",
    "    \n",
    "    'f_Inet' : f_Inet\n",
    "}\n",
    "\n",
    "with open(storepath, 'wb') as f:\n",
    "    pickle.dump(bundle, f)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "d463e5fc653bab9d8e4f24359b8c55f1b66bd68a84a58e4a6e27c30ef7709220"
  },
  "kernelspec": {
   "display_name": "Python 3.8.10 64-bit ('SISSA': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
