{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import pickle\n",
    "import time\n",
    "import scipy\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torchvision.transforms as transforms\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "\n",
    "\n",
    "import random\n",
    "\n",
    "import argparse\n",
    "import os\n",
    "import random\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.parallel\n",
    "import torch.backends.cudnn as cudnn\n",
    "import torch.optim as optim\n",
    "import torch.utils.data\n",
    "import torchvision.datasets as dset\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.utils as vutils\n",
    "import torchvision\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import pickle\n",
    "import time\n",
    "import scipy\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torchvision.models as models\n",
    "\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "DEVICE = 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_dataloaders(b_size=1, shuffle=False):\n",
    "\ttrain_transform = transforms.Compose(\n",
    "\t\t[transforms.ToTensor(),\n",
    "\t\t transforms.Normalize((0.5,), (0.5,), (0.5,))])\n",
    "\n",
    "\ttest_transform = transforms.Compose(\n",
    "\t\t[transforms.ToTensor(),\n",
    "\t\t transforms.Normalize((0.5,), (0.5,), (0.5,))])\n",
    "\n",
    "\ttrain_set = torchvision.datasets.FashionMNIST(\n",
    "\t\troot='./data/FashionMNIST',\n",
    "\t\ttrain=True,\n",
    "\t\tdownload=True,\n",
    "\t\ttransform=train_transform)\n",
    "\n",
    "\ttrain_loader = torch.utils.data.DataLoader(\n",
    "\t\ttrain_set,\n",
    "\t\tbatch_size=b_size,\n",
    "\t\tshuffle=shuffle,\n",
    "    num_workers=2)\n",
    "\n",
    "\ttest_set = torchvision.datasets.FashionMNIST(\n",
    "\t\troot='./data/FashionMNIST',\n",
    "\t\ttrain=False,\n",
    "\t\tdownload=True,\n",
    "\t\ttransform=test_transform)\n",
    "\n",
    "\ttest_loader = torch.utils.data.DataLoader(\n",
    "\t\ttest_set,\n",
    "\t\tbatch_size=b_size,\n",
    "\t\tshuffle=False,\n",
    "    num_workers=2)\n",
    " \n",
    "\treturn train_loader, test_loader\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, test_loader = load_dataloaders(b_size=1, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CNN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(CNN, self).__init__()\n",
    "        self.avgpool = torch.nn.AvgPool2d(7)\n",
    "        self.linear = torch.nn.Linear(128, 10)     \n",
    "        self.main = nn.Sequential(\n",
    "\n",
    "            # input is Z, going into a convolution\n",
    "            nn.Conv2d(1, 8, kernel_size=5, stride=1, padding=2),\n",
    "            nn.BatchNorm2d(8),\n",
    "            nn.ReLU(True),\n",
    "            \n",
    "            nn.Conv2d(8, 16, kernel_size=5, stride=2, padding=2),\n",
    "            nn.BatchNorm2d(16),\n",
    "            nn.ReLU(True),\n",
    "\n",
    "            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),\n",
    "            nn.BatchNorm2d(32),\n",
    "            nn.ReLU(True),\n",
    "\n",
    "            nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=2),\n",
    "            nn.BatchNorm2d(64),\n",
    "            nn.ReLU(True),\n",
    "\n",
    "            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),\n",
    "            nn.BatchNorm2d(128),\n",
    "            nn.ReLU(True),\n",
    "\n",
    "            nn.Conv2d(128, 128, kernel_size=1, stride=1, padding=0),\n",
    "            nn.BatchNorm2d(128),\n",
    "            nn.ReLU(True)  \n",
    "            )\n",
    "                \n",
    "    def forward(self, x):   \n",
    "        C = self.main(x)\n",
    "        x = self.avgpool(C)\n",
    "        x = x.view(x.shape[0], x.shape[1])\n",
    "        logits = self.linear(x)\n",
    "        return logits, x, C"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "netC = CNN()\n",
    "netC.load_state_dict(torch.load('weights/cnn.pth', map_location=torch.device(DEVICE)))\n",
    "netC = netC.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights = netC.linear.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_c = list()\n",
    "X_test_c = list()\n",
    "\n",
    "X_train_x = list()\n",
    "X_test_x = list()\n",
    "\n",
    "X_train_y = list()\n",
    "X_test_y = list()\n",
    "\n",
    "X_train_C = list()\n",
    "X_test_C = list()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Collect Twin Data and Feature Map Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0\n",
      "0.008333333333333333\n",
      "0.016666666666666666\n",
      "0.025\n",
      "0.03333333333333333\n",
      "0.041666666666666664\n",
      "0.05\n",
      "0.058333333333333334\n",
      "0.06666666666666667\n",
      "0.075\n",
      "0.08333333333333333\n",
      "0.09166666666666666\n",
      "0.1\n",
      "0.10833333333333334\n",
      "0.11666666666666667\n",
      "0.125\n",
      "0.13333333333333333\n",
      "0.14166666666666666\n",
      "0.15\n",
      "0.15833333333333333\n",
      "0.16666666666666666\n",
      "0.175\n",
      "0.18333333333333332\n",
      "0.19166666666666668\n",
      "0.2\n",
      "0.20833333333333334\n",
      "0.21666666666666667\n",
      "0.225\n",
      "0.23333333333333334\n",
      "0.24166666666666667\n",
      "0.25\n",
      "0.25833333333333336\n",
      "0.26666666666666666\n",
      "0.275\n",
      "0.2833333333333333\n",
      "0.2916666666666667\n",
      "0.3\n",
      "0.30833333333333335\n",
      "0.31666666666666665\n",
      "0.325\n",
      "0.3333333333333333\n",
      "0.3416666666666667\n",
      "0.35\n",
      "0.35833333333333334\n",
      "0.36666666666666664\n",
      "0.375\n",
      "0.38333333333333336\n",
      "0.39166666666666666\n",
      "0.4\n",
      "0.4083333333333333\n",
      "0.4166666666666667\n",
      "0.425\n",
      "0.43333333333333335\n",
      "0.44166666666666665\n",
      "0.45\n",
      "0.4583333333333333\n",
      "0.4666666666666667\n",
      "0.475\n",
      "0.48333333333333334\n",
      "0.49166666666666664\n",
      "0.5\n",
      "0.5083333333333333\n",
      "0.5166666666666667\n",
      "0.525\n",
      "0.5333333333333333\n",
      "0.5416666666666666\n",
      "0.55\n",
      "0.5583333333333333\n",
      "0.5666666666666667\n",
      "0.575\n",
      "0.5833333333333334\n",
      "0.5916666666666667\n",
      "0.6\n",
      "0.6083333333333333\n",
      "0.6166666666666667\n",
      "0.625\n",
      "0.6333333333333333\n",
      "0.6416666666666667\n",
      "0.65\n",
      "0.6583333333333333\n",
      "0.6666666666666666\n",
      "0.675\n",
      "0.6833333333333333\n",
      "0.6916666666666667\n",
      "0.7\n",
      "0.7083333333333334\n",
      "0.7166666666666667\n",
      "0.725\n",
      "0.7333333333333333\n",
      "0.7416666666666667\n",
      "0.75\n",
      "0.7583333333333333\n",
      "0.7666666666666667\n",
      "0.775\n",
      "0.7833333333333333\n",
      "0.7916666666666666\n",
      "0.8\n",
      "0.8083333333333333\n",
      "0.8166666666666667\n",
      "0.825\n",
      "0.8333333333333334\n",
      "0.8416666666666667\n",
      "0.85\n",
      "0.8583333333333333\n",
      "0.8666666666666667\n",
      "0.875\n",
      "0.8833333333333333\n",
      "0.8916666666666667\n",
      "0.9\n",
      "0.9083333333333333\n",
      "0.9166666666666666\n",
      "0.925\n",
      "0.9333333333333333\n",
      "0.9416666666666667\n",
      "0.95\n",
      "0.9583333333333334\n",
      "0.9666666666666667\n",
      "0.975\n",
      "0.9833333333333333\n",
      "0.9916666666666667\n"
     ]
    }
   ],
   "source": [
    "#### Iterate just one image at a time\n",
    "\n",
    "for i, data in enumerate(train_loader):\n",
    "    img, label = data\n",
    "    img, label = img.to(DEVICE), label.to(DEVICE)\n",
    "    logits, x, C = netC(img)\n",
    "    y_hat = torch.argmax(logits).item()\n",
    "    c = torch.mul(x[0], weights[y_hat])\n",
    "    \n",
    "    X_train_c.append(c.cpu().detach().numpy().tolist())\n",
    "    X_train_C.append(C.detach().numpy().tolist())\n",
    "    X_train_x.append(x[0].cpu().detach().numpy().tolist())\n",
    "    X_train_y.append(y_hat)\n",
    "    \n",
    "    if i % 500 == 0:\n",
    "        print(i / len(train_loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0\n",
      "0.05\n",
      "0.1\n",
      "0.15\n",
      "0.2\n",
      "0.25\n",
      "0.3\n",
      "0.35\n",
      "0.4\n",
      "0.45\n",
      "0.5\n",
      "0.55\n",
      "0.6\n",
      "0.65\n",
      "0.7\n",
      "0.75\n",
      "0.8\n",
      "0.85\n",
      "0.9\n",
      "0.95\n"
     ]
    }
   ],
   "source": [
    "#### Iterate just one image at a time\n",
    "\n",
    "for i, data in enumerate(test_loader):\n",
    "    img, label = data\n",
    "    img, label = img.to(DEVICE), label.to(DEVICE)\n",
    "    logits, x, C = netC(img)\n",
    "    y_hat = torch.argmax(logits).item()\n",
    "    c = torch.mul(x[0], weights[y_hat])\n",
    "    \n",
    "    X_test_c.append(c.cpu().detach().numpy().tolist())\n",
    "    X_train_C.append(C.detach().numpy().tolist())\n",
    "    X_test_x.append(x[0].cpu().detach().numpy().tolist())\n",
    "    X_test_y.append(y_hat)\n",
    "    \n",
    "    if i % 500 == 0:\n",
    "        print(i / len(test_loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_c = np.array(X_train_c)\n",
    "X_test_c = np.array(X_test_c)\n",
    "\n",
    "X_train_C = np.array(X_train_C)\n",
    "X_test_C = np.array(X_test_C)\n",
    "\n",
    "X_train_x = np.array(X_train_x)\n",
    "X_test_x = np.array(X_test_x)\n",
    "\n",
    "X_train_y = np.array(X_train_y)\n",
    "X_test_y = np.array(X_test_y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(\"data/X_train_cont.npy\", X_train_c)\n",
    "np.save(\"data/X_test_cont.npy\",  X_test_c)\n",
    "np.save(\"data/X_train_conv.npy\", X_train_C)\n",
    "np.save(\"data/X_test_conv.npy\",  X_test_C)\n",
    "\n",
    "np.save(\"data/X_train_x.npy\",    X_train_x)\n",
    "np.save(\"data/X_test_x.npy\",     X_test_x)\n",
    "np.save(\"data/X_train_y.npy\",    X_train_y)\n",
    "np.save(\"data/X_test_y.npy\",     X_test_y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(60000, 128)\n",
      "(60000, 128)\n",
      "(70000, 1, 128, 7, 7)\n",
      "(60000,)\n"
     ]
    }
   ],
   "source": [
    "print(X_train_c.shape)\n",
    "print(X_train_x.shape)\n",
    "print(X_train_C.shape)\n",
    "print(X_train_y.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "neurips_img",
   "language": "python",
   "name": "neurips_img"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
