{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch import Tensor\n",
    "#from .utils import load_state_dict_from_url\n",
    "from typing import Type, Any, Callable, Union, List, Optional\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "from torchvision import models, transforms\n",
    "import torch.backends.cudnn as cudnn\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "from torch.autograd import Variable\n",
    "import time\n",
    "import pandas as pd\n",
    "import torch.nn.init as init\n",
    "import math\n",
    "from scipy.spatial import ConvexHull\n",
    "import scipy.io\n",
    "\n",
    "cwd = os.getcwd()\n",
    "%matplotlib inline\n",
    "filescwd = cwd + \"\\\\files_npys\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit( model, optimizer, criterion, my_lr_scheduler, start_epoch, numEpochs, trainloader, testloader):\n",
    "    #train_loss = []\n",
    "    #train_acc = []\n",
    "    #test_loss = []\n",
    "    #test_acc = []\n",
    "    \n",
    "    for epoch in range( start_epoch, start_epoch + numEpochs):\n",
    "        _ = train( model, optimizer, criterion, epoch, trainloader)\n",
    "        #train_loss.append( _[0])\n",
    "        #train_acc.append( _[1])\n",
    "        #_t_ = test( model, criterion, epoch, testloader)\n",
    "        #test_loss.append( _t_[0])\n",
    "        #test_acc.append( _t_[1])\n",
    "        if my_lr_scheduler != None:\n",
    "            my_lr_scheduler.step()\n",
    "    return None#train_loss, train_acc#, test_loss, test_acc\n",
    "\n",
    "\n",
    "def train( model, optimizer, criterion, epochNo, loader):\n",
    "    model.train()\n",
    "    train_loss, correct, total = 0, 0, 0\n",
    "    \n",
    "    for batch_idx, (inputs, targets) in enumerate( loader):\n",
    "        inputs, targets = inputs.to(device), targets.to(device) # gpu\n",
    "        optimizer.zero_grad() # zero the parameter gradients\n",
    "        \n",
    "        # forward + backward + optimize\n",
    "        outputs = model(inputs)\n",
    "        if isinstance( outputs, list): # if we take intermediate outputs too\n",
    "            outputs = outputs[-1]\n",
    "        \n",
    "        loss = criterion( outputs, targets)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        train_loss += loss.item()\n",
    "        _, predicted = outputs.max(1)\n",
    "        total += targets.size(0)\n",
    "        correct += predicted.eq(targets).sum().item()\n",
    "    \n",
    "    train_loss = train_loss / (batch_idx+1)\n",
    "    train_acc = 100*correct/total\n",
    "    print('[Epoch no: %d] Training loss: %.5f  |  Training accuracy: %.3f%%, with Correct: %d, Total: %d' % (epochNo, train_loss, train_acc, correct, total))    \n",
    "    return (train_loss, train_acc)\n",
    "\n",
    "\n",
    "def test( model, criterion, epochNo, loader):\n",
    "    model.eval() \n",
    "    test_loss, correct, total = 0, 0, 0\n",
    "    \n",
    "    with torch.no_grad(): # testing\n",
    "        for batch_idx, (inputs, targets) in enumerate( loader):\n",
    "            inputs, targets = inputs.to(device), targets.to(device) # gpu\n",
    "\n",
    "            outputs = model(inputs)\n",
    "            if isinstance( outputs, list): # if we take intermediate outputs too\n",
    "                outputs = outputs[-1]\n",
    "            loss = criterion( outputs, targets)\n",
    "\n",
    "            test_loss += loss.item()\n",
    "            _, predicted = outputs.max(1) # max_value, the column index\n",
    "            total += targets.size(0)\n",
    "            correct += predicted.eq(targets).sum().item()\n",
    "\n",
    "    test_loss = test_loss / (batch_idx + 1)\n",
    "    test_acc = 100*correct/total\n",
    "    print('[Epoch no: %d]. Test loss: %.5f  |  Test accuracy: %.3f%%, with Correct: %d, Total: %d' % (epochNo, test_loss, test_acc, correct, total))\n",
    "    return (test_loss, test_acc)\n",
    "\n",
    "\n",
    "def predict( model, loader):\n",
    "    model.eval() \n",
    "    correct = 0\n",
    "    total = 0\n",
    "    input_list = []\n",
    "    target_list = []\n",
    "    preds_list = []\n",
    "    list_of_intermediate_outs = [] # elements are a list\n",
    "    initFlag = True\n",
    "    interFlag = False\n",
    "\n",
    "    \n",
    "    with torch.no_grad(): # testing\n",
    "        for batch_idx, (inputs, targets) in enumerate( loader):\n",
    "            inputs, targets = inputs.to(device), targets.to(device) # gpu\n",
    "            input_list.append( inputs)\n",
    "            target_list.append( targets)\n",
    "            \n",
    "            outputs = model(inputs)\n",
    "            if isinstance( outputs, list):\n",
    "                if initFlag:\n",
    "                    initFlag = False\n",
    "                    for i in range( len( outputs)):\n",
    "                        list_of_intermediate_outs.append( [])\n",
    "                interFlag = True\n",
    "                for i,layerOutput in enumerate( outputs):\n",
    "                    if i == len( outputs) - 1:\n",
    "                        list_of_intermediate_outs[i].append( F.softmax( layerOutput, dim=1))\n",
    "                    else:\n",
    "                        list_of_intermediate_outs[i].append( layerOutput)\n",
    "                \n",
    "                outputs = outputs[-1]\n",
    "            out_softmax = F.softmax( outputs, dim=1)\n",
    "            preds_list.append( out_softmax)\n",
    "            \n",
    "            _, predicted = outputs.max(1) # max_value, the column index\n",
    "            total += targets.size(0)\n",
    "            correct += predicted.eq(targets).sum().item()\n",
    "    \n",
    "    pred_acc = 100*correct/total\n",
    "    print('Predict accuracy: %.3f%%, with Correct: %d, Total: %d' % (pred_acc, correct, total))\n",
    "    \n",
    "    input_array = input_list[0].cpu().numpy()\n",
    "    for i in range(1, len(input_list)):\n",
    "        input_array = np.concatenate( (input_array, input_list[i].cpu().numpy()), axis=0)\n",
    "    \n",
    "    target_array = target_list[0].cpu().numpy()\n",
    "    for i in range( 1, len( target_list)):\n",
    "        target_array = np.concatenate( (target_array, target_list[i].cpu().numpy()), axis=0)\n",
    "    \n",
    "    preds_array = preds_list[0].cpu().numpy()\n",
    "    for i in range( 1, len( preds_list)):\n",
    "        preds_array = np.concatenate( (preds_array, preds_list[i].cpu().numpy()), axis=0)\n",
    "    \n",
    "    list_of_intermediate_outs_array = []\n",
    "    if interFlag:\n",
    "        for list_layerOutput in list_of_intermediate_outs:\n",
    "            layer_array = list_layerOutput[0].cpu().numpy()\n",
    "            for i in range( 1, len( list_layerOutput)):\n",
    "                layer_array = np.concatenate( (layer_array, list_layerOutput[i].cpu().numpy()), axis=0)\n",
    "            list_of_intermediate_outs_array.append( layer_array)\n",
    "    \n",
    "    return (preds_array, target_array, list_of_intermediate_outs_array)\n",
    "\n",
    "\n",
    "def to_categorical(y, num_classes):\n",
    "    return np.eye(num_classes, dtype='uint8')[y]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def getIntermediateOut( listOfLayers, loader, last=False):\n",
    "    intermediate_outs = []\n",
    "    initFlag = True\n",
    "    interFlag = False\n",
    "    \n",
    "    with torch.no_grad(): # testing\n",
    "        for batch_idx, (inputs, targets) in enumerate( loader):\n",
    "            inputs, targets = inputs.to(device), targets.to(device) # gpu\n",
    "            \n",
    "            inputToLayer = inputs\n",
    "            for k,layer in enumerate( listOfLayers):\n",
    "                if last and k == len(listOfLayers)-1:\n",
    "                    inputToLayer = torch.flatten( inputToLayer, 1)\n",
    "                out = layer( inputToLayer)\n",
    "                inputToLayer = out\n",
    "            \n",
    "            intermediate_outs.append( out)\n",
    "    \n",
    "    \n",
    "    intermediate_outs_array = intermediate_outs[0].cpu().numpy()\n",
    "    for i in range( 1, len( intermediate_outs)):\n",
    "        intermediate_outs_array = np.concatenate( (intermediate_outs_array, intermediate_outs[i].cpu().numpy()), axis=0)\n",
    "    \n",
    "    return intermediate_outs_array"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def distanceOf( features, classFeatureMeans):\n",
    "    \"\"\"\n",
    "    features: no_samples x num_filters x height x width\n",
    "    classFeatureMeans: num_classes x num_filters x height x width\n",
    "    \n",
    "    returns: distance of each sample to each classFeatureMean -> no_samples x num_classes\n",
    "    \"\"\"\n",
    "    dist_of_samples_to_class_means = np.empty( (features.shape[0], classFeatureMeans.shape[0]))\n",
    "    \n",
    "    b = classFeatureMeans.reshape( num_classes, -1)\n",
    "    for i in range( features.shape[0]):\n",
    "        a = features[i].reshape( -1)\n",
    "        dist_of_samples_to_class_means[i] = np.linalg.norm( a - b, axis=1)\n",
    "    \n",
    "    return dist_of_samples_to_class_means\n",
    "\n",
    "def classFeatureMeansOf( layerOutput, ground_truths):\n",
    "    classFeatureMeans = np.zeros( (num_classes,) + layerOutput.shape[1:])\n",
    "    for i in range( num_classes):\n",
    "        classFeatureMeans[i] = np.mean( layerOutput[ np.where( ground_truths == i)[0]], axis=0)\n",
    "    return classFeatureMeans"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',\n",
    "           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',\n",
    "           'wide_resnet50_2', 'wide_resnet101_2']\n",
    "\n",
    "model_urls = {\n",
    "    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',\n",
    "    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',\n",
    "    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',\n",
    "    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',\n",
    "    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',\n",
    "    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',\n",
    "    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',\n",
    "    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',\n",
    "    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',\n",
    "}\n",
    "\n",
    "\n",
    "def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:\n",
    "    \"\"\"3x3 convolution with padding\"\"\"\n",
    "    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n",
    "                     padding=dilation, groups=groups, bias=False, dilation=dilation)\n",
    "\n",
    "\n",
    "def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:\n",
    "    \"\"\"1x1 convolution\"\"\"\n",
    "    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n",
    "\n",
    "\n",
    "class BasicBlock(nn.Module):\n",
    "    expansion: int = 1\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        inplanes: int,\n",
    "        planes: int,\n",
    "        stride: int = 1,\n",
    "        downsample: Optional[nn.Module] = None,\n",
    "        groups: int = 1,\n",
    "        base_width: int = 64,\n",
    "        dilation: int = 1,\n",
    "        norm_layer: Optional[Callable[..., nn.Module]] = None\n",
    "    ) -> None:\n",
    "        super(BasicBlock, self).__init__()\n",
    "        if norm_layer is None:\n",
    "            norm_layer = nn.BatchNorm2d\n",
    "        if groups != 1 or base_width != 64:\n",
    "            raise ValueError('BasicBlock only supports groups=1 and base_width=64')\n",
    "        if dilation > 1:\n",
    "            raise NotImplementedError(\"Dilation > 1 not supported in BasicBlock\")\n",
    "        # Both self.conv1 and self.downsample layers downsample the input when stride != 1\n",
    "        self.conv1 = conv3x3(inplanes, planes, stride)\n",
    "        self.bn1 = norm_layer(planes)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.conv2 = conv3x3(planes, planes)\n",
    "        self.bn2 = norm_layer(planes)\n",
    "        self.downsample = downsample\n",
    "        self.stride = stride\n",
    "\n",
    "    def forward(self, x: Tensor) -> Tensor:\n",
    "        identity = x\n",
    "        \n",
    "        #print(x.shape)\n",
    "        out = self.conv1(x)\n",
    "        out = self.bn1(out)\n",
    "        out = self.relu(out)\n",
    "        #print(out.shape)\n",
    "        \n",
    "        out = self.conv2(out)\n",
    "        out = self.bn2(out)\n",
    "        #print(out.shape)\n",
    "        if self.downsample is not None:\n",
    "            identity = self.downsample(x)\n",
    "\n",
    "        out += identity\n",
    "        out = self.relu(out)\n",
    "        return out\n",
    "\n",
    "\n",
    "class Bottleneck(nn.Module):\n",
    "    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)\n",
    "    # while original implementation places the stride at the first 1x1 convolution(self.conv1)\n",
    "    # according to \"Deep residual learning for image recognition\"https://arxiv.org/abs/1512.03385.\n",
    "    # This variant is also known as ResNet V1.5 and improves accuracy according to\n",
    "    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.\n",
    "\n",
    "    expansion: int = 4\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        inplanes: int,\n",
    "        planes: int,\n",
    "        stride: int = 1,\n",
    "        downsample: Optional[nn.Module] = None,\n",
    "        groups: int = 1,\n",
    "        base_width: int = 64,\n",
    "        dilation: int = 1,\n",
    "        norm_layer: Optional[Callable[..., nn.Module]] = None\n",
    "    ) -> None:\n",
    "        super(Bottleneck, self).__init__()\n",
    "        if norm_layer is None:\n",
    "            norm_layer = nn.BatchNorm2d\n",
    "        width = int(planes * (base_width / 64.)) * groups\n",
    "        # Both self.conv2 and self.downsample layers downsample the input when stride != 1\n",
    "        self.conv1 = conv1x1(inplanes, width)\n",
    "        self.bn1 = norm_layer(width)\n",
    "        self.conv2 = conv3x3(width, width, stride, groups, dilation)\n",
    "        self.bn2 = norm_layer(width)\n",
    "        self.conv3 = conv1x1(width, planes * self.expansion)\n",
    "        self.bn3 = norm_layer(planes * self.expansion)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.downsample = downsample\n",
    "        self.stride = stride\n",
    "\n",
    "    def forward(self, x: Tensor) -> Tensor:\n",
    "        identity = x\n",
    "\n",
    "        out = self.conv1(x)\n",
    "        out = self.bn1(out)\n",
    "        out = self.relu(out)\n",
    "\n",
    "        out = self.conv2(out)\n",
    "        out = self.bn2(out)\n",
    "        out = self.relu(out)\n",
    "\n",
    "        out = self.conv3(out)\n",
    "        out = self.bn3(out)\n",
    "\n",
    "        if self.downsample is not None:\n",
    "            identity = self.downsample(x)\n",
    "\n",
    "        out += identity\n",
    "        out = self.relu(out)\n",
    "\n",
    "        return out\n",
    "\n",
    "\n",
    "class ResNet(nn.Module):\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        block: Type[Union[BasicBlock, Bottleneck]],\n",
    "        layers: List[int],\n",
    "        num_classes: int = 10,\n",
    "        zero_init_residual: bool = False,\n",
    "        groups: int = 1,\n",
    "        width_per_group: int = 64,\n",
    "        replace_stride_with_dilation: Optional[List[bool]] = None,\n",
    "        norm_layer: Optional[Callable[..., nn.Module]] = None,\n",
    "        wantIntermediateOutputs: bool = False\n",
    "    ) -> None:\n",
    "        super(ResNet, self).__init__()\n",
    "        if norm_layer is None:\n",
    "            norm_layer = nn.BatchNorm2d\n",
    "        self._norm_layer = norm_layer\n",
    "        \n",
    "        self.wantIntermediateOutputs = True\n",
    "        self.inplanes = 64\n",
    "        self.dilation = 1\n",
    "        if replace_stride_with_dilation is None:\n",
    "            # each element in the tuple indicates if we should replace\n",
    "            # the 2x2 stride with a dilated convolution instead\n",
    "            replace_stride_with_dilation = [False, False, False]\n",
    "        if len(replace_stride_with_dilation) != 3:\n",
    "            raise ValueError(\"replace_stride_with_dilation should be None \"\n",
    "                             \"or a 3-element tuple, got {}\".format(replace_stride_with_dilation))\n",
    "        self.groups = groups\n",
    "        self.base_width = width_per_group\n",
    "        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,\n",
    "                               bias=False)\n",
    "        self.bn1 = norm_layer(self.inplanes)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
    "        self.layer1 = self._make_layer(block, 64, layers[0])\n",
    "        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,\n",
    "                                       dilate=replace_stride_with_dilation[0])\n",
    "        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,\n",
    "                                       dilate=replace_stride_with_dilation[1])\n",
    "        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,\n",
    "                                       dilate=replace_stride_with_dilation[2])\n",
    "        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n",
    "        self.fc = nn.Linear(512 * block.expansion, num_classes)\n",
    "\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.Conv2d):\n",
    "                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n",
    "            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n",
    "                nn.init.constant_(m.weight, 1)\n",
    "                nn.init.constant_(m.bias, 0)\n",
    "\n",
    "        # Zero-initialize the last BN in each residual branch,\n",
    "        # so that the residual branch starts with zeros, and each residual block behaves like an identity.\n",
    "        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677\n",
    "        if zero_init_residual:\n",
    "            for m in self.modules():\n",
    "                if isinstance(m, Bottleneck):\n",
    "                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]\n",
    "                elif isinstance(m, BasicBlock):\n",
    "                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]\n",
    "\n",
    "    def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,\n",
    "                    stride: int = 1, dilate: bool = False) -> nn.Sequential:\n",
    "        norm_layer = self._norm_layer\n",
    "        downsample = None\n",
    "        previous_dilation = self.dilation\n",
    "        if dilate:\n",
    "            self.dilation *= stride\n",
    "            stride = 1\n",
    "        if stride != 1 or self.inplanes != planes * block.expansion:\n",
    "            downsample = nn.Sequential(\n",
    "                conv1x1(self.inplanes, planes * block.expansion, stride),\n",
    "                norm_layer(planes * block.expansion),\n",
    "            )\n",
    "\n",
    "        layers = []\n",
    "        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,\n",
    "                            self.base_width, previous_dilation, norm_layer))\n",
    "        self.inplanes = planes * block.expansion\n",
    "        for _ in range(1, blocks):\n",
    "            layers.append(block(self.inplanes, planes, groups=self.groups,\n",
    "                                base_width=self.base_width, dilation=self.dilation,\n",
    "                                norm_layer=norm_layer))\n",
    "\n",
    "        return nn.Sequential(*layers)\n",
    "\n",
    "    def _forward_impl(self, x: Tensor) -> Tensor:\n",
    "        # See note [TorchScript super()]\n",
    "        layerOutputs = []\n",
    "        #print(x.shape)\n",
    "        x = self.conv1(x)\n",
    "        #print(x.shape)\n",
    "        x = self.bn1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.maxpool(x)\n",
    "        #layerOutputs.append( x)\n",
    "        \n",
    "        #print(x.shape)\n",
    "        x = self.layer1(x)\n",
    "        #layerOutputs.append( x)\n",
    "        #print(x.shape)\n",
    "        \n",
    "        x = self.layer2(x)\n",
    "        #layerOutputs.append( x)\n",
    "        #print(x.shape)\n",
    "        \n",
    "        x = self.layer3(x)\n",
    "        #layerOutputs.append( x)\n",
    "        #print(x.shape)\n",
    "        \n",
    "        x = self.layer4(x)\n",
    "        #layerOutputs.append( x)\n",
    "        #print(x.shape)\n",
    "        \n",
    "        x = self.avgpool(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        #print(x.shape)\n",
    "        x = self.fc(x)\n",
    "        #print(x.shape)\n",
    "        layerOutputs.append( x)\n",
    "        \n",
    "        if self.wantIntermediateOutputs:\n",
    "            return layerOutputs\n",
    "        else:\n",
    "            return x\n",
    "\n",
    "    def forward(self, x: Tensor) -> Tensor:\n",
    "        return self._forward_impl(x)\n",
    "\n",
    "    def set_mid( self, val):\n",
    "        self.wantIntermediateOutputs = val\n",
    "\n",
    "def _resnet(\n",
    "    arch: str,\n",
    "    block: Type[Union[BasicBlock, Bottleneck]],\n",
    "    layers: List[int],\n",
    "    pretrained: bool,\n",
    "    progress: bool,\n",
    "    **kwargs: Any\n",
    ") -> ResNet:\n",
    "    model = ResNet(block, layers, **kwargs)\n",
    "    if pretrained:\n",
    "        state_dict = load_state_dict_from_url(model_urls[arch],\n",
    "                                              progress=progress)\n",
    "        model.load_state_dict(state_dict)\n",
    "    return model\n",
    "\n",
    "\n",
    "def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n",
    "    r\"\"\"ResNet-152 model from\n",
    "    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_.\n",
    "    Args:\n",
    "        pretrained (bool): If True, returns a model pre-trained on ImageNet\n",
    "        progress (bool): If True, displays a progress bar of the download to stderr\n",
    "    \"\"\"\n",
    "    return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, \n",
    "                   **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dir = cwd + \"\\\\cifar10_dataset\\\\train\" \n",
    "test_dir = cwd + \"\\\\cifar10_dataset\\\\test\"\n",
    "filescwd = cwd + \"\\\\files_npys\"\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "batch_size = 256\n",
    "num_classes = 10\n",
    "no_sams_train = 50000\n",
    "no_sams_test = 10000\n",
    "noExperiments = 10000\n",
    "\n",
    "transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])\n",
    "trainset = torchvision.datasets.CIFAR10(root=train_dir, train=True, download=True, transform=transform_train)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size, shuffle=False, num_workers=0)\n",
    "transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])\n",
    "testset = torchvision.datasets.CIFAR10(root=test_dir, train=False, download=True, transform=transform_test)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size, shuffle=False, num_workers=0)\n",
    "\n",
    "if device == 'cuda':\n",
    "    cudnn.benchmark = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "####################################################################\n",
    "\"\"\"1. Train the network on CIFAR-10.\"\"\"\n",
    "####################################################################\n",
    "\n",
    "resnet152 = resnet152().to(device)\n",
    "criterion_r = nn.CrossEntropyLoss()\n",
    "optimizer_r = optim.SGD( resnet152.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)\n",
    "my_lr_scheduler_r = optim.lr_scheduler.MultiStepLR( optimizer_r, milestones=[150,250], gamma=0.1)\n",
    "fit( resnet152, optimizer_r, criterion_r, my_lr_scheduler_r, 1, 350, trainloader, testloader)\n",
    "#test( resnet152, criterion_r, 351, testloader) # 0.573688 loss, 88.8 accuracy\n",
    "torch.save( resnet152, filescwd + \"\\\\resnet152\")\n",
    "torch.save( resnet152.state_dict(), filescwd + \"\\\\resnet152_state_dict\")\n",
    "#resnet152.load_state_dict( torch.load( filescwd + \"\\\\resnet152_state_dict\"))\n",
    "test( resnet152, criterion_r, 351, testloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "0: conv1\n",
    "4: bottleneck x3\n",
    "5: bottleneck x8\n",
    "6: bottleneck x36\n",
    "7: bottleneck x3\n",
    "9: linear\n",
    "\"\"\"\n",
    "for i,l in enumerate( list( resnet152.children())):\n",
    "    print(i,l)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "####################################################################\n",
    "\"\"\"2. Using the network, save the intermediate outputs for training set and test set.\n",
    "      Beware: This step requires a high usage of RAM.\"\"\"\n",
    "####################################################################\n",
    "\n",
    "layerNumbersToConsider = [0,4,5,6,7,9]\n",
    "# for saving intermediate outputs\n",
    "\n",
    "layerlist = list(resnet152.children())\n",
    "for layerNo in layerNumbersToConsider:\n",
    "    print(\"layerNo:\", layerNo)\n",
    "    if isinstance( layerlist[layerNo], nn.Sequential):\n",
    "        layerlist_for_sequential_sublayer = list( layerlist[layerNo].children())\n",
    "        for i in range( len( layerlist_for_sequential_sublayer)):\n",
    "            print(\"i:\", i)\n",
    "            layersToConsider = layerlist[0:layerNo] + layerlist_for_sequential_sublayer[0:i+1]\n",
    "            intermediateOut = getIntermediateOut( layersToConsider, trainloader)\n",
    "            np.save( filescwd + \"\\\\\" + \"resnet152_\" + str(layerNo) + \"_\" + str(i) + \"_train\", intermediateOut)\n",
    "            print(\"shape:\", intermediateOut.shape)\n",
    "            del intermediateOut\n",
    "    else:\n",
    "        layersToConsider = layerlist[0:layerNo+1]\n",
    "        last = (layerNo == (len(layerlist) - 1))\n",
    "        intermediateOut = getIntermediateOut( layersToConsider, trainloader, last)\n",
    "        np.save( filescwd + \"\\\\\" + \"resnet152_\" + str(layerNo) + \"_train\", intermediateOut)\n",
    "        print(\"shape:\", intermediateOut.shape)\n",
    "        del intermediateOut\n",
    "\n",
    "\n",
    "\n",
    "layerlist = list(resnet152.children())\n",
    "for layerNo in layerNumbersToConsider:\n",
    "    print(\"layerNo:\", layerNo)\n",
    "    if isinstance( layerlist[layerNo], nn.Sequential):\n",
    "        layerlist_for_sequential_sublayer = list( layerlist[layerNo].children())\n",
    "        for i in range( len( layerlist_for_sequential_sublayer)):\n",
    "            print(\"i:\", i)\n",
    "            layersToConsider = layerlist[0:layerNo] + layerlist_for_sequential_sublayer[0:i+1]\n",
    "            intermediateOut = getIntermediateOut( layersToConsider, testloader)\n",
    "            np.save( filescwd + \"\\\\\" + \"resnet152_\" + str(layerNo) + \"_\" + str(i) + \"_test\", intermediateOut)\n",
    "            print(\"shape:\", intermediateOut.shape)\n",
    "            del intermediateOut\n",
    "    else:\n",
    "        layersToConsider = layerlist[0:layerNo+1]\n",
    "        last = (layerNo == (len(layerlist) - 1))\n",
    "        intermediateOut = getIntermediateOut( layersToConsider, testloader, last)\n",
    "        np.save( filescwd + \"\\\\\" + \"resnet152_\" + str(layerNo) + \"_test\", intermediateOut)\n",
    "        print(\"shape:\", intermediateOut.shape)\n",
    "        del intermediateOut\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## architecture variables\n",
    "no_layers = 52\n",
    "individual_layerFlops = np.asarray([  3*(7*7*7*7+48)*16*16*64 + 10*(64*16*16 + 64*16*16 + 64*16*16-1),\n",
    "                            64*(1*1*1*1)*8*8*64 + 64*(3*3*3*3+8)*8*8*64 + 64*(1*1*1*1)*8*8*256 + 64*(1*1*1*1)*8*8*256 + 10*(256*8*8 + 256*8*8 + 256*8*8-1),\n",
    "                            256*(1*1*1*1)*8*8*64 + 64*(3*3*3*3+8)*8*8*64 + 64*(1*1*1*1)*8*8*256 + 10*(256*8*8 + 256*8*8 + 256*8*8-1),\n",
    "                            256*(1*1*1*1)*8*8*64 + 64*(3*3*3*3+8)*8*8*64 + 64*(1*1*1*1)*8*8*256 + 10*(256*8*8 + 256*8*8 + 256*8*8-1),\n",
    "                                256*(1*1*1*1)*8*8*128 + 128*(3*3*3*3+8)*4*4*128 + 128*(1*1*1*1)*4*4*512 + 256*(1*1*1*1)*4*4*512 + 10*(512*4*4 + 512*4*4 + 512*4*4-1),\n",
    "                                512*(1*1*1*1)*4*4*128 + 128*(3*3*3*3+8)*4*4*128 + 128*(1*1*1*1)*4*4*512 + 10*(512*4*4 + 512*4*4 + 512*4*4-1),\n",
    "                                512*(1*1*1*1)*4*4*128 + 128*(3*3*3*3+8)*4*4*128 + 128*(1*1*1*1)*4*4*512 + 10*(512*4*4 + 512*4*4 + 512*4*4-1),\n",
    "                                512*(1*1*1*1)*4*4*128 + 128*(3*3*3*3+8)*4*4*128 + 128*(1*1*1*1)*4*4*512 + 10*(512*4*4 + 512*4*4 + 512*4*4-1),\n",
    "                                512*(1*1*1*1)*4*4*128 + 128*(3*3*3*3+8)*4*4*128 + 128*(1*1*1*1)*4*4*512 + 10*(512*4*4 + 512*4*4 + 512*4*4-1),\n",
    "                                512*(1*1*1*1)*4*4*128 + 128*(3*3*3*3+8)*4*4*128 + 128*(1*1*1*1)*4*4*512 + 10*(512*4*4 + 512*4*4 + 512*4*4-1),\n",
    "                                512*(1*1*1*1)*4*4*128 + 128*(3*3*3*3+8)*4*4*128 + 128*(1*1*1*1)*4*4*512 + 10*(512*4*4 + 512*4*4 + 512*4*4-1),\n",
    "                                512*(1*1*1*1)*4*4*128 + 128*(3*3*3*3+8)*4*4*128 + 128*(1*1*1*1)*4*4*512 + 10*(512*4*4 + 512*4*4 + 512*4*4-1),\n",
    "                            512*(1*1*1*1)*4*4*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 512*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                            1024*(1*1*1*1)*2*2*256 + 256*(3*3*3*3+8)*2*2*256 + 256*(1*1*1*1)*2*2*1024 + 10*(1024*2*2 + 1024*2*2 + 1024*2*2-1),\n",
    "                                    1024*(1*1*1*1)*2*2*512 + 512*(3*3*3*3+8)*1*1*512 + 512*(1*1*1*1)*1*1*2048 + 1024*(1*1*1*1)*1*1*2048 + 10*(2048*1*1 + 2048*1*1 + 2048*1*1-1),\n",
    "                                    2048*(1*1*1*1)*1*1*512 + 512*(3*3*3*3+8)*1*1*512 + 512*(1*1*1*1)*1*1*2048 + 10*(2048*1*1 + 2048*1*1 + 2048*1*1-1),\n",
    "                                    2048*(1*1*1*1)*1*1*512 + 512*(3*3*3*3+8)*1*1*512 + 512*(1*1*1*1)*1*1*2048 + 10*(2048*1*1 + 2048*1*1 + 2048*1*1-1),\n",
    "                             (2048+2047)*10\n",
    "                          ])\n",
    "savethis = individual_layerFlops[1]\n",
    "individual_layerFlops = individual_layerFlops / individual_layerFlops[1] # to prevent overflow\n",
    "\n",
    "divider = np.sum( individual_layerFlops) * savethis \n",
    "individual_layerFlops = individual_layerFlops / np.sum( individual_layerFlops)\n",
    "\n",
    "thresholding_layerFlops = np.empty( [no_layers])\n",
    "for i in range( no_layers):\n",
    "    thresholding_layerFlops[i] = np.sum( individual_layerFlops[:i+1])\n",
    "\n",
    "print(\"resnet152\")\n",
    "print(\"Individual:\", individual_layerFlops)\n",
    "print(\"Thresholding:\", thresholding_layerFlops)\n",
    "print(\"divider:\", divider)\n",
    "\n",
    "\n",
    "predicts_train, ground_truths_train, list_of_intermediate_outs_array_train = predict( resnet152, trainloader)\n",
    "predicts_test, ground_truths_test, list_of_intermediate_outs_array_test = predict( resnet152, testloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layerOutSizes = [(64,16,16),\n",
    "                (256,8,8),\n",
    "                (256,8,8),\n",
    "                (256,8,8),\n",
    "                (512,4,4),\n",
    "                (512,4,4),\n",
    "                (512,4,4),\n",
    "                (512,4,4),\n",
    "                (512,4,4),\n",
    "                (512,4,4),\n",
    "                (512,4,4),\n",
    "                (512,4,4),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (1024,2,2),\n",
    "                (2048,1,1),\n",
    "                (2048,1,1),\n",
    "                (2048,1,1),\n",
    "                (2048)\n",
    "                ]\n",
    "len( layerOutSizes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "####################################################################\n",
    "\"\"\"3. On training set, calculate the means and distances of samples to those means.\"\"\"\n",
    "####################################################################\n",
    "\n",
    "\n",
    "## get the outputs\n",
    "list_of_layerouts_train = []\n",
    "list_of_classFeatureMeans_train = []\n",
    "list_of_dist_of_layerouts_to_class_means_train = []\n",
    "\n",
    "layerlist = list(resnet152.children())\n",
    "\n",
    "for layerNo in layerNumbersToConsider:\n",
    "    print(\"layerNo:\", layerNo)\n",
    "    if isinstance( layerlist[layerNo], nn.Sequential):\n",
    "        layerlist_for_sequential_sublayer = list( layerlist[layerNo].children())\n",
    "        for i in range( len( layerlist_for_sequential_sublayer)):\n",
    "            print(\"i:\", i)\n",
    "            intermediateOut = np.load( filescwd + \"\\\\\" + \"resnet152_\" + str(layerNo) + \"_\" + str(i) + \"_train.npy\")\n",
    "            #list_of_layerouts_train.append( intermediateOut)\n",
    "            classMeans = classFeatureMeansOf( intermediateOut, ground_truths_train)\n",
    "            list_of_classFeatureMeans_train.append( classMeans)\n",
    "            list_of_dist_of_layerouts_to_class_means_train.append( distanceOf( intermediateOut, classMeans))\n",
    "            del intermediateOut\n",
    "    else:\n",
    "        intermediateOut = np.load( filescwd + \"\\\\\" + \"resnet152_\" + str(layerNo) + \"_train.npy\")\n",
    "        #list_of_layerouts_train.append( intermediateOut)\n",
    "        classMeans = classFeatureMeansOf( intermediateOut, ground_truths_train)\n",
    "        list_of_classFeatureMeans_train.append( classMeans)\n",
    "        list_of_dist_of_layerouts_to_class_means_train.append( distanceOf( intermediateOut, classMeans))\n",
    "        del intermediateOut\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "for i in range( len( list_of_classFeatureMeans_train)):\n",
    "    np.save( filescwd + \"\\\\\" + \"resnet152_list_of_classFeatureMeans_train_\" + str(i), list_of_classFeatureMeans_train[i])\n",
    "\n",
    "\n",
    "for i in range( len( list_of_dist_of_layerouts_to_class_means_train)):\n",
    "    np.save( filescwd + \"\\\\\" + \"resnet152_dist_of_layerouts_to_class_means_train_\" + str(i), list_of_dist_of_layerouts_to_class_means_train[i])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "actuallayernum = list( np.arange( no_layers))\n",
    "\"\"\"\n",
    "list_of_classFeatureMeans_train = []\n",
    "for i in actuallayernum:#range( no_layers):\n",
    "    list_of_classFeatureMeans_train.append( np.load( filescwd + \"\\\\\" + \"resnet152_list_of_classFeatureMeans_train_\" + str(i) + \".npy\"))\n",
    "\n",
    "\n",
    "list_of_dist_of_layerouts_to_class_means_train = []\n",
    "for i actuallayernum:#in range( no_layers):\n",
    "    list_of_dist_of_layerouts_to_class_means_train.append( np.load( filescwd + \"\\\\\" + \"resnet152_dist_of_layerouts_to_class_means_train_\" + str(i) + \".npy\"))\n",
    "\"\"\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layers_accuracy_train = np.empty([ no_layers, no_sams_train])\n",
    "for i in range( no_layers):\n",
    "    layers_accuracy_train[i] = (np.argmin( list_of_dist_of_layerouts_to_class_means_train[i], axis=1) == ground_truths_train).astype(int)\n",
    "\n",
    "print( np.sum( layers_accuracy_train, axis=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "####################################################################\n",
    "\"\"\"4. Apply normalization and softmax.\"\"\"\n",
    "####################################################################\n",
    "\n",
    "meanOfDistsClasses = []\n",
    "for i in range( len( list_of_dist_of_layerouts_to_class_means_train)):\n",
    "    dists = list_of_dist_of_layerouts_to_class_means_train[i]\n",
    "    mean = np.sum( dists, axis=0)#np.mean( dists, axis=0)\n",
    "    meanOfDistsClasses.append( mean)\n",
    "    for j in range( 10):\n",
    "        dists[:,j] /= mean[j]\n",
    "    list_of_dist_of_layerouts_to_class_means_train[i] = dists\n",
    "\n",
    "def softmax( distances):\n",
    "    #distances: 50000x10\n",
    "    probs = np.empty(distances.shape)\n",
    "    for i in range( probs.shape[0]):\n",
    "        neg_dists = -1*distances[i]\n",
    "        probs[i] = np.exp( neg_dists) / np.sum( np.exp( neg_dists))\n",
    "    return probs\n",
    "\n",
    "softmax_list_of_dist_of_layerouts_to_class_means_train = []\n",
    "for i in range( len( list_of_dist_of_layerouts_to_class_means_train)):\n",
    "    softmax_list_of_dist_of_layerouts_to_class_means_train.append( softmax( list_of_dist_of_layerouts_to_class_means_train[i]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "####################################################################\n",
    "\"\"\"5. Perform thresholding on training set to choose the best threshold values.\"\"\"\n",
    "####################################################################\n",
    "\n",
    "import time\n",
    "start_time = time.clock()\n",
    "\n",
    "flop_granularity = 0.001\n",
    "maxstepcount = 40\n",
    "thresholds_targetflops_actualflops_acc_numsteps = np.zeros([int(1/flop_granularity + 1), 4+no_layers-1])\n",
    "noExperiments = thresholds_targetflops_actualflops_acc_numsteps.shape[0]\n",
    "thresholds_flops_accuracies = np.empty( [noExperiments, no_layers-1+2])\n",
    "accuracies_train = np.zeros(noExperiments)\n",
    "flops_train = np.zeros(noExperiments)\n",
    "\n",
    "for targetflop in np.arange(0, 1.0+flop_granularity, flop_granularity):\n",
    "    index = int(targetflop / flop_granularity)\n",
    "    expNo = index\n",
    "    print(\"targetflop\", targetflop)\n",
    "    threshold = np.random.uniform(0, 1, no_layers-1)\n",
    "    lower_th = [0]*len(threshold)\n",
    "    upper_th = [1.0]*len(threshold)\n",
    "    doneforthistargetflop = False\n",
    "    stepcount = 0\n",
    "    \n",
    "    while not doneforthistargetflop:\n",
    "        \n",
    "        layers_trust_train = np.empty([ no_layers, no_sams_train])\n",
    "        for i in range( no_layers - 1):\n",
    "            layers_trust_train[i] = np.max( softmax_list_of_dist_of_layerouts_to_class_means_train[i], axis=1) > threshold[i]\n",
    "        layers_trust_train[no_layers - 1] = np.ones( no_sams_train)\n",
    "\n",
    "        layers_choice_mask_train = layers_trust_train\n",
    "        for i in range(1, no_layers):\n",
    "            for j in range( i):\n",
    "                layers_choice_mask_train[i] = layers_choice_mask_train[i] * (1 - layers_choice_mask_train[j])\n",
    "\n",
    "        to_layer_num_train = np.argmax( layers_choice_mask_train, axis=0)\n",
    "        to_layer_num_train = to_categorical( to_layer_num_train, no_layers)\n",
    "        flops = np.sum( to_layer_num_train*thresholding_layerFlops) / no_sams_train\n",
    "        correct = np.sum( layers_accuracy_train*np.transpose( to_layer_num_train, (1,0))) / no_sams_train\n",
    "        stepcount += 1\n",
    "        \n",
    "        for ii in range( len( threshold)):\n",
    "            print(\"---\",ii)\n",
    "            if abs(flops - targetflop) < flop_granularity or stepcount == maxstepcount:\n",
    "                thresholds_targetflops_actualflops_acc_numsteps[index,:len(threshold)] = threshold\n",
    "                thresholds_targetflops_actualflops_acc_numsteps[index,len(threshold):] = [targetflop, flops, correct, stepcount]\n",
    "                doneforthistargetflop = True\n",
    "                break\n",
    "            elif flops > targetflop:\n",
    "                upper_th[ii] = threshold[ii]\n",
    "                threshold[ii] = (lower_th[ii] + threshold[ii]) / 2\n",
    "                lower_th[ii] = lower_th[ii]\n",
    "            elif flops < targetflop:\n",
    "                lower_th[ii] = threshold[ii]\n",
    "                threshold[ii] = (threshold[ii] + upper_th[ii]) / 2\n",
    "                upper_th[ii] = upper_th[ii]\n",
    "\n",
    "    flops_train[expNo] = flops\n",
    "    accuracies_train[expNo] = correct\n",
    "    thresholds_flops_accuracies[expNo, 0:2] = flops_train[expNo], accuracies_train[expNo]\n",
    "    thresholds_flops_accuracies[expNo, 2:] = threshold\n",
    "    \n",
    "end = time.clock()\n",
    "print(end-start_time, \"time elapsed\")\n",
    "\n",
    "\"\"\"random thresholding, choosing bests\n",
    "was an alternative approach\n",
    "thresholds_flops_accuracies = np.empty( [noExperiments, no_layers-1+2])\n",
    "accuracies_train = np.zeros(noExperiments)\n",
    "flops_train = np.zeros(noExperiments)\n",
    "\n",
    "for expNo in range( noExperiments):\n",
    "    print(expNo)\n",
    "    thresholds = np.random.uniform(0, 1, no_layers-1)\n",
    "    \n",
    "    layers_trust_train = np.empty([ no_layers, no_sams_train])\n",
    "    for i in range( no_layers - 1):\n",
    "        layers_trust_train[i] = np.max( softmax_list_of_dist_of_layerouts_to_class_means_train[i], axis=1) > thresholds[i]\n",
    "    layers_trust_train[no_layers - 1] = np.ones( no_sams_train)\n",
    "\n",
    "    layers_choice_mask_train = layers_trust_train\n",
    "    for i in range(1, no_layers):\n",
    "        for j in range( i):\n",
    "            layers_choice_mask_train[i] = layers_choice_mask_train[i] * (1 - layers_choice_mask_train[j])\n",
    "    \n",
    "    to_layer_num_train = np.argmax( layers_choice_mask_train, axis=0)\n",
    "    to_layer_num_train = to_categorical( to_layer_num_train, no_layers)\n",
    "    flops = np.sum( to_layer_num_train*thresholding_layerFlops) / no_sams_train\n",
    "    correct = np.sum( layers_accuracy_train*np.transpose( to_layer_num_train, (1,0))) / no_sams_train\n",
    "\n",
    "    flops_train[expNo] = flops\n",
    "    accuracies_train[expNo] = correct\n",
    "    thresholds_flops_accuracies[expNo, 0:2] = flops_train[expNo], accuracies_train[expNo]\n",
    "    thresholds_flops_accuracies[expNo, 2:] = thresholds\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# convex hull for threshold\n",
    "points = np.column_stack( (flops_train, accuracies_train))\n",
    "hull = ConvexHull(points)\n",
    "descending_side = np.sort(hull.vertices)\n",
    "\n",
    "fig = plt.figure()\n",
    "ax = fig.gca()\n",
    "ax.grid(True)\n",
    "plt.plot( flops_train, accuracies_train, 'ro', alpha=0.2)\n",
    "plt.plot(points[descending_side,0], points[descending_side,1], 'go-', lw=3)\n",
    "plt.ylabel('Accuracy')\n",
    "plt.xlabel('FLOPs')\n",
    "plt.title(\"On training set\")\n",
    "#plt.xlim(0, 1)\n",
    "#plt.ylim(0, 1)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "####################################################################\n",
    "\"\"\"6. On test set, calculate distances of samples to the means calculated on training set.\"\"\"\n",
    "####################################################################\n",
    "\n",
    "## get the outputs\n",
    "list_of_layerouts_test = []\n",
    "list_of_classFeatureMeans_test = list_of_classFeatureMeans_train\n",
    "list_of_dist_of_layerouts_to_class_means_test = []\n",
    "\n",
    "layerlist = list(resnet152.children())\n",
    "layerNumbersToConsider = [0,4,5,6,7,9]\n",
    "count = 0\n",
    "for layerNo in layerNumbersToConsider:\n",
    "    print(\"layerNo:\", layerNo)\n",
    "    if isinstance( layerlist[layerNo], nn.Sequential):\n",
    "        layerlist_for_sequential_sublayer = list( layerlist[layerNo].children())\n",
    "        for i in range( len( layerlist_for_sequential_sublayer)):\n",
    "            print(\"i:\", i)\n",
    "            intermediateOut = np.load( filescwd + \"\\\\\" + \"resnet152_\" + str(layerNo) + \"_\" + str(i) + \"_test.npy\")\n",
    "            list_of_layerouts_test.append( intermediateOut)\n",
    "            list_of_dist_of_layerouts_to_class_means_test.append( distanceOf( intermediateOut, list_of_classFeatureMeans_test[count]))\n",
    "            count += 1\n",
    "            del intermediateOut\n",
    "    else:\n",
    "        intermediateOut = np.load( filescwd + \"\\\\\" + \"resnet152_\" + str(layerNo) + \"_test.npy\")\n",
    "        list_of_layerouts_test.append( intermediateOut)\n",
    "        list_of_dist_of_layerouts_to_class_means_test.append( distanceOf( intermediateOut, list_of_classFeatureMeans_test[count]))\n",
    "        count += 1\n",
    "        del intermediateOut\n",
    "\n",
    "layers_accuracy_test = np.empty([ no_layers, no_sams_test])\n",
    "for i in range( no_layers):\n",
    "    layers_accuracy_test[i] = (np.argmin( list_of_dist_of_layerouts_to_class_means_test[i], axis=1) == ground_truths_test).astype(int)\n",
    "\n",
    "print( np.sum( layers_accuracy_test, axis=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "####################################################################\n",
    "\"\"\"7. Apply normalization and softmax.\"\"\"\n",
    "####################################################################\n",
    "\n",
    "for i in range( len( list_of_dist_of_layerouts_to_class_means_test)):\n",
    "    dists = list_of_dist_of_layerouts_to_class_means_test[i]\n",
    "    mean = meanOfDistsClasses[i]\n",
    "    for j in range( 10):\n",
    "        dists[:,j] /= mean[j]\n",
    "    list_of_dist_of_layerouts_to_class_means_test[i] = dists\n",
    "\n",
    "softmax_list_of_dist_of_layerouts_to_class_means_test = []\n",
    "for i in range( len( list_of_dist_of_layerouts_to_class_means_test)):\n",
    "    softmax_list_of_dist_of_layerouts_to_class_means_test.append( softmax( list_of_dist_of_layerouts_to_class_means_test[i]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "####################################################################\n",
    "\"\"\"8. Perform thresholding on test set using the threshold values chosen on training set.\"\"\"\n",
    "####################################################################\n",
    "\n",
    "## thresholding\n",
    "accuracies_test = np.zeros(noExperiments)\n",
    "flops_test = np.zeros(noExperiments)\n",
    "\n",
    "\n",
    "for expNo in range( noExperiments):\n",
    "    print(expNo)\n",
    "    thresholds = thresholds_flops_accuracies[expNo, 2:]\n",
    "    \n",
    "    layers_trust_test = np.empty([ no_layers, no_sams_test])\n",
    "    for i in range( no_layers - 1):\n",
    "        layers_trust_test[i] = np.max( softmax_list_of_dist_of_layerouts_to_class_means_test[i], axis=1) > thresholds[i]\n",
    "    layers_trust_test[no_layers - 1] = np.ones( no_sams_test)\n",
    "\n",
    "    layers_choice_mask_test = layers_trust_test\n",
    "    for i in range(1, no_layers):\n",
    "        for j in range( i):\n",
    "            layers_choice_mask_test[i] = layers_choice_mask_test[i] * (1 - layers_choice_mask_test[j])\n",
    "    \n",
    "    to_layer_num_test = np.argmax( layers_choice_mask_test, axis=0)\n",
    "    to_layer_num_test = to_categorical( to_layer_num_test, no_layers)\n",
    "    flops = np.sum( to_layer_num_test*thresholding_layerFlops) / no_sams_test\n",
    "    correct = np.sum( layers_accuracy_test*np.transpose( to_layer_num_test, (1,0))) / no_sams_test\n",
    "    \n",
    "    flops_test[expNo] = flops\n",
    "    accuracies_test[expNo] = correct"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# convex hull for threshold\n",
    "points = np.column_stack( (flops_test, accuracies_test))\n",
    "\n",
    "fig = plt.figure()\n",
    "ax = fig.gca()\n",
    "ax.grid(True)\n",
    "plt.plot( flops_test, accuracies_test, 'ro', alpha=0.2)\n",
    "plt.plot( points[descending_side,0], points[descending_side,1], 'go-', lw=3)\n",
    "plt.ylabel('Accuracy')\n",
    "plt.xlabel('FLOPs')\n",
    "plt.title(\"On test set\")\n",
    "#plt.xlim(0, 1)\n",
    "#plt.ylim(0, 1)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"The figure corresponds to the green curve in Figure 2 in the paper.\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "celltoolbar": "Raw Cell Format",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
