{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import pickle\n",
    "import time\n",
    "import scipy\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torchvision.transforms as transforms\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "\n",
    "\n",
    "import random\n",
    "\n",
    "import argparse\n",
    "import os\n",
    "import random\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.parallel\n",
    "import torch.backends.cudnn as cudnn\n",
    "import torch.optim as optim\n",
    "import torch.utils.data\n",
    "import torchvision.datasets as dset\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.utils as vutils\n",
    "import torchvision\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import pickle\n",
    "import time\n",
    "import scipy\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torchvision.models as models\n",
    "\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "DEVICE = 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_dataloaders(b_size=32, shuffle=False):\n",
    "\n",
    "    test_transform = transforms.Compose(\n",
    "        [transforms.ToTensor(),\n",
    "         transforms.Normalize((0.5,), (0.5,), (0.5,))])\n",
    "\n",
    "    train_transform = test_transform\n",
    "\n",
    "    train_set = torchvision.datasets.CIFAR10(\n",
    "        root='./data/CIFAR10',\n",
    "        train=True,\n",
    "        download=True,\n",
    "        transform=train_transform)\n",
    "\n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        train_set,\n",
    "        batch_size=b_size,\n",
    "        shuffle=shuffle,\n",
    "    num_workers=2)\n",
    "\n",
    "    test_set = torchvision.datasets.CIFAR10(\n",
    "        root='./data/CIFAR10',\n",
    "        train=False,\n",
    "        download=True,\n",
    "        transform=test_transform)\n",
    "\n",
    "    test_loader = torch.utils.data.DataLoader(\n",
    "        test_set,\n",
    "        batch_size=b_size,\n",
    "        shuffle=False,\n",
    "    num_workers=2)\n",
    "\n",
    "    return train_loader, test_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "train_loader, test_loader = load_dataloaders(b_size=1, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CNN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(CNN, self).__init__()\n",
    "\n",
    "        # input is Z, going into a convolution\n",
    "        self.conv1 = nn.Conv2d(3, 8, kernel_size=5, stride=1, padding=2)\n",
    "        self.bn1   = nn.BatchNorm2d(8)\n",
    "\n",
    "        self.conv2 = nn.Conv2d(8, 16, kernel_size=5, stride=2, padding=2)\n",
    "        self.bn2 = nn.BatchNorm2d(16)\n",
    "\n",
    "        self.conv3 = nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2)\n",
    "        self.bn3 = nn.BatchNorm2d(32)\n",
    "\n",
    "        self.conv4 = nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=2)\n",
    "        self.bn4 = nn.BatchNorm2d(64)\n",
    "\n",
    "        self.conv5 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)\n",
    "        self.bn5 = nn.BatchNorm2d(128)\n",
    "\n",
    "        self.conv6 = nn.Conv2d(128, 128, kernel_size=1, stride=1, padding=0)\n",
    "        self.bn6 = nn.BatchNorm2d(128)\n",
    "        \n",
    "        self.avgpool = nn.AvgPool2d(8)\n",
    "        self.linear = nn.Linear(128, 10)    \n",
    "        self.relu = nn.ReLU()\n",
    "        \n",
    "\n",
    "        \n",
    "    def forward(self, I):   \n",
    "        x = self.relu(self.bn1(self.conv1(I)))\n",
    "        x = self.relu(self.bn2(self.conv2(x)))\n",
    "        x = self.relu(self.bn3(self.conv3(x)))\n",
    "        x = self.relu(self.bn4(self.conv4(x)))\n",
    "        x = self.relu(self.bn5(self.conv5(x)))\n",
    "        C = self.relu(self.bn6(self.conv6(x)))\n",
    "        \n",
    "        \n",
    "        x = self.avgpool(C)\n",
    "        x = x.view(x.shape[0], x.shape[1])\n",
    "        logits = self.linear(x)\n",
    "        return logits, x, C"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "netC = CNN()\n",
    "netC.load_state_dict(torch.load('weights/cnn.pth', map_location=torch.device(DEVICE)))\n",
    "netC = netC.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights = netC.linear.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_c = list()\n",
    "X_test_c = list()\n",
    "\n",
    "X_train_x = list()\n",
    "X_test_x = list()\n",
    "\n",
    "X_train_y = list()\n",
    "X_test_y = list()\n",
    "\n",
    "X_train_C = list()\n",
    "X_test_C = list()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Collect Twin Data and Feature Map Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0\n",
      "0.01\n",
      "0.02\n",
      "0.03\n",
      "0.04\n",
      "0.05\n",
      "0.06\n",
      "0.07\n",
      "0.08\n",
      "0.09\n",
      "0.1\n",
      "0.11\n",
      "0.12\n",
      "0.13\n",
      "0.14\n",
      "0.15\n",
      "0.16\n",
      "0.17\n",
      "0.18\n",
      "0.19\n",
      "0.2\n",
      "0.21\n",
      "0.22\n",
      "0.23\n",
      "0.24\n",
      "0.25\n",
      "0.26\n",
      "0.27\n",
      "0.28\n",
      "0.29\n",
      "0.3\n",
      "0.31\n",
      "0.32\n",
      "0.33\n",
      "0.34\n",
      "0.35\n",
      "0.36\n",
      "0.37\n",
      "0.38\n",
      "0.39\n",
      "0.4\n",
      "0.41\n",
      "0.42\n",
      "0.43\n",
      "0.44\n",
      "0.45\n",
      "0.46\n",
      "0.47\n",
      "0.48\n",
      "0.49\n",
      "0.5\n",
      "0.51\n",
      "0.52\n",
      "0.53\n",
      "0.54\n",
      "0.55\n",
      "0.56\n",
      "0.57\n",
      "0.58\n",
      "0.59\n",
      "0.6\n",
      "0.61\n",
      "0.62\n",
      "0.63\n",
      "0.64\n",
      "0.65\n",
      "0.66\n",
      "0.67\n",
      "0.68\n",
      "0.69\n",
      "0.7\n",
      "0.71\n",
      "0.72\n",
      "0.73\n",
      "0.74\n",
      "0.75\n",
      "0.76\n",
      "0.77\n",
      "0.78\n",
      "0.79\n",
      "0.8\n",
      "0.81\n",
      "0.82\n",
      "0.83\n",
      "0.84\n",
      "0.85\n",
      "0.86\n",
      "0.87\n",
      "0.88\n",
      "0.89\n",
      "0.9\n",
      "0.91\n",
      "0.92\n",
      "0.93\n",
      "0.94\n",
      "0.95\n",
      "0.96\n",
      "0.97\n",
      "0.98\n",
      "0.99\n"
     ]
    }
   ],
   "source": [
    "#### Iterate just one image at a time\n",
    "\n",
    "for i, data in enumerate(train_loader):\n",
    "    img, label = data\n",
    "    img, label = img.to(DEVICE), label.to(DEVICE)\n",
    "    logits, x, C = netC(img)\n",
    "    y_hat = torch.argmax(logits).item()\n",
    "    c = torch.mul(x[0], weights[y_hat])\n",
    "    \n",
    "    X_train_c.append(c.cpu().detach().numpy().tolist())\n",
    "    X_train_C.append(C.detach().numpy().tolist())\n",
    "    X_train_x.append(x[0].cpu().detach().numpy().tolist())\n",
    "    X_train_y.append(y_hat)\n",
    "    \n",
    "    if i % 500 == 0:\n",
    "        print(i / len(train_loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0\n",
      "0.05\n",
      "0.1\n",
      "0.15\n",
      "0.2\n",
      "0.25\n",
      "0.3\n",
      "0.35\n",
      "0.4\n",
      "0.45\n",
      "0.5\n",
      "0.55\n",
      "0.6\n",
      "0.65\n",
      "0.7\n",
      "0.75\n",
      "0.8\n",
      "0.85\n",
      "0.9\n",
      "0.95\n"
     ]
    }
   ],
   "source": [
    "#### Iterate just one image at a time\n",
    "\n",
    "for i, data in enumerate(test_loader):\n",
    "    img, label = data\n",
    "    img, label = img.to(DEVICE), label.to(DEVICE)\n",
    "    logits, x, C = netC(img)\n",
    "    y_hat = torch.argmax(logits).item()\n",
    "    c = torch.mul(x[0], weights[y_hat])\n",
    "    \n",
    "    X_test_c.append(c.cpu().detach().numpy().tolist())\n",
    "    X_test_C.append(C.detach().numpy().tolist())\n",
    "    X_test_x.append(x[0].cpu().detach().numpy().tolist())\n",
    "    X_test_y.append(y_hat)\n",
    "    \n",
    "    if i % 500 == 0:\n",
    "        print(i / len(test_loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_c = np.array(X_train_c)\n",
    "X_test_c = np.array(X_test_c)\n",
    "\n",
    "X_train_C = np.array(X_train_C)\n",
    "X_test_C = np.array(X_test_C)\n",
    "\n",
    "X_train_x = np.array(X_train_x)\n",
    "X_test_x = np.array(X_test_x)\n",
    "\n",
    "X_train_y = np.array(X_train_y)\n",
    "X_test_y = np.array(X_test_y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(\"data/X_train_cont.npy\", X_train_c)\n",
    "np.save(\"data/X_test_cont.npy\",  X_test_c)\n",
    "np.save(\"data/X_train_conv.npy\", X_train_C)\n",
    "np.save(\"data/X_test_conv.npy\",  X_test_C)\n",
    "\n",
    "np.save(\"data/X_train_x.npy\",    X_train_x)\n",
    "np.save(\"data/X_test_x.npy\",     X_test_x)\n",
    "np.save(\"data/X_train_y.npy\",    X_train_y)\n",
    "np.save(\"data/X_test_y.npy\",     X_test_y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(50000, 128)\n",
      "(50000, 128)\n",
      "(50000, 1, 128, 8, 8)\n",
      "(50000,)\n"
     ]
    }
   ],
   "source": [
    "print(X_train_c.shape)\n",
    "print(X_train_x.shape)\n",
    "print(X_train_C.shape)\n",
    "print(X_train_y.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "neurips_img",
   "language": "python",
   "name": "neurips_img"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
