{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 146,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import pickle\n",
    "import time\n",
    "import scipy\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "\n",
    "\n",
    "from functions import *\n",
    "from ANNs import *\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 147,
   "metadata": {},
   "outputs": [],
   "source": [
    "netC = CNN()\n",
    "netC.load_state_dict(torch.load('weights/cnn.pth', map_location=torch.device('cpu')))\n",
    "netC = netC.eval()\n",
    "DEVICE = 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights = netC.linear.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 149,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, test_loader = load_dataloaders(batch_size=32, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 150,
   "metadata": {},
   "outputs": [],
   "source": [
    "# blackout_df = pd.read_csv('blackout_data.csv')          # Top 3 features combined  (100 nns)\n",
    "# blackout_df = pd.read_csv('blackout_data_feature_1.csv')  # number 1 feature  (10 nns)\n",
    "# blackout_df = pd.read_csv('blackout_data_feature_5.csv')  # number 5 feature  (10 nns)\n",
    "\n",
    "blackout_df = pd.read_csv('blackout_data_feature_1_activations.csv')  # number 5 feature  (10 nns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 151,
   "metadata": {},
   "outputs": [],
   "source": [
    "# blackout_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 152,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(50000, 2)\n",
      "(37557, 2)\n"
     ]
    }
   ],
   "source": [
    "print(blackout_df.shape)\n",
    "blackout_df = blackout_df.drop_duplicates()\n",
    "print(blackout_df.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Blackout and Filter Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 153,
   "metadata": {},
   "outputs": [],
   "source": [
    "#### blackout df data\n",
    "\n",
    "unit = 4\n",
    "\n",
    "for _, row in blackout_df.iterrows():\n",
    "    idx = row.instance\n",
    "    coord = row.coord\n",
    "    i = int(coord[1])\n",
    "    j = int(coord[4])\n",
    "    \n",
    "    temp = train_loader.dataset.data[idx] \n",
    "    temp[i*unit: i*unit+unit, j*unit: j*unit+unit] = 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 139,
   "metadata": {},
   "outputs": [],
   "source": [
    "### blackout random data\n",
    "\n",
    "unit = 4\n",
    "\n",
    "for _, row in blackout_df.iterrows():\n",
    "    idx = row.instance\n",
    "    coord = row.coord\n",
    "    i = int(random.randint(0, 6))\n",
    "    j = int(random.randint(0, 6))\n",
    "    \n",
    "    temp = train_loader.dataset.data[idx] \n",
    "    temp[i*unit: i*unit+unit, j*unit: j*unit+unit] = 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 154,
   "metadata": {},
   "outputs": [],
   "source": [
    "keep_idxs = np.array( [False for _ in range(train_loader.dataset.targets.shape[0])] )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "metadata": {},
   "outputs": [],
   "source": [
    "for _, row in blackout_df.iterrows():\n",
    "    idx = row.instance\n",
    "    keep_idxs[idx] = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 156,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader.dataset.targets = train_loader.dataset.targets[keep_idxs]\n",
    "train_loader.dataset.data    = train_loader.dataset.data[keep_idxs]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Retrain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 157,
   "metadata": {},
   "outputs": [],
   "source": [
    "cce_loss = torch.nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(netC.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 158,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.15711598098278046\n",
      "0.011580591090023518\n",
      "0.004043043591082096\n",
      "0.0025863272603601217\n",
      "0.0020912489853799343\n",
      "CCE Loss: 0.037392260486064186\n",
      "0.00522867776453495\n",
      "0.01583453267812729\n",
      "0.0015341582475230098\n",
      "0.00042523042066022754\n",
      "0.000984199927188456\n",
      "CCE Loss: 0.017023013247608992\n",
      "0.01711840182542801\n",
      "0.11882060766220093\n",
      "0.0033372899051755667\n",
      "0.0009629745618440211\n",
      "0.018731161952018738\n",
      "CCE Loss: 0.01100002572723827\n"
     ]
    }
   ],
   "source": [
    "netC = netC.train()\n",
    "\n",
    "\n",
    "for epoch in range(3):\n",
    "\n",
    "    start = time.time()\n",
    "    running_loss = 0.0\n",
    "    \n",
    "    for i, data in enumerate(train_loader):\n",
    "        \n",
    "        netC.zero_grad()\n",
    "\n",
    "        imgs, labels = data\n",
    "        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)\n",
    "        logits, _, _ = netC(imgs)\n",
    "        preds = torch.argmax(logits, axis=1)\n",
    "        loss = cce_loss(logits, labels)\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        running_loss += loss.item()\n",
    "        \n",
    "        if i % 200 == 0:\n",
    "            print(loss.item())\n",
    "    print(  \"CCE Loss:\", running_loss / len(train_loader)  )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test Network"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Control"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " Test Accuracy: 0.9972\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    netC = netC.eval()\n",
    "    total_correct = 0\n",
    "\n",
    "    for i, data in enumerate(test_loader):\n",
    "\n",
    "        imgs, labels = data\n",
    "        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)\n",
    "        logits, _, _ = netC(imgs)\n",
    "        preds = torch.argmax(logits, axis=1)\n",
    "        total_correct += torch.sum(preds==labels)\n",
    "\n",
    "    print( \"\\n Test Accuracy:\", \n",
    "          (total_correct.item() / len(test_loader.dataset.targets) )  )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 145,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " Test Accuracy: 0.9962\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    netC = netC.eval()\n",
    "    total_correct = 0\n",
    "\n",
    "    for i, data in enumerate(test_loader):\n",
    "\n",
    "        imgs, labels = data\n",
    "        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)\n",
    "        logits, _, _ = netC(imgs)\n",
    "        preds = torch.argmax(logits, axis=1)\n",
    "        total_correct += torch.sum(preds==labels)\n",
    "\n",
    "    print( \"\\n Test Accuracy:\",\n",
    "          (total_correct.item() / len(test_loader.dataset.targets) )  )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Targeted Blackout Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 159,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " Test Accuracy: 0.9949\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    netC = netC.eval()\n",
    "    total_correct = 0\n",
    "\n",
    "    for i, data in enumerate(test_loader):\n",
    "\n",
    "        imgs, labels = data\n",
    "        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)\n",
    "        logits, _, _ = netC(imgs)\n",
    "        preds = torch.argmax(logits, axis=1)\n",
    "        total_correct += torch.sum(preds==labels)\n",
    "\n",
    "    print( \"\\n Test Accuracy:\",\n",
    "          (total_correct.item() / len(test_loader.dataset.targets) )  )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "img_env",
   "language": "python",
   "name": "img_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
