{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compress Last Layer Feature Map Activations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import pickle\n",
    "import time\n",
    "import scipy\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from ANNs import CNN\n",
    "from functions import *\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "netC = load_cnn('mnist_normal/cnn.pth')\n",
    "# netC = load_cnn('mnist_compressed_fams/cnn.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, test_loader = load_dataloaders(batch_size=32,shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 50\n",
    "DEVICE = torch.device(\"cuda\" if (torch.cuda.is_available() and ngpu > 0) else \"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Find NB Feature Maps for Each Class (less is better)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "mse_loss = torch.nn.MSELoss()\n",
    "l1_loss = torch.nn.L1Loss()\n",
    "cce_loss = torch.nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(netC.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_target_C(C, x, weights, preds):\n",
    "    \n",
    "    target_C = torch.zeros(C.shape)\n",
    "\n",
    "    # Iterate batch\n",
    "    for i in range(len(C)):\n",
    "\n",
    "        # Get contributions\n",
    "        c = x[i] * weights[preds[i].item()]  \n",
    "        # Get top 10 features\n",
    "        topn = np.argsort(c.clone().detach().cpu().numpy())[-10:]\n",
    "\n",
    "        # Select most diverse features\n",
    "        dists = np.zeros(( len(topn) ))\n",
    "        most_nb_vec = argmax_loc = (C[i][topn[-1]]==torch.max(C[i][topn[-1]])).nonzero()[0]\n",
    "        for j in range(0, len(topn) ):\n",
    "            argmax_loc = (C[i][topn[j]]==torch.max(C[i][topn[j]])).nonzero()[0]\n",
    "            dist = sum(abs(most_nb_vec - argmax_loc))\n",
    "            dists[j] = dist\n",
    "\n",
    "        # Two most diverse features\n",
    "        top3 = np.argsort(dists)[-3:]\n",
    "        top3[0] = 9\n",
    "\n",
    "        for j in top3:\n",
    "            argmax_loc = (C[i][topn[j]]==torch.max(C[i][topn[j]])).nonzero()[0]\n",
    "            \n",
    "    return target_C"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CCE Loss   : 0.0057822903618216515\n",
      "Kernel Loss: 0.2906612157821655\n",
      " \n",
      "CCE Loss   : 0.01589447632431984\n",
      "Kernel Loss: 0.08464045077562332\n",
      " \n",
      "CCE Loss   : 0.01805870421230793\n",
      "Kernel Loss: 0.06932646036148071\n",
      " \n",
      "CCE Loss   : 0.025127127766609192\n",
      "Kernel Loss: 0.05859549343585968\n",
      " \n",
      "CCE Loss   : 0.005459838546812534\n",
      "Kernel Loss: 0.05601439252495766\n",
      " \n",
      "0\n",
      "CCE Loss   : 0.01212992426007986\n",
      "Kernel Loss: 0.05053924024105072\n",
      " \n",
      "CCE Loss   : 0.004735975060611963\n",
      "Kernel Loss: 0.0475502535700798\n",
      " \n",
      "CCE Loss   : 0.008270442485809326\n",
      "Kernel Loss: 0.04427185282111168\n",
      " \n",
      "CCE Loss   : 0.0065114451572299\n",
      "Kernel Loss: 0.04010869935154915\n",
      " \n",
      "CCE Loss   : 0.01191624067723751\n",
      "Kernel Loss: 0.0393734872341156\n",
      " \n",
      "1\n",
      "CCE Loss   : 0.0073858024552464485\n",
      "Kernel Loss: 0.03887660428881645\n",
      " \n",
      "CCE Loss   : 0.005084926262497902\n",
      "Kernel Loss: 0.03431219607591629\n",
      " \n",
      "CCE Loss   : 0.024286985397338867\n",
      "Kernel Loss: 0.034470636397600174\n",
      " \n",
      "CCE Loss   : 0.006291162222623825\n",
      "Kernel Loss: 0.033759988844394684\n",
      " \n",
      "CCE Loss   : 0.00219138921238482\n",
      "Kernel Loss: 0.029365768656134605\n",
      " \n",
      "2\n",
      "CCE Loss   : 0.0031818770803511143\n",
      "Kernel Loss: 0.029704639688134193\n",
      " \n",
      "CCE Loss   : 0.011056376621127129\n",
      "Kernel Loss: 0.028827786445617676\n",
      " \n",
      "CCE Loss   : 0.008607449010014534\n",
      "Kernel Loss: 0.026967087760567665\n",
      " \n",
      "CCE Loss   : 0.0236019566655159\n",
      "Kernel Loss: 0.02600903809070587\n",
      " \n",
      "CCE Loss   : 0.0037756466772407293\n",
      "Kernel Loss: 0.0235774964094162\n",
      " \n",
      "3\n",
      "CCE Loss   : 0.05188600346446037\n",
      "Kernel Loss: 0.023883095011115074\n",
      " \n",
      "CCE Loss   : 0.01548693422228098\n",
      "Kernel Loss: 0.02334311045706272\n",
      " \n",
      "CCE Loss   : 0.00459334859624505\n",
      "Kernel Loss: 0.02341477945446968\n",
      " \n",
      "CCE Loss   : 0.003949545323848724\n",
      "Kernel Loss: 0.020236700773239136\n",
      " \n",
      "CCE Loss   : 0.0028348935302346945\n",
      "Kernel Loss: 0.020195458084344864\n",
      " \n",
      "4\n",
      "CCE Loss   : 0.001587970880791545\n",
      "Kernel Loss: 0.02123882994055748\n",
      " \n",
      "CCE Loss   : 0.005715507082641125\n",
      "Kernel Loss: 0.017735600471496582\n",
      " \n",
      "CCE Loss   : 0.0030346475541591644\n",
      "Kernel Loss: 0.018562210723757744\n",
      " \n",
      "CCE Loss   : 0.002175625180825591\n",
      "Kernel Loss: 0.01800083927810192\n",
      " \n",
      "CCE Loss   : 0.0022566088009625673\n",
      "Kernel Loss: 0.01781928725540638\n",
      " \n",
      "5\n",
      "CCE Loss   : 0.003809885587543249\n",
      "Kernel Loss: 0.017647558823227882\n",
      " \n",
      "CCE Loss   : 0.0021212161518633366\n",
      "Kernel Loss: 0.01584063097834587\n",
      " \n",
      "CCE Loss   : 0.0022528539411723614\n",
      "Kernel Loss: 0.017921553924679756\n",
      " \n",
      "CCE Loss   : 0.002787994686514139\n",
      "Kernel Loss: 0.015604302287101746\n",
      " \n",
      "CCE Loss   : 0.0008199077565222979\n",
      "Kernel Loss: 0.016976529732346535\n",
      " \n",
      "6\n",
      "CCE Loss   : 0.0015212728176265955\n",
      "Kernel Loss: 0.015305058099329472\n",
      " \n",
      "CCE Loss   : 0.0020181939471513033\n",
      "Kernel Loss: 0.01418464258313179\n",
      " \n",
      "CCE Loss   : 0.0019844933412969112\n",
      "Kernel Loss: 0.013724906370043755\n",
      " \n",
      "CCE Loss   : 0.0022058880422264338\n",
      "Kernel Loss: 0.012910568155348301\n",
      " \n",
      "CCE Loss   : 0.0008936263038776815\n",
      "Kernel Loss: 0.01458556205034256\n",
      " \n",
      "7\n",
      "CCE Loss   : 0.0025331107899546623\n",
      "Kernel Loss: 0.014982299879193306\n",
      " \n",
      "CCE Loss   : 0.0009814855875447392\n",
      "Kernel Loss: 0.013377255760133266\n",
      " \n",
      "CCE Loss   : 0.0022117521148175\n",
      "Kernel Loss: 0.013360758312046528\n",
      " \n",
      "CCE Loss   : 0.002986871637403965\n",
      "Kernel Loss: 0.011939839459955692\n",
      " \n",
      "CCE Loss   : 0.001226545311510563\n",
      "Kernel Loss: 0.012523158453404903\n",
      " \n",
      "8\n",
      "CCE Loss   : 0.0012576243607327342\n",
      "Kernel Loss: 0.01270021591335535\n",
      " \n",
      "CCE Loss   : 0.0005252861883491278\n",
      "Kernel Loss: 0.012072804383933544\n",
      " \n",
      "CCE Loss   : 0.001171339419670403\n",
      "Kernel Loss: 0.011676796711981297\n",
      " \n",
      "CCE Loss   : 0.0022540355566889048\n",
      "Kernel Loss: 0.011165455915033817\n",
      " \n",
      "CCE Loss   : 0.0016921937931329012\n",
      "Kernel Loss: 0.012075149454176426\n",
      " \n",
      "9\n",
      "CCE Loss   : 0.001579879317432642\n",
      "Kernel Loss: 0.011579615995287895\n",
      " \n",
      "CCE Loss   : 0.0011583005543798208\n",
      "Kernel Loss: 0.01129170972853899\n",
      " \n",
      "CCE Loss   : 0.009684680961072445\n",
      "Kernel Loss: 0.009812450036406517\n",
      " \n",
      "CCE Loss   : 0.0013864071806892753\n",
      "Kernel Loss: 0.010608463548123837\n",
      " \n",
      "CCE Loss   : 0.001962203998118639\n",
      "Kernel Loss: 0.010643843561410904\n",
      " \n",
      "10\n",
      "CCE Loss   : 0.0007882254431024194\n",
      "Kernel Loss: 0.010006546974182129\n",
      " \n",
      "CCE Loss   : 0.0007641204865649343\n",
      "Kernel Loss: 0.00960945338010788\n",
      " \n",
      "CCE Loss   : 0.001582521595992148\n",
      "Kernel Loss: 0.00945986621081829\n",
      " \n",
      "CCE Loss   : 0.0005124307936057448\n",
      "Kernel Loss: 0.009450532495975494\n",
      " \n",
      "CCE Loss   : 0.0020423531532287598\n",
      "Kernel Loss: 0.008356934413313866\n",
      " \n",
      "11\n",
      "CCE Loss   : 0.00035644019953906536\n",
      "Kernel Loss: 0.009481186978518963\n",
      " \n",
      "CCE Loss   : 0.0009583397186361253\n",
      "Kernel Loss: 0.009004740975797176\n",
      " \n",
      "CCE Loss   : 0.0034958263859152794\n",
      "Kernel Loss: 0.008295424282550812\n",
      " \n",
      "CCE Loss   : 0.0029956966172903776\n",
      "Kernel Loss: 0.00904660765081644\n",
      " \n",
      "CCE Loss   : 0.0010957629419863224\n",
      "Kernel Loss: 0.00820154044777155\n",
      " \n",
      "12\n",
      "CCE Loss   : 0.00037389498902484775\n",
      "Kernel Loss: 0.008477809838950634\n",
      " \n",
      "CCE Loss   : 0.0010833065025508404\n",
      "Kernel Loss: 0.00815822184085846\n",
      " \n",
      "CCE Loss   : 0.001120914937928319\n",
      "Kernel Loss: 0.007823582738637924\n",
      " \n",
      "CCE Loss   : 0.001787592307664454\n",
      "Kernel Loss: 0.007758957799524069\n",
      " \n",
      "CCE Loss   : 0.000905616965610534\n",
      "Kernel Loss: 0.008535162545740604\n",
      " \n",
      "13\n",
      "CCE Loss   : 0.0005126005271449685\n",
      "Kernel Loss: 0.007809911388903856\n",
      " \n",
      "CCE Loss   : 0.0006399326375685632\n",
      "Kernel Loss: 0.007748277857899666\n",
      " \n",
      "CCE Loss   : 0.0010653891367837787\n",
      "Kernel Loss: 0.007492117118090391\n",
      " \n",
      "CCE Loss   : 0.0014703834895044565\n",
      "Kernel Loss: 0.007405000738799572\n",
      " \n",
      "CCE Loss   : 0.0013809020165354013\n",
      "Kernel Loss: 0.007678327616304159\n",
      " \n",
      "14\n",
      "CCE Loss   : 0.0016530543798580766\n",
      "Kernel Loss: 0.006947215646505356\n",
      " \n",
      "CCE Loss   : 0.0009749533492140472\n",
      "Kernel Loss: 0.006946421228349209\n",
      " \n",
      "CCE Loss   : 0.0007306698244065046\n",
      "Kernel Loss: 0.0065880729816854\n",
      " \n",
      "CCE Loss   : 0.0010732756927609444\n",
      "Kernel Loss: 0.007163474801927805\n",
      " \n",
      "CCE Loss   : 0.0006200730567798018\n",
      "Kernel Loss: 0.006889921613037586\n",
      " \n",
      "15\n",
      "CCE Loss   : 0.0008469216991215944\n",
      "Kernel Loss: 0.006723220460116863\n",
      " \n",
      "CCE Loss   : 0.0015505322953686118\n",
      "Kernel Loss: 0.006576248444616795\n",
      " \n",
      "CCE Loss   : 0.00040985288796946406\n",
      "Kernel Loss: 0.0064538102596998215\n",
      " \n",
      "CCE Loss   : 0.0007251882343553007\n",
      "Kernel Loss: 0.006570355035364628\n",
      " \n",
      "CCE Loss   : 0.001219714991748333\n",
      "Kernel Loss: 0.006321922410279512\n",
      " \n",
      "16\n",
      "CCE Loss   : 0.000898428144864738\n",
      "Kernel Loss: 0.006013329606503248\n",
      " \n",
      "CCE Loss   : 0.0006943570333532989\n",
      "Kernel Loss: 0.005856270901858807\n",
      " \n",
      "CCE Loss   : 0.00037747385795228183\n",
      "Kernel Loss: 0.006297582294791937\n",
      " \n",
      "CCE Loss   : 0.0005445267888717353\n",
      "Kernel Loss: 0.005937430541962385\n",
      " \n",
      "CCE Loss   : 0.0005386441480368376\n",
      "Kernel Loss: 0.006060242187231779\n",
      " \n",
      "17\n",
      "CCE Loss   : 0.0006233342573978007\n",
      "Kernel Loss: 0.0058662258088588715\n",
      " \n",
      "CCE Loss   : 0.00029263977194204926\n",
      "Kernel Loss: 0.006066382396966219\n",
      " \n",
      "CCE Loss   : 0.000631846662145108\n",
      "Kernel Loss: 0.005601509474217892\n",
      " \n",
      "CCE Loss   : 0.00404263474047184\n",
      "Kernel Loss: 0.006261159665882587\n",
      " \n",
      "CCE Loss   : 0.00022675012587569654\n",
      "Kernel Loss: 0.006148564163595438\n",
      " \n",
      "18\n",
      "CCE Loss   : 0.0002506376476958394\n",
      "Kernel Loss: 0.0060976785607635975\n",
      " \n",
      "CCE Loss   : 0.0003896918788086623\n",
      "Kernel Loss: 0.0062386831268668175\n",
      " \n",
      "CCE Loss   : 0.0005321060307323933\n",
      "Kernel Loss: 0.0052724359557032585\n",
      " \n",
      "CCE Loss   : 0.0004119337536394596\n",
      "Kernel Loss: 0.005322157870978117\n",
      " \n",
      "CCE Loss   : 0.0006398400291800499\n",
      "Kernel Loss: 0.005136095918715\n",
      " \n",
      "19\n",
      "CCE Loss   : 0.0006126485532149673\n",
      "Kernel Loss: 0.0050456468015909195\n",
      " \n",
      "CCE Loss   : 0.000701147539075464\n",
      "Kernel Loss: 0.004829637240618467\n",
      " \n",
      "CCE Loss   : 0.0005688633536919951\n",
      "Kernel Loss: 0.004931910894811153\n",
      " \n",
      "CCE Loss   : 0.0006875675171613693\n",
      "Kernel Loss: 0.005087000317871571\n",
      " \n",
      "CCE Loss   : 0.0011562601430341601\n",
      "Kernel Loss: 0.004846036434173584\n",
      " \n",
      "20\n",
      "CCE Loss   : 0.0007043540244922042\n",
      "Kernel Loss: 0.004744454752653837\n",
      " \n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-35-1f973e6e9f7e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     11\u001b[0m         \u001b[0mpreds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m         \u001b[0mtarget_C\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_target_C\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mC\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclone\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclone\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweights\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpreds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     14\u001b[0m         \u001b[0mkernel_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0ml1_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mC\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_C\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-34-9592d259f722>\u001b[0m in \u001b[0;36mget_target_C\u001b[0;34m(C, x, weights, preds)\u001b[0m\n\u001b[1;32m     15\u001b[0m         \u001b[0mmost_nb_vec\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0margmax_loc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mC\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtopn\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mC\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtopn\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnonzero\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mj\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtopn\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m             \u001b[0margmax_loc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mC\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtopn\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mC\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtopn\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnonzero\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     18\u001b[0m             \u001b[0mdist\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mabs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmost_nb_vec\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0margmax_loc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     19\u001b[0m             \u001b[0mdists\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdist\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "for epoch in range(num_epochs):\n",
    "    \n",
    "    for i, data in enumerate(train_loader):\n",
    "        \n",
    "        netC.zero_grad()\n",
    "\n",
    "        weights = netC.classifier[0].weight\n",
    "        imgs, labels = data\n",
    "        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)\n",
    "        logits, x, C = netC(imgs)\n",
    "        preds = torch.argmax(logits, axis=1)\n",
    "        \n",
    "        target_C = get_target_C(C.clone().detach(), x.clone().detach(), weights, preds)\n",
    "        kernel_loss = l1_loss(C, target_C)\n",
    "\n",
    "#         kernel_loss = C.flatten().mean()\n",
    "        classifier_loss = cce_loss(logits, labels)\n",
    "\n",
    "        loss = kernel_loss + classifier_loss\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        if i % 400 == 0:\n",
    "            print(\"CCE Loss   :\", classifier_loss.item())\n",
    "            print(\"Kernel Loss:\", kernel_loss.item())\n",
    "            print(\" \")\n",
    "            \n",
    "    print(epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " Validation Accuracy After Training: 0.9957\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    netC = netC.eval()\n",
    "    total_correct = 0\n",
    "\n",
    "    for i, data in enumerate(test_loader):\n",
    "\n",
    "        imgs, labels = data\n",
    "        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)\n",
    "        logits, _, C = netC(imgs)\n",
    "        preds = torch.argmax(logits, axis=1)\n",
    "        total_correct += torch.sum(preds==labels)\n",
    "\n",
    "    print( \"\\n Validation Accuracy After Training:\",\n",
    "          (total_correct.item() / test_loader.dataset.targets.shape[0]) )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9985052614795918"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(C == 0).flatten().sum().item() / C.flatten().shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(target_C == 0).flatten().sum().item() / target_C.flatten().shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(netC.state_dict(), \"mnist_novel_cnn/cnn.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(519.7405)"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "C.flatten().sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(33.2146)"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "C[0].sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "6272"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "7*7*128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "# l1_loss(C, target_C)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "# C"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "metadata": {},
   "outputs": [],
   "source": [
    "# target_C = torch.zeros(C.shape)\n",
    "\n",
    "\n",
    "# # Iterate batch\n",
    "# for i in range(len(C)):\n",
    "    \n",
    "#     # Get contributions\n",
    "#     c = x[i] * weights[preds[i].item()]  \n",
    "#     # Get top 10 features\n",
    "#     topn = np.argsort(c.clone().detach().cpu().numpy())[-10:]\n",
    "\n",
    "#     # Select most diverse features\n",
    "#     dists = np.zeros(( len(topn) ))\n",
    "#     most_nb_vec = argmax_loc = (C[i][topn[-1]]==torch.max(C[i][topn[-1]])).nonzero()[0]\n",
    "#     for j in range(0, len(topn) ):\n",
    "#         argmax_loc = (C[i][topn[j]]==torch.max(C[i][topn[j]])).nonzero()[0]\n",
    "#         dist = sum(abs(most_nb_vec - argmax_loc))\n",
    "#         dists[j] = dist\n",
    "\n",
    "#     # Two most diverse features\n",
    "#     top3 = np.argsort(dists)[-3:]\n",
    "#     top3[0] = 9\n",
    "    \n",
    "#     for j in top3:\n",
    "#         argmax_loc = (C[i][topn[j]]==torch.max(C[i][topn[j]])).nonzero()[0]\n",
    "#         target_C[i][topn[j]][argmax_loc[0]][argmax_loc[1]] = x[i][topn[j]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "img_env",
   "language": "python",
   "name": "img_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
