{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import pickle\n",
    "import time\n",
    "import scipy\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "from functions import *\n",
    "from ANNs import *\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "netC = CNN()\n",
    "netC.load_state_dict(torch.load('weights/cnn.pth', map_location=torch.device('cpu')))\n",
    "netC = netC.eval()\n",
    "DATAROOT = 'data'\n",
    "DEVICE = 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights = netC.linear.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, test_loader = load_dataloaders()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_c = np.load(DATAROOT + \"/X_train_cont.npy\")\n",
    "X_test_c = np.load(DATAROOT + \"/X_test_cont.npy\")\n",
    "X_train_C = np.load(DATAROOT + \"/X_train_conv.npy\")\n",
    "X_test_C = np.load(DATAROOT + \"/X_test_conv.npy\")\n",
    "\n",
    "X_train_x = np.load(DATAROOT + \"/X_train_x.npy\")\n",
    "X_test_x = np.load(DATAROOT + \"/X_test_x.npy\")\n",
    "train_preds = np.load(DATAROOT + \"/X_train_y.npy\")\n",
    "test_preds = np.load(DATAROOT + \"/X_test_y.npy\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run Experiments"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Contributions, L2, Layer x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier(algorithm='brute', metric='euclidean', n_neighbors=1)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Fit COLE and DkNN\n",
    "twin = KNeighborsClassifier(n_neighbors=1, algorithm=\"brute\", metric='euclidean') \n",
    "twin.fit(X_train_c, train_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10000/10000 [02:17<00:00, 72.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Agreement: 1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "#### Iterate just one image at a time\n",
    "agreement = 0\n",
    "\n",
    "for i in tqdm(range(len(test_preds))):\n",
    "    cnn_pred = test_preds[i]\n",
    "    twin_pred = twin.predict( np.array([X_test_c[i]]) )\n",
    "    \n",
    "    if (twin_pred[0] == cnn_pred):\n",
    "        agreement += 1\n",
    "\n",
    "print(\"Agreement:\", agreement / len(test_preds) )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Contributions, Cosine, Layer x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "temp_train = list()\n",
    "temp_test = list()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(X_train_C.shape[0]):\n",
    "    x = X_train_C[i][0].flatten()\n",
    "    x = x/x.sum()\n",
    "    temp_train.append(x)\n",
    "    \n",
    "for i in range(X_test_C.shape[0]):\n",
    "    x = X_test_C[i][0].flatten()\n",
    "    x = x/x.sum()\n",
    "    temp_test.append(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "temp_train = np.array(temp_train)\n",
    "temp_test  = np.array(temp_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier(algorithm='brute', metric='euclidean', n_neighbors=1)"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Fit COLE and DkNN\n",
    "twin_dknn = KNeighborsClassifier(n_neighbors=1, algorithm=\"brute\", metric='euclidean') \n",
    "twin_dknn.fit(temp_train, train_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 583/10000 [05:48<1:32:50,  1.69it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  9%|▉         | 883/10000 [09:02<1:42:46,  1.48it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  9%|▉         | 948/10000 [09:44<1:37:27,  1.55it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 1015/10000 [10:26<1:29:34,  1.67it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|█▏        | 1248/10000 [13:02<1:44:31,  1.40it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 13%|█▎        | 1261/10000 [13:12<1:47:39,  1.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|█▍        | 1394/10000 [14:36<1:26:02,  1.67it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 17%|█▋        | 1682/10000 [17:37<1:21:25,  1.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|█▊        | 1790/10000 [18:46<1:32:25,  1.48it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 18%|█▊        | 1791/10000 [18:47<1:32:36,  1.48it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 19%|█▊        | 1860/10000 [19:35<1:34:46,  1.43it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 19%|█▉        | 1879/10000 [19:52<2:02:38,  1.10it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 2041/10000 [21:52<1:22:54,  1.60it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 22%|██▏       | 2238/10000 [24:01<1:20:31,  1.61it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 23%|██▎       | 2294/10000 [24:46<2:02:22,  1.05it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 25%|██▍       | 2455/10000 [26:23<1:22:36,  1.52it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 26%|██▋       | 2649/10000 [28:35<1:14:11,  1.65it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 28%|██▊       | 2772/10000 [30:04<1:15:03,  1.61it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 29%|██▉       | 2928/10000 [31:45<1:11:26,  1.65it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|██▉       | 2996/10000 [32:26<1:09:24,  1.68it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 31%|███       | 3074/10000 [33:13<1:08:57,  1.67it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 34%|███▎      | 3366/10000 [36:09<1:14:39,  1.48it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 36%|███▌      | 3602/10000 [38:38<1:04:14,  1.66it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 37%|███▋      | 3654/10000 [39:09<1:02:06,  1.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 37%|███▋      | 3728/10000 [39:52<1:01:24,  1.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 38%|███▊      | 3763/10000 [40:13<1:01:13,  1.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 43%|████▎     | 4285/10000 [45:23<56:10,  1.70it/s]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 43%|████▎     | 4290/10000 [45:26<56:15,  1.69it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 44%|████▎     | 4370/10000 [46:13<55:05,  1.70it/s]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 44%|████▍     | 4444/10000 [46:57<54:18,  1.71it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 45%|████▌     | 4505/10000 [47:33<53:43,  1.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 45%|████▌     | 4508/10000 [47:35<53:42,  1.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 46%|████▌     | 4572/10000 [48:12<53:18,  1.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 47%|████▋     | 4741/10000 [49:52<51:32,  1.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 49%|████▉     | 4887/10000 [51:19<50:03,  1.70it/s]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 56%|█████▋    | 5637/10000 [58:59<43:25,  1.67it/s]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 57%|█████▋    | 5655/10000 [59:10<43:10,  1.68it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 59%|█████▉    | 5938/10000 [1:01:59<39:21,  1.72it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 61%|██████    | 6102/10000 [1:03:35<38:00,  1.71it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 66%|██████▌   | 6555/10000 [1:08:28<40:44,  1.41it/s]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 66%|██████▌   | 6559/10000 [1:08:32<47:08,  1.22it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 66%|██████▋   | 6626/10000 [1:09:15<37:16,  1.51it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 69%|██████▉   | 6884/10000 [1:11:58<31:03,  1.67it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 83%|████████▎ | 8317/10000 [1:27:46<24:09,  1.16it/s]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 85%|████████▌ | 8528/10000 [1:30:18<15:48,  1.55it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 96%|█████████▌| 9621/10000 [1:42:25<04:20,  1.46it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 96%|█████████▋| 9630/10000 [1:42:31<04:07,  1.50it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 96%|█████████▋| 9643/10000 [1:42:40<04:17,  1.39it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 97%|█████████▋| 9673/10000 [1:43:02<03:55,  1.39it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 97%|█████████▋| 9680/10000 [1:43:08<05:13,  1.02it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 97%|█████████▋| 9693/10000 [1:43:18<03:48,  1.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 97%|█████████▋| 9694/10000 [1:43:19<03:40,  1.39it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 97%|█████████▋| 9699/10000 [1:43:23<03:30,  1.43it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 98%|█████████▊| 9793/10000 [1:44:28<02:22,  1.45it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10000/10000 [1:46:48<00:00,  1.56it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Agreement: 0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "#### Iterate just one image at a time\n",
    "agreement = 0\n",
    "\n",
    "for i in tqdm(range(len(test_preds))):\n",
    "    cnn_pred = test_preds[i]\n",
    "    twin_pred = twin_dknn.predict( np.array([temp_test[i]]) )\n",
    "    \n",
    "    if (twin_pred[0] != cnn_pred):\n",
    "        print(\"fail\")\n",
    "    else:\n",
    "        agreement += 1\n",
    "\n",
    "print(\"Agreement:\", agreement / len(test_preds) )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9936"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(len(test_preds) - 64) / len(test_preds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Inter-Agreement"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 1/10000 [00:00<1:53:05,  1.47it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 4/10000 [00:02<1:56:23,  1.43it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 5/10000 [00:03<1:54:41,  1.45it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 8/10000 [00:05<1:58:37,  1.40it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 9/10000 [00:06<1:55:17,  1.44it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 10/10000 [00:06<1:53:05,  1.47it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 12/10000 [00:08<1:52:16,  1.48it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 13/10000 [00:08<1:50:34,  1.51it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 14/10000 [00:09<1:57:31,  1.42it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 15/10000 [00:10<2:01:46,  1.37it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 16/10000 [00:11<2:00:37,  1.38it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 17/10000 [00:11<2:04:24,  1.34it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 18/10000 [00:12<2:12:03,  1.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 19/10000 [00:13<2:05:03,  1.33it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 20/10000 [00:14<2:02:14,  1.36it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fail\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 20/10000 [00:14<2:02:06,  1.36it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-29-9f76b549f87c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      5\u001b[0m     \u001b[0mxp1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtwin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkneighbors\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mX_test_c\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_neighbors\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_distance\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m     \u001b[0mxp2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtwin_dknn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkneighbors\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtemp_test\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_neighbors\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_distance\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mxp1\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mxp2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Documents/University/Ph.D/Contrastive Explanations Experiments/Image/img_env/lib/python3.7/site-packages/sklearn/neighbors/_base.py\u001b[0m in \u001b[0;36mkneighbors\u001b[0;34m(self, X, n_neighbors, return_distance)\u001b[0m\n\u001b[1;32m    642\u001b[0m                 \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fit_X\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce_func\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreduce_func\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    643\u001b[0m                 \u001b[0mmetric\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meffective_metric_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_jobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mn_jobs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 644\u001b[0;31m                 **kwds))\n\u001b[0m\u001b[1;32m    645\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    646\u001b[0m         \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fit_method\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'ball_tree'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'kd_tree'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Documents/University/Ph.D/Contrastive Explanations Experiments/Image/img_env/lib/python3.7/site-packages/sklearn/metrics/pairwise.py\u001b[0m in \u001b[0;36mpairwise_distances_chunked\u001b[0;34m(X, Y, reduce_func, metric, n_jobs, working_memory, **kwds)\u001b[0m\n\u001b[1;32m   1615\u001b[0m             \u001b[0mX_chunk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0msl\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1616\u001b[0m         D_chunk = pairwise_distances(X_chunk, Y, metric=metric,\n\u001b[0;32m-> 1617\u001b[0;31m                                      n_jobs=n_jobs, **kwds)\n\u001b[0m\u001b[1;32m   1618\u001b[0m         if ((X is Y or Y is None)\n\u001b[1;32m   1619\u001b[0m                 \u001b[0;32mand\u001b[0m \u001b[0mPAIRWISE_DISTANCE_FUNCTIONS\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetric\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Documents/University/Ph.D/Contrastive Explanations Experiments/Image/img_env/lib/python3.7/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36minner_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     70\u001b[0m                           FutureWarning)\n\u001b[1;32m     71\u001b[0m         \u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0marg\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 72\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     73\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0minner_f\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     74\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Documents/University/Ph.D/Contrastive Explanations Experiments/Image/img_env/lib/python3.7/site-packages/sklearn/metrics/pairwise.py\u001b[0m in \u001b[0;36mpairwise_distances\u001b[0;34m(X, Y, metric, n_jobs, force_all_finite, **kwds)\u001b[0m\n\u001b[1;32m   1777\u001b[0m         \u001b[0mfunc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpartial\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdistance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcdist\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetric\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmetric\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1778\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1779\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0m_parallel_pairwise\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mY\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_jobs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1780\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1781\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Documents/University/Ph.D/Contrastive Explanations Experiments/Image/img_env/lib/python3.7/site-packages/sklearn/metrics/pairwise.py\u001b[0m in \u001b[0;36m_parallel_pairwise\u001b[0;34m(X, Y, func, n_jobs, **kwds)\u001b[0m\n\u001b[1;32m   1358\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1359\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0meffective_n_jobs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_jobs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1360\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mY\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1361\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1362\u001b[0m     \u001b[0;31m# enforce a threading backend to prevent data communication overhead\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Documents/University/Ph.D/Contrastive Explanations Experiments/Image/img_env/lib/python3.7/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36minner_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     70\u001b[0m                           FutureWarning)\n\u001b[1;32m     71\u001b[0m         \u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0marg\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 72\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     73\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0minner_f\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     74\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Documents/University/Ph.D/Contrastive Explanations Experiments/Image/img_env/lib/python3.7/site-packages/sklearn/metrics/pairwise.py\u001b[0m in \u001b[0;36meuclidean_distances\u001b[0;34m(X, Y, Y_norm_squared, squared, X_norm_squared)\u001b[0m\n\u001b[1;32m    300\u001b[0m         \u001b[0mYY\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    301\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 302\u001b[0;31m         \u001b[0mYY\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrow_norms\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mY\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msquared\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnewaxis\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    303\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    304\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Documents/University/Ph.D/Contrastive Explanations Experiments/Image/img_env/lib/python3.7/site-packages/sklearn/utils/extmath.py\u001b[0m in \u001b[0;36mrow_norms\u001b[0;34m(X, squared)\u001b[0m\n\u001b[1;32m     73\u001b[0m         \u001b[0mnorms\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcsr_row_norms\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     74\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 75\u001b[0;31m         \u001b[0mnorms\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meinsum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'ij,ij->i'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     76\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     77\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0msquared\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<__array_function__ internals>\u001b[0m in \u001b[0;36meinsum\u001b[0;34m(*args, **kwargs)\u001b[0m\n",
      "\u001b[0;32m~/Documents/University/Ph.D/Contrastive Explanations Experiments/Image/img_env/lib/python3.7/site-packages/numpy/core/einsumfunc.py\u001b[0m in \u001b[0;36meinsum\u001b[0;34m(out, optimize, *operands, **kwargs)\u001b[0m\n\u001b[1;32m   1348\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mspecified_out\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1349\u001b[0m             \u001b[0mkwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'out'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1350\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mc_einsum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0moperands\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1351\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1352\u001b[0m     \u001b[0;31m# Check the kwargs to avoid a more cryptic error later, without having to\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "#### Iterate just one image at a time\n",
    "\n",
    "for i in tqdm(range(len(test_preds))):\n",
    "            \n",
    "    xp1 = twin.kneighbors(X=np.array([X_test_c[i]]), n_neighbors=1, return_distance=False )\n",
    "    \n",
    "    xp2 = twin_dknn.kneighbors(X=np.array([temp_test[i]]), n_neighbors=1, return_distance=False )\n",
    "    \n",
    "    if (xp1 != xp2):\n",
    "        print('Fail')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "img_env",
   "language": "python",
   "name": "img_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
