{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import pickle\n",
    "import time\n",
    "import scipy\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "\n",
    "from functions import *\n",
    "from ANNs import *\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "netC = CNN()\n",
    "netC.load_state_dict(torch.load('weights/cnn.pth', map_location=torch.device('cpu')))\n",
    "netC = netC.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights = netC.linear.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, test_loader = load_dataloaders()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_c = list()\n",
    "X_test_c = list()\n",
    "\n",
    "X_train_x = list()\n",
    "X_test_x = list()\n",
    "\n",
    "X_train_y = list()\n",
    "X_test_y = list()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Collect Twin Data and Feature Map Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0\n",
      "0.008333333333333333\n",
      "0.016666666666666666\n",
      "0.025\n",
      "0.03333333333333333\n",
      "0.041666666666666664\n",
      "0.05\n",
      "0.058333333333333334\n",
      "0.06666666666666667\n",
      "0.075\n",
      "0.08333333333333333\n",
      "0.09166666666666666\n",
      "0.1\n",
      "0.10833333333333334\n",
      "0.11666666666666667\n",
      "0.125\n",
      "0.13333333333333333\n",
      "0.14166666666666666\n",
      "0.15\n",
      "0.15833333333333333\n",
      "0.16666666666666666\n",
      "0.175\n",
      "0.18333333333333332\n",
      "0.19166666666666668\n",
      "0.2\n",
      "0.20833333333333334\n",
      "0.21666666666666667\n",
      "0.225\n",
      "0.23333333333333334\n",
      "0.24166666666666667\n",
      "0.25\n",
      "0.25833333333333336\n",
      "0.26666666666666666\n",
      "0.275\n",
      "0.2833333333333333\n",
      "0.2916666666666667\n",
      "0.3\n",
      "0.30833333333333335\n",
      "0.31666666666666665\n",
      "0.325\n",
      "0.3333333333333333\n",
      "0.3416666666666667\n",
      "0.35\n",
      "0.35833333333333334\n",
      "0.36666666666666664\n",
      "0.375\n",
      "0.38333333333333336\n",
      "0.39166666666666666\n",
      "0.4\n",
      "0.4083333333333333\n",
      "0.4166666666666667\n",
      "0.425\n",
      "0.43333333333333335\n",
      "0.44166666666666665\n",
      "0.45\n",
      "0.4583333333333333\n",
      "0.4666666666666667\n",
      "0.475\n",
      "0.48333333333333334\n",
      "0.49166666666666664\n",
      "0.5\n",
      "0.5083333333333333\n",
      "0.5166666666666667\n",
      "0.525\n",
      "0.5333333333333333\n",
      "0.5416666666666666\n",
      "0.55\n",
      "0.5583333333333333\n",
      "0.5666666666666667\n",
      "0.575\n",
      "0.5833333333333334\n",
      "0.5916666666666667\n",
      "0.6\n",
      "0.6083333333333333\n",
      "0.6166666666666667\n",
      "0.625\n",
      "0.6333333333333333\n",
      "0.6416666666666667\n",
      "0.65\n",
      "0.6583333333333333\n",
      "0.6666666666666666\n",
      "0.675\n",
      "0.6833333333333333\n",
      "0.6916666666666667\n",
      "0.7\n",
      "0.7083333333333334\n",
      "0.7166666666666667\n",
      "0.725\n",
      "0.7333333333333333\n",
      "0.7416666666666667\n",
      "0.75\n",
      "0.7583333333333333\n",
      "0.7666666666666667\n",
      "0.775\n",
      "0.7833333333333333\n",
      "0.7916666666666666\n",
      "0.8\n",
      "0.8083333333333333\n",
      "0.8166666666666667\n",
      "0.825\n",
      "0.8333333333333334\n",
      "0.8416666666666667\n",
      "0.85\n",
      "0.8583333333333333\n",
      "0.8666666666666667\n",
      "0.875\n",
      "0.8833333333333333\n",
      "0.8916666666666667\n",
      "0.9\n",
      "0.9083333333333333\n",
      "0.9166666666666666\n",
      "0.925\n",
      "0.9333333333333333\n",
      "0.9416666666666667\n",
      "0.95\n",
      "0.9583333333333334\n",
      "0.9666666666666667\n",
      "0.975\n",
      "0.9833333333333333\n",
      "0.9916666666666667\n"
     ]
    }
   ],
   "source": [
    "#### Iterate just one image at a time\n",
    "\n",
    "for i, data in enumerate(train_loader):\n",
    "    img, label = data\n",
    "    logits, x, C = netC(img)\n",
    "    y_hat = torch.argmax(logits).item()\n",
    "    c = torch.mul(x[0], weights)\n",
    "    \n",
    "    X_train_c.append(c.detach().numpy().tolist())\n",
    "    X_train_x.append(x[0].detach().numpy().tolist())\n",
    "    X_train_y.append(y_hat)\n",
    "    \n",
    "    if i % 500 == 0:\n",
    "        print(i / len(train_loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0\n",
      "0.05\n",
      "0.1\n",
      "0.15\n",
      "0.2\n",
      "0.25\n",
      "0.3\n",
      "0.35\n",
      "0.4\n",
      "0.45\n",
      "0.5\n",
      "0.55\n",
      "0.6\n",
      "0.65\n",
      "0.7\n",
      "0.75\n",
      "0.8\n",
      "0.85\n",
      "0.9\n",
      "0.95\n"
     ]
    }
   ],
   "source": [
    "#### Iterate just one image at a time\n",
    "\n",
    "for i, data in enumerate(test_loader):\n",
    "    img, label = data\n",
    "    logits, x, C = netC(img)\n",
    "    y_hat = torch.argmax(logits).item()\n",
    "    c = torch.mul(x[0], weights)\n",
    "    \n",
    "    X_test_c.append(c.detach().numpy().tolist())\n",
    "    X_test_x.append(x[0].detach().numpy().tolist())\n",
    "    X_test_y.append(y_hat)\n",
    "    \n",
    "    if i % 500 == 0:\n",
    "        print(i / len(test_loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_c = np.array(X_train_c)\n",
    "X_test_c = np.array(X_test_c)\n",
    "\n",
    "X_train_x = np.array(X_train_x)\n",
    "X_test_x = np.array(X_test_x)\n",
    "\n",
    "X_train_y = np.array(X_train_y)\n",
    "X_test_y = np.array(X_test_y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(\"data/X_train_cont.npy\", X_train_c)\n",
    "np.save(\"data/X_test_cont.npy\", X_test_c)\n",
    "np.save(\"data/X_train_x.npy\", X_train_x)\n",
    "np.save(\"data/X_test_x.npy\", X_test_x)\n",
    "np.save(\"data/X_train_y.npy\", X_train_y)\n",
    "np.save(\"data/X_test_y.npy\", X_test_y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(60000, 10, 128)\n",
      "(60000, 128)\n"
     ]
    }
   ],
   "source": [
    "print(X_train_c.shape)\n",
    "print(X_train_x.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(10000, 10, 128)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_test_c.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "img_env",
   "language": "python",
   "name": "img_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
