{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Quantitative Experiments\n",
    "\n",
    "1. Test how much the salient regions on a query overlap with the salient regions on the nearest neighbors\n",
    "\n",
    "2. Do Sanity Check on salient areas by blacking out the input pixels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import pickle\n",
    "import time\n",
    "import scipy\n",
    "import matplotlib.pyplot as plt\n",
    "import torchvision.models as models\n",
    "\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "\n",
    "from functions import *\n",
    "from ANNs import *\n",
    "from tqdm import tqdm\n",
    "from copy import deepcopy\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pre-Requisites for Algorithm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "netC = CNN()\n",
    "netC.load_state_dict(torch.load('weights/cnn.pth', map_location=torch.device('cpu')))\n",
    "netC = netC.eval()\n",
    "DATAROOT = 'data'\n",
    "DEVICE = 'cpu'\n",
    "# MEAN_NORM=(0.485, 0.456, 0.406)\n",
    "# STD_NORM=(0.229, 0.224, 0.225)\n",
    "MEAN_NORM=(0.5, 0.5, 0.5)\n",
    "STD_NORM=(0.5, 0.5, 0.5)\n",
    "\n",
    "NUM_NEIGHBORS = 1  # Num neighbors to check for features\n",
    "XP_FEATURE_NUM = 1  # Num box features to show"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class netClassifier(nn.Module):\n",
    "    \n",
    "    def __init__(self, netC):\n",
    "        super(netClassifier, self).__init__()\n",
    "        self.net = netC\n",
    "        \n",
    "    def forward(self, C):\n",
    "        x = self.net.avgpool(C)\n",
    "        x = x.view(-1, 128)\n",
    "        logits = self.net.linear(x)\n",
    "        return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def inverse_normalize(tensor, mean=MEAN_NORM, std=STD_NORM):\n",
    "    for t, m, s in zip(tensor, mean, std):\n",
    "        t.mul_(s).add_(m)\n",
    "    return tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_transformed_data(idxs, loader):\n",
    "    \"\"\"\n",
    "    Takes in indexs and a dataloader\n",
    "    returns: Those indexs transformed\n",
    "    \"\"\"\n",
    "    \n",
    "    subset = torch.utils.data.Subset(loader.dataset, idxs)\n",
    "    loader_subset = torch.utils.data.DataLoader(subset, batch_size=1, num_workers=0, shuffle=False)\n",
    "\n",
    "    transformed_data = list()\n",
    "    for data in loader_subset:\n",
    "        transformed_data.append(data[0])\n",
    "        \n",
    "    return transformed_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_box_xp(Conv, latent_feature):\n",
    "    \"\"\"\n",
    "    Get region in Conv which is most similar to latent feature\n",
    "    \"\"\"\n",
    "    \n",
    "    max_dist = float('inf')\n",
    "    coord = [None, None]\n",
    "    for i in range(Conv.shape[2]):\n",
    "        for j in range(Conv.shape[3]):\n",
    "            query_seg = Conv[:, :, i:i+1, j:j+1 ]\n",
    "            new_dist = torch.cdist(query_seg.view(-1, Conv.shape[1]),\n",
    "                                   latent_feature.view(-1, Conv.shape[1]), p=1.0).item()\n",
    "            if new_dist < max_dist:\n",
    "                max_dist = new_dist\n",
    "                coord = [i, j]\n",
    "                \n",
    "    return coord, max_dist / Conv.shape[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_box_query(Conv, org_logits, net_classifier):\n",
    "    \n",
    "    results = list()\n",
    "    pred = torch.argmax(org_logits, dim=1).item()\n",
    "    \n",
    "    for i in range(Conv.shape[2]):\n",
    "        for j in range(Conv.shape[3]):\n",
    "            C_copy = Conv.clone().detach()\n",
    "            C_copy[:, :, i:i+1, j:j+1 ] = 0\n",
    "            new_logit = net_classifier(C_copy)[0][pred]\n",
    "            logit_change = org_logits[0][pred].item() - new_logit.item()   # Must modify for negative logits\n",
    "            results.append( [logit_change, [i, j]] )\n",
    "            \n",
    "    return sorted(results, key=lambda x: x[0], reverse=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [],
   "source": [
    "net_classifier = netClassifier(netC)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [],
   "source": [
    "LABELS_DICT = {0:\"0\", 1:\"1\", 2:\"2\", 3:\"3\", 4:\"4\", 5:\"5\", 6:\"6\", 7:\"7\", 8:\"8\", 9:\"9\" }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights = netC.linear.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, test_loader = load_dataloaders()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = train_loader.dataset.data\n",
    "y_train = train_loader.dataset.targets\n",
    "\n",
    "X_test = test_loader.dataset.data\n",
    "y_test = test_loader.dataset.targets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [],
   "source": [
    "# X_train, y_train, X_test, y_test = get_MNIST_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_c = np.load(DATAROOT + \"/X_train_cont.npy\")\n",
    "X_test_c = np.load(DATAROOT + \"/X_test_cont.npy\")\n",
    "X_train_C = np.load(DATAROOT + \"/X_train_conv.npy\")\n",
    "X_test_C = np.load(DATAROOT + \"/X_test_conv.npy\")\n",
    "\n",
    "X_train_x = np.load(DATAROOT + \"/X_train_x.npy\")\n",
    "X_test_x = np.load(DATAROOT + \"/X_test_x.npy\")\n",
    "train_preds = np.load(DATAROOT + \"/X_train_y.npy\")\n",
    "test_preds = np.load(DATAROOT + \"/X_test_y.npy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier(algorithm='brute', metric='euclidean', n_neighbors=1)"
      ]
     },
     "execution_count": 91,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Fit COLE and DkNN\n",
    "twin = KNeighborsClassifier(n_neighbors=1, algorithm=\"brute\", metric='euclidean') \n",
    "twin.fit(X_train_c, train_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_salient_regions(Conv, org_logits, net_classifier):\n",
    "    \n",
    "    results = list()\n",
    "    pred = torch.argmax(org_logits, dim=1).item()\n",
    "    \n",
    "    for i in range(Conv.shape[2]):\n",
    "        for j in range(Conv.shape[3]):\n",
    "            C_copy = Conv.clone().detach()\n",
    "            C_copy[:, :, i:i+1, j:j+1 ] = 0\n",
    "            new_logit = net_classifier(C_copy)[0][pred]\n",
    "            logit_change = org_logits[0][pred].item() - new_logit.item()   # Must modify for negative logits\n",
    "            results.append( [logit_change, [i, j]] )\n",
    "            \n",
    "    return sorted(results, key=lambda x: x[0], reverse=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [],
   "source": [
    "def blackout_C():\n",
    "\n",
    "    box = query_nb_boxes[0][1]\n",
    "    \n",
    "    unit = 4\n",
    "    filter_size = 7\n",
    "\n",
    "    for i in range(filter_size):\n",
    "        for j in range(filter_size):\n",
    "\n",
    "            img_copy = query_img_trans.clone()\n",
    "            img_copy[:, :, i*unit: i*unit+unit, j*unit: j*unit+unit] = -1.0\n",
    "\n",
    "            logits, x, C = netC(img_copy)\n",
    "\n",
    "            if i == box[0] and j == box[1]:\n",
    "                print(logits[0][query_pred].item(), [i, j])\n",
    "                plt.imshow(img_copy[0][0].detach().numpy())\n",
    "                plt.show()\n",
    "\n",
    "                return C, i, j"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Baseline with ExMatchnia"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier(algorithm='brute', metric='euclidean', n_neighbors=1)"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "temp_train = list()\n",
    "temp_test = list()\n",
    "\n",
    "for i in range(X_train_C.shape[0]):\n",
    "    x = deepcopy(X_train_C[i][0].flatten())\n",
    "    x = x/x.sum()\n",
    "    temp_train.append(x)\n",
    "    \n",
    "for i in range(X_test_C.shape[0]):\n",
    "    x = deepcopy(X_test_C[i][0].flatten())\n",
    "    x = x/x.sum()\n",
    "    temp_test.append(x)\n",
    "    \n",
    "temp_train = np.array(temp_train)\n",
    "temp_test  = np.array(temp_test)\n",
    "\n",
    "# Fit COLE and DkNN\n",
    "twin = KNeighborsClassifier(n_neighbors=1, algorithm=\"brute\", metric='euclidean') \n",
    "twin.fit(temp_train, train_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|          | 100/10000 [01:10<1:56:14,  1.42it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Agreement: 1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "results = list()\n",
    "agreement = 0\n",
    "saliency_sim = list()\n",
    "\n",
    "for query_idx, data in enumerate(tqdm(test_loader)):\n",
    "    \n",
    "    # Get data\n",
    "    img, label = data\n",
    "    \n",
    "    # Get query information\n",
    "    query_label = y_test[query_idx].item()\n",
    "    query_pred  = test_preds[query_idx].item()\n",
    "    query_img_trans = get_transformed_data([query_idx], test_loader)[0]\n",
    "    query_cont = X_test_c[query_idx]\n",
    "    \n",
    "    # Get query boxes\n",
    "    query_logits, query_x, query_C = netC(query_img_trans)\n",
    "    query_nb_boxes = get_salient_regions(query_C, query_logits, net_classifier)\n",
    "    \n",
    "    # Get explanation nn\n",
    "    xp_idxs = twin.kneighbors(X=[temp_test[query_idx]], n_neighbors=NUM_NEIGHBORS, return_distance=False)\n",
    "    xps_imgs_trans = get_transformed_data(xp_idxs[0], train_loader)\n",
    "    xp_logits, xp_x, xp_C = netC(xps_imgs_trans[0])\n",
    "    xp_nb_boxes = get_salient_regions(xp_C, xp_logits, net_classifier)\n",
    "    \n",
    "    if train_preds[xp_idxs[0][0]] == query_pred:\n",
    "        agreement += 1\n",
    "        \n",
    "    saliency_sim.append(abs(query_nb_boxes[0][0] - xp_nb_boxes[0][0]))\n",
    "    \n",
    "    # See distance between query nb feature and nn\n",
    "    window_idx = query_nb_boxes[0][1]  # -1 just to make hyperparam easier to think about\n",
    "    query_feature = query_C[ :, :, window_idx[0]:window_idx[0]+1, window_idx[1]:window_idx[1]+1]\n",
    "\n",
    "    window_idx = xp_nb_boxes[0][1]  # -1 just to make hyperparam easier to think about\n",
    "    xp_feature = xp_C[ :, :, window_idx[0]:window_idx[0]+1, window_idx[1]:window_idx[1]+1]\n",
    "\n",
    "    dist = sum(abs(query_feature.flatten() - xp_feature.flatten()))\n",
    "\n",
    "    results.append(dist.item())\n",
    "        \n",
    "    if query_idx == 100:\n",
    "        print(\"Agreement:\", agreement / len(results))\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [],
   "source": [
    "saliency_sim = np.array(saliency_sim)\n",
    "ex_matchina = np.array(deepcopy(results))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "58.37529969923567\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<seaborn.axisgrid.FacetGrid at 0x139f05b50>"
      ]
     },
     "execution_count": 75,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFgCAYAAACFYaNMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAARrUlEQVR4nO3de4yldX3H8fdHVtR4KaLbzbosBSu1NW28ZLWIl0SxdrUqaC1qjG5T7NpUG42tFmvS2KR/aC9q2xh1K8a18YI3AtoWRURNo6ILooBoQQrCuuwuqNVeol349o95Vo/DzO647vN8DzvvVzKZc37nzJ5vnpl57zPPnPNMqgpJ0vTu0j2AJK1WBliSmhhgSWpigCWpiQGWpCZrugdYic2bN9cFF1zQPYYkHaostXin2AO+5ZZbukeQpMPuThFgSToSGWBJamKAJamJAZakJqM+CyLJ9cD3gduAfVW1KcmxwDnACcD1wBlV9Z0x55CkeTTFHvATquphVbVpuH4WcFFVnQRcNFyXpFWn4xDEacD24fJ24PSGGSSp3dgBLuDjSS5NsnVYW1dVu4bLNwPrlvrAJFuT7EiyY+/evSOPKUnTG/uVcI+tqp1Jfh64MMnXZm+sqkqy5AmJq2obsA1g06ZNnrRY0hFn1D3gqto5vN8DnAs8CtidZD3A8H7PmDNI0rwaLcBJ7pnk3vsvA08GrgTOB7YMd9sCnDfWDJI0z8Y8BLEOODfJ/sd5T1VdkOSLwPuTnAncAJwx4gySNLdGC3BVXQc8dIn1W4FTx3pcSbqz8JVwktTkiA7who3Hk6T9bcPG47s3haQ5dKc4Ifuh+tZNN/Kct322ewzOefEp3SNImkNH9B6wJM0zAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUZPcBJjkrypSQfHa6fmOSSJNcmOSfJ0WPPIEnzaIo94JcBV89cfz3wxqp6EPAd4MwJZpCkuTNqgJMcB/wW8PbheoAnAh8c7rIdOH3MGSRpXo29B/wm4FXA7cP1+wHfrap9w/WbgA1LfWCSrUl2JNmxd+/ekceUpOmNFuAkTwP2VNWlh/LxVbWtqjZV1aa1a9ce5ukkqd+aEf/txwDPSPJU4O7AfYC/A45JsmbYCz4O2DniDJI0t0bbA66qV1fVcVV1AvBc4JNV9XzgYuDZw922AOeNNYMkzbOO5wH/KfCKJNeycEz47IYZJKndmIcgfqSqPgV8arh8HfCoKR5XkuaZr4STpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJqMFOMndk3whyZeTXJXkL4b1E5NckuTaJOckOXqsGSRpno25B/wD4IlV9VDgYcDmJCcDrwfeWFUPAr4DnDniDJI0t0YLcC34r+HqXYe3Ap4IfHBY3w6cPtYMkjTPRj0GnOSoJJcDe4ALgW8A362qfcNdbgI2LPOxW5PsSLJj7969Y44pSS1GDXBV3VZVDwOOAx4F/PJP8bHbqmpTVW1au3btWCNKUptJngVRVd8FLgYeDRyTZM1w03HAzilmkKR5M+azINYmOWa4fA/gN4CrWQjxs4e7bQHOG2sGSZpnaw5+l0O2Htie5CgWQv/+qvpokq8C70vyl8CXgLNHnEGS5tZoAa6qrwAPX2L9OhaOB0vSquYr4SSpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqcmKApzkMStZkySt3Er3gP9hhWuSpBU64LkgkjwaOAVYm+QVMzfdBzhqzMEk6Uh3sJPxHA3ca7jfvWfWv8ePTykpSToEBwxwVX0a+HSSd1bVDRPNJEmrwkpPR3m3JNuAE2Y/pqqeOMZQkrQarDTAHwDeCrwduG28cSRp9VhpgPdV1VtGnUSSVpmVPg3tI0n+MMn6JMfufxt1Mkk6wq10D3jL8P6VM2sFPPDwjiNJq8eKAlxVJ449iCStNisKcJIXLrVeVe86vONI0uqx0kMQj5y5fHfgVOAywABL0iFa6SGIP5q9nuQY4H1jDCRJq8Whno7yvwGPC0vSz2Clx4A/wsKzHmDhJDy/Arx/rKEkaTVY6THgv5m5vA+4oapuGmEeSVo1VnQIYjgpz9dYOCPafYEfjjmUJK0GK/2LGGcAXwB+BzgDuCSJp6OUpJ/BSg9BvAZ4ZFXtAUiyFvgE8MGxBpOkI91KnwVxl/3xHdz6U3ysJGkJK90DviDJx4D3DtefA/zLOCNJ0upwsL8J9yBgXVW9MsmzgMcON30OePfYw0nSkexge8BvAl4NUFUfBj4MkOTXhtuePuJsknREO9hx3HVVdcXixWHthFEmkqRV4mABPuYAt93jMM4hSavOwQK8I8nvL15M8iLg0nFGkqTV4WDHgF8OnJvk+fw4uJuAo4FnjjiXJB3xDhjgqtoNnJLkCcCvDsv/XFWfHH0ySTrCrfR8wBcDF488iyStKr6aTZKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJanJaAFOsjHJxUm+muSqJC8b1o9NcmGSa4b39x1rBkmaZ2PuAe8D/riqHgKcDLwkyUOAs4CLquok4KLhuiStOqMFuKp2VdVlw+XvA1cDG4DTgO3D3bYDp481gyTNs0mOASc5AXg4cAmwrqp2DTfdDKxb5mO2JtmRZMfevXunGFOSJjV6gJPcC/gQ8PKq+t7sbVVVQC31cVW1rao2VdWmtWvXjj2mJE1u1AAnuSsL8X13VX14WN6dZP1w+3pgz5gzSNK8GvNZEAHOBq6uqjfM3HQ+sGW4vAU4b6wZJGmerejP0h+ixwAvAK5Icvmw9mfA64D3JzkTuAE4Y8QZJGlujRbgqvo3IMvcfOpYjytJdxa+Ek6SmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmqzpHmBVuMsaknRPAcADjtvIzhu/2T2GJAzwNG7fx3Pe9tnuKQA458WndI8gaeAhCElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWriX8RYbfzzSNLcMMCrjX8eSZobHoKQpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqMlqAk7wjyZ4kV86sHZvkwiTXDO/vO9bjS9K8G3MP+J3A5kVrZwEXVdVJwEXDdUlalUYLcFV9Bvj2ouXTgO3D5e3A6WM9viTNu6mPAa+rql3D5ZuBdcvdMcnWJDuS7Ni7d+8000nShNp+CVdVBdQBbt9WVZuqatPatWsnnEySpjF1gHcnWQ8wvN8z8eNL0tyYOsDnA1uGy1uA8yZ+fEmaG2M+De29wOeABye5KcmZwOuA30hyDfCk4bokrUqjnQ+4qp63zE2njvWYknRn4ivhJKmJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmoz2PGDpoO6yhiTdU/CA4zay88Zvdo+hVcgAq8/t+3jO2z7bPQXnvPiU7hG0SnkIQpKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCXdwYaNx5NkLt42bDy+e3OMxnNBSLqDb91041ycpwOO7HN1uAcsSU0MsCQ1McCS1MQAS1ITAyxJTXwWhDQnfxoJ/PNIq40BlubkTyPBkf2UK92RhyAkqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJr4UWZonc3ReirkxR9vkcJ+rwwBL82ROzksxV+ekmJNtAod/u3gIQpKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJatIS4CSbk3w9ybVJzuqYQZK6TR7gJEcBbwaeAjwEeF6Sh0w9hyR169gDfhRwbVVdV1U/BN4HnNYwhyS1SlVN+4DJs4HNVfWi4foLgF+vqpcuut9WYOtw9cHA1ycd9I7uD9zSPMNiznRw8zYPzN9M8zYPHHkz3VJVmxcvzu2fJKqqbcC27jn2S7KjqjZ1zzHLmQ5u3uaB+Ztp3uaB1TNTxyGIncDGmevHDWuStKp0BPiLwElJTkxyNPBc4PyGOSSp1eSHIKpqX5KXAh8DjgLeUVVXTT3HIZibwyEznOng5m0emL+Z5m0eWCUzTf5LOEnSAl8JJ0lNDLAkNTHAS0iyMcnFSb6a5KokLxvWX5tkZ5LLh7enTjjT9UmuGB53x7B2bJILk1wzvL/vhPM8eGY7XJ7ke0lePvU2SvKOJHuSXDmztuR2yYK/H14C/5Ukj5honr9O8rXhMc9NcsywfkKS/53ZVm893PMcYKZlP09JXj1so68n+c0JZzpnZp7rk1w+rI++nQ7wPT/u11JV+bboDVgPPGK4fG/g31l42fRrgT9pmul64P6L1v4KOGu4fBbw+qbZjgJuBn5h6m0EPB54BHDlwbYL8FTgX4EAJwOXTDTPk4E1w+XXz8xzwuz9Jt5GS36ehq/zLwN3A04EvgEcNcVMi27/W+DPp9pOB/ieH/VryT3gJVTVrqq6bLj8feBqYEPvVEs6Ddg+XN4OnN40x6nAN6rqhqkfuKo+A3x70fJy2+U04F214PPAMUnWjz1PVX28qvYNVz/PwnPfJ7PMNlrOacD7quoHVfUfwLUsnD5gspmSBDgDeO/hftwDzLPc9/yoX0sG+CCSnAA8HLhkWHrp8CPHO6b8kR8o4ONJLh1epg2wrqp2DZdvBtZNOM+s5/KT3yxd22i/5bbLBuDGmfvdxPT/sf4eC3tO+52Y5EtJPp3kcRPPstTnaR620eOA3VV1zczaZNtp0ff8qF9LBvgAktwL+BDw8qr6HvAW4BeBhwG7WPgxaSqPrapHsHAWuZckefzsjbXwc9HkzynMwotpngF8YFjq3EZ30LVdlpLkNcA+4N3D0i7g+Kp6OPAK4D1J7jPROHP1eVrkefzkf+iTbaclvud/ZIyvJQO8jCR3ZeET8e6q+jBAVe2uqtuq6nbgHxnhR7PlVNXO4f0e4NzhsXfv/7FneL9nqnlmPAW4rKp2D/O1baMZy22XtpfBJ/ld4GnA84dvZIYf828dLl/KwvHWX5pingN8nlpPFZBkDfAs4JyZWSfZTkt9zzPy15IBXsJwDOps4OqqesPM+uwxnmcCVy7+2JHmuWeSe++/zMIvda5k4SXcW4a7bQHOm2KeRX5ib6VrGy2y3HY5H3jh8Bvsk4H/nPnxcjRJNgOvAp5RVf8zs742C+fHJskDgZOA68aeZ3i85T5P5wPPTXK3JCcOM31hipkGTwK+VlU37V+YYjst9z3P2F9LY/5m8c76BjyWhR81vgJcPrw9Ffgn4Iph/Xxg/UTzPJCF30x/GbgKeM2wfj/gIuAa4BPAsRNvp3sCtwI/N7M26TZiIf67gP9j4TjcmcttFxZ+Y/1mFvagrgA2TTTPtSwcL9z/tfTW4b6/PXw+LwcuA54+4TZa9vMEvGbYRl8HnjLVTMP6O4E/WHTf0bfTAb7nR/1a8qXIktTEQxCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElN/h9s3cWG4FJZLwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(ex_matchina.mean())\n",
    "sns.displot(ex_matchina)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.21728234524183934\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<seaborn.axisgrid.FacetGrid at 0x2280cd2d0>"
      ]
     },
     "execution_count": 76,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFgCAYAAACFYaNMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAARrElEQVR4nO3da4ylBX3H8e8PVsRWLFhXQtbd4r0SW9GMVtG0KmqQF6KtFYkXTNBFLUajMTH6ovbyQhMvTVtjWZWIjRdQseKlWESUKIpdFbl6pVgWVnbw3jZVV/59cR7qBHd2zu7Oc/5ndr6f5GTPec5zzvOfycw3zz5znnNSVUiSZu+Q7gEkab0ywJLUxABLUhMDLElNDLAkNdnQPcA0TjrppLrooou6x5Ck5WR/HrQm9oBvu+227hEkadWtiQBL0sHIAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwO8gk2bt5DkgC+bNm/p/lIkzZk18YbsnW7ZcROnnn35AT/PeWeesArTSDqYuAcsSU0MsCQ1McCS1MQAS1KT0QKc5PAkX07y9STXJvmrYfl9k1yR5DtJzkty2FgzSNI8G3MP+OfAE6vqYcDxwElJHg28EXhrVT0A+BFwxogzSNLcGi3ANfFfw827DJcCngh8aFh+LvD0sWaQpHk26jHgJIcmuRLYBVwMfBf4cVXtHlbZAWwacwZJmlejBriqflVVxwP3AR4F/P60j02yNcn2JNsXFxfHGlGS2szkVRBV9WPgUuAxwJFJ7jgD7z7Azcs8ZltVLVTVwsaNG2cxpiTN1JivgtiY5Mjh+t2AJwPXMwnxM4fVTgc+OtYMkjTPxnwviGOAc5McyiT051fVx5NcB3wgyd8CXwPeNeIMkjS3RgtwVV0FPHwPy29gcjxYktY1z4STpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJqMFOMnmJJcmuS7JtUlePix/fZKbk1w5XE4eawZJmmcbRnzu3cCrquqrSY4AvpLk4uG+t1bVm0bctiTNvdECXFU7gZ3D9Z8luR7YNNb2JGmtmckx4CTHAg8HrhgWnZXkqiTnJDlqmcdsTbI9yfbFxcVZjClJMzV6gJPcHfgw8Iqq+inwduD+wPFM9pDfvKfHVdW2qlqoqoWNGzeOPaYkzdyoAU5yFybxfW9VXQBQVbdW1a+q6nbgHcCjxpxBkubVmK+CCPAu4PqqesuS5ccsWe0ZwDVjzSBJ82zMV0E8FngecHWSK4dlrwVOS3I8UMCNwJkjziBJc2vMV0F8Hsge7vrkWNuUpLXEM+EkqYkBlqQmBliSmhhgSWpy0AZ40+YtJDngiySNZcyXobW6ZcdNnHr25Qf8POedecIqTCNJv+mg3QOWpHlngCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqcloAU6yOcmlSa5Lcm2Slw/L75nk4iTfHv49aqwZJGmejbkHvBt4VVUdBzwa+IskxwGvAS6pqgcClwy3JWndGS3AVbWzqr46XP8ZcD2wCTgFOHdY7Vzg6WPNIEnzbCbHgJMcCzwcuAI4uqp2Dnd9Hzh6mcdsTbI9yfbFxcVZjClJMzV6gJPcHfgw8Iqq+unS+6qqgNrT46pqW1UtVNXCxo0bxx5TkmZu1AAnuQuT+L63qi4YFt+a5Jjh/mOAXWPOIEnzasxXQQR4F3B9Vb1lyV0XAqcP108HPjrWDJI0zzaM+NyPBZ4HXJ3kymHZa4E3AOcnOQP4HvCsEWeQpLk1WoCr6vNAlrn7xLG2K0lrhWfCSVITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNZkqwEkeO80ySdL0pt0D/ocpl0mSprTXT0VO8hjgBGBjklcuuesewKFjDiZJB7uVPpb+MODuw3pHLFn+U+CZYw0lSevBXgNcVZ8DPpfk3VX1vRnNJEnrwkp7wHe4a5JtwLFLH1NVTxxjKElaD6YN8AeBfwLeCfxqvHEkaf2YNsC7q+rto04iSevMtC9D+1iSlyY5Jsk977iMOpkkHeSm3QM+ffj31UuWFXC/1R1HktaPqQJcVfcdexBJWm+mCnCS5+9peVW9Z3XHkaT1Y9pDEI9ccv1w4ETgq4ABlqT9NO0hiJctvZ3kSOADYwwkSevF/r4d5X8DHheWpAMw7THgjzF51QNM3oTnIcD5Yw0lSevBtMeA37Tk+m7ge1W1Y4R5JGndmOoQxPCmPN9g8o5oRwG/GHMoSVoPpv1EjGcBXwb+HHgWcEUS345Skg7AtIcgXgc8sqp2ASTZCHwa+NBYg0nSwW7aV0Ecckd8Bz/Yh8dKkvZg2ohelORTSV6Q5AXAJ4BP7u0BSc5JsivJNUuWvT7JzUmuHC4n7//okrS2rfSZcA8Ajq6qVyf5U+Bxw11fBN67wnO/G/hHfvNsubdW1Zt+c3VJWl9W2gP+Oyaf/0ZVXVBVr6yqVwIfGe5bVlVdBvxwFWaUpIPSSgE+uqquvvPCYdmx+7nNs5JcNRyiOGo/n0OS1ryVAnzkXu67235s7+3A/YHjgZ3Am5dbMcnWJNuTbF9cXNyPTc2ZQzaQ5IAvmzZv6f5KJK2SlV6Gtj3Ji6rqHUsXJnkh8JV93VhV3brkOd4BfHwv624DtgEsLCzUcuutGbfv5tSzLz/gpznvzBNWYRhJ82ClAL8C+EiS5/Dr4C4AhwHP2NeNJTmmqnYON58BXLO39SXpYLbXAA97rCckeQLw0GHxJ6rqMys9cZL3A48H7pVkB/CXwOOTHM/kjX1uBM7c78klaY2b9v2ALwUu3ZcnrqrT9rD4XfvyHJJ0MPNsNklqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWoyWoCTnJNkV5Jrliy7Z5KLk3x7+PeosbYvSfNuzD3gdwMn3WnZa4BLquqBwCXDbUlal0YLcFVdBvzwTotPAc4drp8LPH2s7UvSvJv1MeCjq2rncP37wNHLrZhka5LtSbYvLi7OZrq14JANJDngy6bNW7q/Emnd29C14aqqJLWX+7cB2wAWFhaWXW/duX03p559+QE/zXlnnrAKw0g6ELPeA741yTEAw7+7Zrx9SZobsw7whcDpw/XTgY/OePuSNDfGfBna+4EvAg9OsiPJGcAbgCcn+TbwpOG2JK1Lox0DrqrTlrnrxLG2KUlriWfCSVITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUZEPHRpPcCPwM+BWwu6oWOuaQpE4tAR48oapua9y+JLXyEIQkNekKcAH/luQrSbbuaYUkW5NsT7J9cXFxxuNpWps2byHJAV82bd7S/aVIM9d1COJxVXVzknsDFyf5RlVdtnSFqtoGbANYWFiojiG1slt23MSpZ19+wM9z3pknrMI00trSsgdcVTcP/+4CPgI8qmMOSeo08wAn+e0kR9xxHXgKcM2s55Ckbh2HII4GPpLkju2/r6ouaphDklrNPMBVdQPwsFlvV5LmjS9Dk6QmBliSmhhgSWpigCWpiQFerw7ZsCpnsGnvVutMQc8WPDh1vhmPOt2+2zPYZmC1zhQEv9cHI/eAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigDUfVunMPM8W01rimXCaD56Zp3XIPWBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkB1sFllT5bbsNhh6/K88yjTZu3zNX3aLU+x2+1vq5Zfq6gnwmng8sqfrbcwfoZdbfsuOmg/B6t5tc1K+4BS1ITAyxJTQywJDUxwJLUxABLUpOWACc5Kck3k3wnyWs6ZpCkbjMPcJJDgbcBTwWOA05Lctys55Ckbh17wI8CvlNVN1TVL4APAKc0zCFJrVJVs91g8kzgpKp64XD7ecAfVdVZd1pvK7B1uPlg4Jv7uKl7Abcd4Lid1vL8zt7D2fscXlUP3dcHze2ZcFW1Ddi2v49Psr2qFlZxpJlay/M7ew9n75Nk+/48ruMQxM3A5iW37zMsk6R1pSPA/w48MMl9kxwGPBu4sGEOSWo180MQVbU7yVnAp4BDgXOq6toRNrXfhy/mxFqe39l7OHuf/Zp/5n+EkyRNeCacJDUxwJLUZM0HeKXTmpPcNcl5w/1XJDm2Ycw9mmL2Vya5LslVSS5J8nsdcy5n2lPKk/xZkkoyNy8zmmb2JM8avv/XJnnfrGdczhQ/N1uSXJrka8PPzskdc+5JknOS7EpyzTL3J8nfD1/bVUkeMesZlzPF7M8ZZr46yeVJHrbik1bVmr0w+SPed4H7AYcBXweOu9M6LwX+abj+bOC87rn3YfYnAL81XH/JvMw+7fzDekcAlwFfAha6596H7/0Dga8BRw2379099z7Mvg14yXD9OODG7rmXzPbHwCOAa5a5/2TgX4EAjwau6J55H2Y/YcnPy1OnmX2t7wFPc1rzKcC5w/UPAScmc/FhXSvOXlWXVtX/DDe/xOQ10/Ni2lPK/wZ4I/C/sxxuBdPM/iLgbVX1I4Cq2jXjGZczzewF3GO4/jvALTOcb6+q6jLgh3tZ5RTgPTXxJeDIJMfMZrq9W2n2qrr8jp8Xpvx9XesB3gTctOT2jmHZHtepqt3AT4Dfncl0ezfN7EudwWTPYF6sOP/w38fNVfWJWQ42hWm+9w8CHpTkC0m+lOSkmU23d9PM/nrguUl2AJ8EXjab0VbFvv5ezKupfl/n9lRk/VqS5wILwJ90zzKtJIcAbwFe0DzK/trA5DDE45nsyVyW5A+q6sedQ03pNODdVfXmJI8B/jnJQ6vq9u7B1oMkT2AS4MettO5a3wOe5rTm/18nyQYm/yX7wUym27upTslO8iTgdcDTqurnM5ptGivNfwTwUOCzSW5kcjzvwjn5Q9w03/sdwIVV9cuq+g/gW0yC3G2a2c8Azgeoqi8ChzN5s5u1YE2/VUGSPwTeCZxSVSt2Zq0HeJrTmi8ETh+uPxP4TA1HyZutOHuShwNnM4nvvByDvMNe56+qn1TVvarq2Ko6lskxsadV1X69ackqm+bn5l+Y7P2S5F5MDkncMMMZlzPN7P8JnAiQ5CFMArw40yn334XA84dXQzwa+ElV7eweahpJtgAXAM+rqm9N9aDuvyyuwl8mT2ayd/Jd4HXDsr9m8ssOkx++DwLfAb4M3K975n2Y/dPArcCVw+XC7pn3Zf47rftZ5uRVEFN+78PkEMp1wNXAs7tn3ofZjwO+wOQVElcCT+meecns7wd2Ar9k8r+MM4AXAy9e8n1/2/C1XT1nPzMrzf5O4EdLfl+3r/ScnoosSU3W+iEISVqzDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1OT/AEPGtnntT+qwAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(saliency_sim.mean())\n",
    "sns.displot(saliency_sim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Baseline: D-kNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier(algorithm='brute', metric='euclidean', n_neighbors=1)"
      ]
     },
     "execution_count": 94,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "temp_train = list()\n",
    "temp_test = list()\n",
    "\n",
    "for i in range(X_train_C.shape[0]):\n",
    "    x = deepcopy(X_train_x[i].flatten())\n",
    "    if x.sum() == 0.:\n",
    "        temp_train.append(x)\n",
    "    else:\n",
    "        x = x/x.sum()\n",
    "        temp_train.append(x)\n",
    "    \n",
    "for i in range(X_test_C.shape[0]):\n",
    "    x = deepcopy(X_test_x[i].flatten())\n",
    "    if x.sum() == 0.:\n",
    "        temp_test.append(x)\n",
    "    else:\n",
    "        x = x/x.sum()\n",
    "        temp_test.append(x)  \n",
    "        \n",
    "temp_train = np.array(temp_train)\n",
    "temp_test  = np.array(temp_test)\n",
    "\n",
    "# Fit COLE and DkNN\n",
    "twin = KNeighborsClassifier(n_neighbors=1, algorithm=\"brute\", metric='euclidean') \n",
    "twin.fit(temp_train, train_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|          | 100/10000 [00:06<10:21, 15.94it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Agreement: 1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "results = list()\n",
    "agreement = 0\n",
    "saliency_sim = list()\n",
    "\n",
    "for query_idx, data in enumerate(tqdm(test_loader)):\n",
    "    \n",
    "    # Get data\n",
    "    img, label = data\n",
    "    \n",
    "    # Get query information\n",
    "    query_label = y_test[query_idx].item()\n",
    "    query_pred  = test_preds[query_idx].item()\n",
    "    query_img_trans = get_transformed_data([query_idx], test_loader)[0]\n",
    "    query_cont = X_test_c[query_idx]\n",
    "    \n",
    "    # Get query boxes\n",
    "    query_logits, query_x, query_C = netC(query_img_trans)\n",
    "    query_nb_boxes = get_salient_regions(query_C, query_logits, net_classifier)\n",
    "    \n",
    "    # Get explanation nn\n",
    "    xp_idxs = twin.kneighbors(X=[temp_test[query_idx]], n_neighbors=NUM_NEIGHBORS, return_distance=False)\n",
    "    xps_imgs_trans = get_transformed_data(xp_idxs[0], train_loader)\n",
    "    xp_logits, xp_x, xp_C = netC(xps_imgs_trans[0])\n",
    "    xp_nb_boxes = get_salient_regions(xp_C, xp_logits, net_classifier)\n",
    "    \n",
    "    if train_preds[xp_idxs[0][0]] == query_pred:\n",
    "        agreement += 1\n",
    "        \n",
    "    saliency_sim.append(abs(query_nb_boxes[0][0] - xp_nb_boxes[0][0]))\n",
    "    \n",
    "    # See distance between query nb feature and nn\n",
    "    window_idx = query_nb_boxes[0][1]  # -1 just to make hyperparam easier to think about\n",
    "    query_feature = query_C[ :, :, window_idx[0]:window_idx[0]+1, window_idx[1]:window_idx[1]+1]\n",
    "\n",
    "    window_idx = xp_nb_boxes[0][1]  # -1 just to make hyperparam easier to think about\n",
    "    xp_feature = xp_C[ :, :, window_idx[0]:window_idx[0]+1, window_idx[1]:window_idx[1]+1]\n",
    "\n",
    "    dist = sum(abs(query_feature.flatten() - xp_feature.flatten()))\n",
    "\n",
    "    results.append(dist.item())\n",
    "        \n",
    "    if query_idx == 100:\n",
    "        print(\"Agreement:\", agreement / len(results))\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [],
   "source": [
    "saliency_sim = np.array(saliency_sim)\n",
    "dknn = np.array(deepcopy(results))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "69.84689330110456\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<seaborn.axisgrid.FacetGrid at 0x21436c050>"
      ]
     },
     "execution_count": 97,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFgCAYAAACFYaNMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAP0UlEQVR4nO3dXYxc9XmA8efFxhAVIkNYWcY2whSU1FJVQBsKBkUVNK1DP0wqiomixBdOjdRQBaVNY8pNkHoRqjYfraIUN6A4FQJTQgSJUlJCTaKK1HRJzFcciqGA7Ri8NFDSXiQ1fnsxx2VYeXfHy555Z2een7TamTMz3vfvYz+aPTtzNjITSVL/HVc9gCSNKgMsSUUMsCQVMcCSVMQAS1KRxdUD9GLdunV53333VY8hSUcTc33ggngG/PLLL1ePIEnzbkEEWJKGkQGWpCIGWJKKGGBJKmKAJamIAZakIgZYkooYYEkqYoAlqYgBlqQiBliSihhgSSpigCWpiAGWpCJDG+AVq84gIso/Vqw6o/qvQtKAWhAnZJ+LH+/by4abH6oeg+3XrK0eQdKAGtpnwJI06AywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUpHWAxwRiyLiBxHxjeb66ojYGRF7ImJ7RCxpewZJGkT9eAb8MWB31/WbgM9m5tnAK8CmPswgSQOn1QBHxErgt4AvNdcDuBS4q7nLNuCKNmeQpEHV9jPgzwF/Chxurr8DeDUzDzXX9wErjvbAiNgcERMRMTE5OdnymJLUf60FOCJ+GziYmY/M5fGZuTUzxzNzfGxsbJ6nk6R6bf5W5IuB342Iy4ETgbcDnweWRsTi5lnwSmB/izNI0sBq7RlwZl6fmSsz80zgauCfM/ODwA7gyuZuG4F72ppBkgZZxeuAPwl8PCL20DkmfEvBDJJUrs1DEP8vMx8EHmwuPwtc0I+vK0mDzHfCSVIRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRVoLcEScGBEPR8SjEfFkRNzYbF8dETsjYk9EbI+IJW3NIEmDrM1nwD8DLs3MXwHOBdZFxIXATcBnM/Ns4BVgU4szSNLAai3A2fHfzdXjm48ELgXuarZvA65oawZJGmStHgOOiEURsQs4CNwPPAO8mpmHmrvsA1ZM89jNETEREROTk5NtjilJJVoNcGa+npnnAiuBC4B3HcNjt2bmeGaOj42NtTWiJJXpy6sgMvNVYAdwEbA0IhY3N60E9vdjBkkaNG2+CmIsIpY2l98GvBfYTSfEVzZ32wjc09YMkjTIFs9+lzlbDmyLiEV0Qn9nZn4jIn4I3BERfw78ALilxRkkaWC1FuDMfAw47yjbn6VzPFiSRprvhJOkIgZYkooYYEkqYoAlqYgBlqQiBliSihhgSSpigCWpiAGWpCIGWJKKGGBJKmKAJamIAZakIgZYkooYYEkqYoAlqYgBlqQiBliSihhgSSpigCWpiAGWpCIGWJKKGGBJKmKAJamIAZakIgZYkooYYEkqYoAlqYgBlqQiBliSihhgSSpigCWpiAGWpCIGWJKKGGBJKtJTgCPi4l62SZJ61+sz4L/pcZskqUeLZ7oxIi4C1gJjEfHxrpveDixqczBJGnYzBhhYApzU3O/kru2vAVe2NZQkjYIZA5yZ3wG+ExFfzszn+zSTJI2E2Z4BH3FCRGwFzux+TGZe2sZQkjQKeg3wPwB/C3wJeL29cSRpdPQa4EOZ+cVWJ5GkEdPry9C+HhF/GBHLI+LUIx+tTiZJQ67XZ8Abm8+f6NqWwFnzO44kjY6eApyZq9seRJJGTU8BjogPH217Zn5lfseRpNHR6yGId3ddPhG4DPg+YIAlaY56PQTxR93XI2IpcEcbA0nSqJjr6Sj/B/C4sCS9Bb0eA/46nVc9QOckPL8E3NnWUJI0Cno9BvyXXZcPAc9n5r4W5pGkkdHTIYjmpDw/onNGtFOAn7c5lCSNgl5/I8ZVwMPA7wNXATsjwtNRStJb0OshiBuAd2fmQYCIGAO+DdzV1mCSNOx6fRXEcUfi2/jPY3isJOkoen0GfF9EfAu4vbm+AfhmOyNJ0miY7XfCnQ0sy8xPRMTvAZc0N30PuK3t4SRpmM32DPhzwPUAmXk3cDdARPxyc9vvtDibJA212Y7jLsvMx6dubLad2cpEkjQiZgvw0hlue9s8ziFJI2e2AE9ExB9M3RgRHwEeaWckSRoNsx0Dvg74WkR8kDeCOw4sAd4/0wMjYhWd01Uuo3Meia2Z+fnmVxltp3MI4zngqsx8ZY7zS9KCNeMz4Mx8KTPXAjfSieVzwI2ZeVFmvjjLn30I+OPMXANcCHw0ItYAW4AHMvMc4IHmuiSNnF7PB7wD2HEsf3BmHgAONJd/GhG7gRXAeuDXmrttAx4EPnksf7YkDYO+vJstIs4EzgN20nllxYHmphfpHKI42mM2R8RERExMTk72Y0xJ6qvWAxwRJwFfBa7LzNe6b8vM5I3zDDPltq2ZOZ6Z42NjY22PKUl912qAI+J4OvG9rXkjB8BLEbG8uX05cHC6x0vSMGstwBERwC3A7sz8TNdN9wIbm8sbgXvamkGSBlmvJ+OZi4uBDwGPR8SuZtufAZ8G7oyITcDzdM4vLEkjp7UAZ+a/ADHNzZe19XUlaaHwnL6SVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBVp8zdiCOC4xXR+O1Ot01euYv/eF6rHkNTFALft8CE23PxQ9RRsv2Zt9QiSpvAQhCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUU8Gc+oGICzsnlGNunNDPCoGICzsnlGNunNPAQhSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBVpLcARcWtEHIyIJ7q2nRoR90fE083nU9r6+pI06Np8BvxlYN2UbVuABzLzHOCB5rokjaTWApyZ3wV+MmXzemBbc3kbcEVbX1+SBl2/jwEvy8wDzeUXgWV9/vqSNDDKfgiXmQnkdLdHxOaImIiIicnJyT5OJrVvxaoziIjyjxWrzqj+qxhpi/v89V6KiOWZeSAilgMHp7tjZm4FtgKMj49PG2ppIfrxvr1suPmh6jHYfs3a6hFGWr+fAd8LbGwubwTu6fPXl6SB0ebL0G4Hvge8MyL2RcQm4NPAeyPiaeDXm+uSNJJaOwSRmR+Y5qbL2vqakrSQ+E44SSpigCWpSL9fBaFRdtxiIqJ6Ck5fuYr9e1+oHkMywOqjw4d86ZXUxUMQklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUZHH1AJIKHbeYiKiegtNXrmL/3heqx+g7AyyNssOH2HDzQ9VTsP2atdUjlPAQhCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQV8WxoGj0DcgpGyQBr9AzAKRhH9fSLejMPQUhSEQMsSUUMsCQVMcCSVMQAS1IRXwUhqd4AvDSw4jczG2BJ9Ub0pYEegpCkIgZYkoqUBDgi1kXEUxGxJyK2VMwgSdX6HuCIWAR8AXgfsAb4QESs6fccklSt4hnwBcCezHw2M38O3AGsL5hDkkpFZvb3C0ZcCazLzI801z8E/GpmXjvlfpuBzc3VdwJPAacBL/dx3EqjstZRWSe41mF0GvCjzFw3lwcP7MvQMnMrsLV7W0RMZOZ40Uh9NSprHZV1gmsdRs065xRfqDkEsR9Y1XV9ZbNNkkZKRYD/DTgnIlZHxBLgauDegjkkqVTfD0Fk5qGIuBb4FrAIuDUzn+zx4Vtnv8vQGJW1jso6wbUOo7e0zr7/EE6S1OE74SSpiAGWpCILJsDD/PbliHguIh6PiF0RMdFsOzUi7o+Ip5vPp1TPORcRcWtEHIyIJ7q2HXVt0fHXzT5+LCLOr5v82E2z1k9FxP5m3+6KiMu7bru+WetTEfGbNVMfu4hYFRE7IuKHEfFkRHys2T5U+3WGdc7fPs3Mgf+g88O6Z4CzgCXAo8Ca6rnmcX3PAadN2fYXwJbm8hbgpuo557i29wDnA0/MtjbgcuAfgQAuBHZWzz8Pa/0U8CdHue+a5t/xCcDq5t/3ouo19LjO5cD5zeWTgX9v1jNU+3WGdc7bPl0oz4BH8e3L64FtzeVtwBV1o8xdZn4X+MmUzdOtbT3wlez4V2BpRCzvy6DzYJq1Tmc9cEdm/iwz/wPYQ+ff+cDLzAOZ+f3m8k+B3cAKhmy/zrDO6RzzPl0oAV4B7O26vo+Z/yIWmgT+KSIead6CDbAsMw80l18EltWM1orp1jas+/na5lvvW7sOJQ3FWiPiTOA8YCdDvF+nrBPmaZ8ulAAPu0sy83w6Z4j7aES8p/vG7Hx/M5SvFxzmtTW+CPwicC5wAPir0mnmUUScBHwVuC4zX+u+bZj261HWOW/7dKEEeKjfvpyZ+5vPB4Gv0fm25aUj36Y1nw/WTTjvplvb0O3nzHwpM1/PzMPA3/HGt6QLeq0RcTydKN2WmXc3m4duvx5tnfO5TxdKgIf27csR8QsRcfKRy8BvAE/QWd/G5m4bgXtqJmzFdGu7F/hw81PzC4H/6vqWdkGacqzz/XT2LXTWenVEnBARq4FzgIf7Pd9cREQAtwC7M/MzXTcN1X6dbp3zuk+rf9J4DD+RvJzOTyGfAW6onmce13UWnZ+cPgo8eWRtwDuAB4CngW8Dp1bPOsf13U7n27T/pXNMbNN0a6PzU/IvNPv4cWC8ev55WOvfN2t5rPkPurzr/jc0a30KeF/1/MewzkvoHF54DNjVfFw+bPt1hnXO2z71rciSVGShHIKQpKFjgCWpiAGWpCIGWJKKGGBJKmKAJamIAZakIv8HLT1X9FbXQlwAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(dknn.mean())\n",
    "sns.displot(dknn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.19334502768988657\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<seaborn.axisgrid.FacetGrid at 0x21477f750>"
      ]
     },
     "execution_count": 98,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFgCAYAAACFYaNMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAQhUlEQVR4nO3de6xld1mH8ec7DAWFIgWOTTPOWMACVsSCQ8WBELBoBhIoFewliiUBZpQWIRAigolE/xCVi0YJdICmJYEyUEooAYulFAiUFodaepWrJbSUdgpqCUZxOq9/nFU4jHPZ05m13n3OeT7Jztl77XX6e/ee8nSx9mVSVUiSpremewBJWq0MsCQ1McCS1MQAS1ITAyxJTdZ2DzCLzZs31yWXXNI9hiTdW9nbxmVxBHznnXd2jyBJh92yCLAkrUQGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKarOgAr1u/gSSTXNat39D9cCUtM8viC9nvrW/f8i1OO+eKSdbavnXTJOtIWjlW9BGwJM0zAyxJTQywJDUxwJLUxABLUhMDLElNRgtwkvVJLk9yY5Ibkrx82P76JLcmuWa4PGusGSRpno35PuBdwKuq6uokRwJfTHLpcN9bquqNI64tSXNvtABX1W3AbcP17ye5CVg31nqStNxMcg44ybHA44Grhk1nJ7k2yblJjtrH72xJsiPJjp07d04xpiRNavQAJ3kg8EHgFVV1F/A24JHACSweIb9pb79XVduqamNVbVxYWBh7TEma3KgBTnJfFuP7nqq6CKCqbq+qu6tqN/AO4MQxZ5CkeTXmuyACvAu4qarevGT7MUt2OwW4fqwZJGmejfkuiCcDLwCuS3LNsO21wBlJTgAKuBnYOuIMkjS3xnwXxGeB7OWuj421piQtJ34STpKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAPlzVrSTLZZd36Dd2PWNIhWts9wIqxexennXPFZMtt37ppsrUkjcMjYElqYoAlqYkBlqQmBliSmhhgSWpigCWpyWgBTrI+yeVJbkxyQ5KXD9sfkuTSJF8dfh411gySNM/GPALeBbyqqo4HngScleR44DXAZVV1HHDZcFuSVp3RAlxVt1XV1cP17wM3AeuAk4Hzh93OB5471gySNM8mOQec5Fjg8cBVwNFVddtw13eAo/fxO1uS7EiyY+fOnVOMKUmTGj3ASR4IfBB4RVXdtfS+qiqg9vZ7VbWtqjZW1caFhYWxx5SkyY0a4CT3ZTG+76mqi4bNtyc5Zrj/GOCOMWeQpHk15rsgArwLuKmq3rzkrouBM4frZwIfHmsGSZpnY34b2pOBFwDXJblm2PZa4A3A+5O8CPgmcOqIM0jS3BotwFX1WSD7uPuksdaVpOXCT8JJUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0M8HK1Zi1JJrusW7+h+xFLK86YfymnxrR7F6edc8Vky23fummytaTVwiNgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJanJaAFOcm6SO5Jcv2Tb65PcmuSa4fKssdaXpHk35hHwecDmvWx/S1WdMFw+NuL6kjTXRgtwVX0G+N5Y/3xJWu46zgGfneTa4RTFUfvaKcmWJDuS7Ni5c+eU80nSJKYO8NuARwInALcBb9rXjlW1rao2VtXGhYWFicaTpOlMGuCqur2q7q6q3cA7gBOnXF+S5smkAU5yzJKbpwDX72tfSVrp1s6yU5InV9XnDrRtj/svAJ4GPCzJLcCfAU9LcgJQwM3A1ns3tiQtfzMFGPh74AkzbPuRqjpjL5vfNeN6krTi7TfASX4d2AQsJHnlkrseBNxnzMEkaaU70BHwEcADh/2OXLL9LuD5Yw0lSavBfgNcVZ8GPp3kvKr65kQzSdKqMOs54Psl2QYcu/R3quo3xhhKklaDWQP8AeDtwDuBu8cbR5JWj1kDvKuq3jbqJJK0ysz6QYyPJHlpkmOSPOSey6iTab6sWUuSyS7r1m/ofsTS6GY9Aj5z+PnqJdsKeMThHUdza/cuTjvnismW275102RrSV1mCnBVPXzsQSRptZn1o8i/v7ftVfXuwzuOJK0es56CeOKS6/cHTgKuBgywJN1Ls56CeNnS20keDLxvjIEkabW4t19H+QPA88KSdAhmPQf8ERbf9QCLX8Lzi8D7xxpKklaDWc8Bv3HJ9V3AN6vqlhHmkaRVY6ZTEMOX8vwri9+IdhTwwzGHkqTVYKYAJzkV+ALwO8CpwFVJ/DpKSToEs56CeB3wxKq6AyDJAvAJ4MKxBpOklW7Wd0GsuSe+g+8exO9KkvZi1iPgS5J8HLhguH0a8LFxRpKk1eFAfyfcLwBHV9Wrk/w28JThrs8D7xl7OElayQ50BPy3wJ8AVNVFwEUASX55uO/ZI84mSSvagc7jHl1V1+25cdh27CgTSdIqcaAAP3g/9/3UYZxDkladAwV4R5KX7LkxyYuBL44zkiStDgc6B/wK4ENJfpcfB3cjcARwyohzSdKKt98AV9XtwKYkTwceO2z+aFV9cvTJJGmFm/X7gC8HLh95FklaVfw0myQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCaT2vWkmSSy7r1G7ofrVapWf9aemlau3dx2jlXTLLU9q2bJllH2pNHwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU1GC3CSc5PckeT6JdsekuTSJF8dfh411vqSNO/GPAI+D9i8x7bXAJdV1XHAZcNtSVqVRgtwVX0G+N4em08Gzh+unw88d6z1JWneTX0O+Oiqum24/h3g6H3tmGRLkh1JduzcuXOa6SRpQm0vwlVVAbWf+7dV1caq2riwsDDhZJI0jakDfHuSYwCGn3dMvL4kzY2pA3wxcOZw/UzgwxOvL0lzY8y3oV0AfB54dJJbkrwIeAPwm0m+CjxjuC1Jq9Jo3wdcVWfs466TxlpTkpYTPwknSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1WduxaJKbge8DdwO7qmpjxxyS1KklwIOnV9WdjetLUitPQUhSk64AF/BPSb6YZMvedkiyJcmOJDt27tw58XiSNL6uAD+lqp4APBM4K8lT99yhqrZV1caq2riwsDD9hJI0spYAV9Wtw887gA8BJ3bMIUmdJg9wkgckOfKe68BvAddPPYckdet4F8TRwIeS3LP+e6vqkoY5JKnV5AGuqm8AvzL1upI0b3wbmiQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwNKatSSZ7LJu/YZJH9669Rt8bHP6+Dr+TjhpvuzexWnnXDHZctu3bppsLYBv3/KtyR7fSn5scPgfn0fAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhM/iizp8Bm+V0OzMcCSDp8V/r0ah5unICSpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpq0BDjJ5iRfTvK1JK/pmEGSuk0e4CT3Ad4KPBM4HjgjyfFTzyFJ3TqOgE8EvlZV36iqHwLvA05umEOSWqWqpl0weT6wuapePNx+AfBrVXX2HvttAbYMNx8NfHnGJR4G3HmYxh2D8x26eZ/R+Q7dvM94sPPdWVWb99y49vDNc3hV1TZg28H+XpIdVbVxhJEOC+c7dPM+o/Mdunmf8XDN13EK4lZg/ZLbPzdsk6RVpSPA/wwcl+ThSY4ATgcubphDklpNfgqiqnYlORv4OHAf4NyquuEwLnHQpy0m5nyHbt5ndL5DN+8zHpb5Jn8RTpK0yE/CSVITAyxJTZZtgA/0ceYk90uyfbj/qiTHztl8T01ydZJdw3ujJzXDfK9McmOSa5NcluTn52y+P0hyXZJrkny249OUs36kPsnzklSSSd9WNcNz+MIkO4fn8JokL56n+YZ9Th3+PbwhyXunnG+WGZO8Zcnz95Uk/3FQC1TVsruw+OLd14FHAEcAXwKO32OflwJvH66fDmyfs/mOBR4HvBt4/hw+f08Hfnq4/odz+Pw9aMn15wCXzNtzOOx3JPAZ4Epg4zzNB7wQ+Icpn7eDnO844F+Ao4bbPztvM+6x/8tYfFPBzGss1yPgWT7OfDJw/nD9QuCkJJmX+arq5qq6Ftg90UwHO9/lVfVfw80rWXy/9jzNd9eSmw8Apn41edaP1P8F8FfAf085HPP/kf9Z5nsJ8Naq+neAqrpjDmdc6gzggoNZYLkGeB3wrSW3bxm27XWfqtoF/Cfw0Emmm22+Tgc734uAfxx1op8003xJzkrydeCvgT+aaLZ7HHDGJE8A1lfVR6ccbDDrn/HzhtNMFyZZv5f7xzLLfI8CHpXkc0muTPL/Pso7spn/dzKcons48MmDWWC5BlgTSfJ7wEbgb7pn2VNVvbWqHgn8MfCn3fMslWQN8GbgVd2z7MdHgGOr6nHApfz4/zHOi7UsnoZ4GotHl+9I8uDOgfbjdODCqrr7YH5puQZ4lo8z/2ifJGuBnwG+O8l08/9x65nmS/IM4HXAc6rqfyaaDQ7++Xsf8NwxB9qLA814JPBY4FNJbgaeBFw84QtxB3wOq+q7S/5c3wn86kSzwWx/xrcAF1fV/1bVvwFfYTHIUzmYfw9P5yBPPwDL9kW4tcA3WDzkv+fk+C/tsc9Z/OSLcO+fp/mW7Hse078IN8vz93gWX4A4bk7/fI9bcv3ZwI55m3GP/T/FtC/CzfIcHrPk+inAlXM232bg/OH6w1g8HfDQeZpx2O8xwM0MH2w7qDWmejAjPDnPYvG/iF8HXjds+3MWj9YA7g98APga8AXgEXM23xNZ/C/8D1g8Mr9hzub7BHA7cM1wuXjO5vs74IZhtsv3F7+uGffYd9IAz/gc/uXwHH5peA4fM2fzhcXTODcC1wGnz+OfMfB64A335p/vR5ElqclyPQcsScueAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmvwfRoyF4gCTFOcAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(saliency_sim.mean())\n",
    "sns.displot(saliency_sim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## My Method: Baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier(algorithm='brute', metric='euclidean', n_neighbors=1)"
      ]
     },
     "execution_count": 99,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "NUM_NEIGHBORS = 1\n",
    "FEATURE_NUM = 1  # pick most nb feature\n",
    "\n",
    "# Fit COLE and DkNN\n",
    "twin = KNeighborsClassifier(n_neighbors=1, algorithm=\"brute\", metric='euclidean') \n",
    "twin.fit(X_train_c, train_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|          | 100/10000 [00:03<06:35, 25.06it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Agreement: 1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "results = list()\n",
    "agreement = 0\n",
    "saliency_sim = list()\n",
    "\n",
    "for query_idx, data in enumerate(tqdm(test_loader)):\n",
    "    \n",
    "    # Get data\n",
    "    img, label = data\n",
    "    \n",
    "    # Get query information\n",
    "    query_label = y_test[query_idx].item()\n",
    "    query_pred  = test_preds[query_idx].item()\n",
    "    query_img_trans = get_transformed_data([query_idx], test_loader)[0]\n",
    "    query_cont = X_test_c[query_idx]\n",
    "    \n",
    "    # Get query boxes\n",
    "    query_logits, query_x, query_C = netC(query_img_trans)\n",
    "    query_nb_boxes = get_salient_regions(query_C, query_logits, net_classifier)\n",
    "    \n",
    "    # Get explanation nn\n",
    "    xp_idxs = twin.kneighbors(X=[X_test_c[query_idx]], n_neighbors=NUM_NEIGHBORS, return_distance=False)\n",
    "    xps_imgs_trans = get_transformed_data(xp_idxs[0], train_loader)\n",
    "    xp_logits, xp_x, xp_C = netC(xps_imgs_trans[0])\n",
    "    xp_nb_boxes = get_salient_regions(xp_C, xp_logits, net_classifier)\n",
    "    \n",
    "    if train_preds[xp_idxs[0][0]] == query_pred:\n",
    "        agreement += 1\n",
    "        \n",
    "    saliency_sim.append(abs(query_nb_boxes[0][0] - xp_nb_boxes[0][0]))\n",
    "    \n",
    "    # See distance between query nb feature and nn\n",
    "    window_idx = query_nb_boxes[0][1]  # -1 just to make hyperparam easier to think about\n",
    "    query_feature = query_C[ :, :, window_idx[0]:window_idx[0]+1, window_idx[1]:window_idx[1]+1]\n",
    "\n",
    "    window_idx = xp_nb_boxes[0][1]  # -1 just to make hyperparam easier to think about\n",
    "    xp_feature = xp_C[ :, :, window_idx[0]:window_idx[0]+1, window_idx[1]:window_idx[1]+1]\n",
    "\n",
    "    dist = sum(abs(query_feature.flatten() - xp_feature.flatten()))\n",
    "\n",
    "    results.append(dist.item())\n",
    "        \n",
    "    if query_idx == 100:\n",
    "        print(\"Agreement:\", agreement / len(results))\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {},
   "outputs": [],
   "source": [
    "saliency_sim = np.array(saliency_sim)\n",
    "twin_sys = np.array(deepcopy(results))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "66.24577645027992\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<seaborn.axisgrid.FacetGrid at 0x21452c0d0>"
      ]
     },
     "execution_count": 102,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFgCAYAAACFYaNMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAASA0lEQVR4nO3df6zld13n8eernYIGqm3lOhmmU1sQcckmtuRasYBRUBya1YKLQGOwZHGnm7WGBnVTJVFM/AOUH8YfgQ62oW4qFIWGumqldhuIQcsOdWinFGzBQjsOM1NYt6wadNq3f5zv4GGcO3NmOt/v+8zc5yM5ued+zjn3+8733vu8537vOeemqpAkTe+07gEkab0ywJLUxABLUhMDLElNDLAkNdnQPcAitm7dWrfeemv3GJJ0JDnWG5wU94AfeeSR7hEk6YQ7KQIsSaciAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTU7ZAG/ech5J2k+bt5zXvSskLamT4gXZj8ffPfwQr7r2Y91jcNOVl3SPIGlJnbL3gCVp2RlgSWpigCWpiQGWpCajBTjJNyT5eJJPJrk3ya8M6xckuTPJA0luSvKksWaQpGU25j3grwIvqqrvAi4EtiZ5HvAW4B1V9e3A/wVeN+IMkrS0Rgtwzfz/4d0zhlMBLwL+cFi/AXjZWDNI0jIb9RhwktOT7AT2AbcBnwX+vqoODFd5GNg85gyStKxGDXBVPVZVFwLnAhcD37nobZNsS7IjyY79+/ePNaIktZnkURBV9ffAHcD3AmclOfgMvHOB3WvcZntVrVbV6srKyhRjStKkxnwUxEqSs4bz3wj8EHAfsxC/YrjaFcCHxppBkpbZmK8FsQm4IcnpzEL//qr6X0k+Bbwvya8Cfw1cN+IMkrS0RgtwVd0NXHSY9c8xOx4sSeuaz4STpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJqMFOMmWJHck+VSSe5O8flh/U5LdSXYOp0vHmkGSltmGET/2AeBnq+quJGcCn0hy23DZO6rqrSNuW5KW3mgBrqo9wJ7h/FeS3AdsHmt7knSymeQYcJLzgYuAO4elq5LcneT6JGevcZttSXYk2bF///4pxpSkSY0e4CRPBT4AXF1VjwLvBJ4JXMjsHvLbDne7qtpeVatVtbqysjL2mJI0uVEDnOQMZvG9sao+CFBVe6vqsap6HHg3cPGYM0jSshrzURABrgPuq6q3z61vmrvay4FdY80gSctszEdBPB94DXBPkp3D2i8Clye5ECjgQeDKEWeQpKU15qMg/gLIYS76k7G2KUknE58JJ0lNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNRnz9YAFcNoGZq9N3+vp525h90Nf6B5D0hwDPLbHD/Cqaz/WPQU3XXlJ9wiSDuEhCElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpqMFuAkW5LckeRTSe5N8vph/ZwktyW5f3h79lgzSNIyG/Me8AHgZ6vqOcDzgJ9O8hzgGuD2qnoWcPvwviStO6MFuKr2VNVdw/mvAPcBm4HLgBuGq90AvGysGSRpmU1yDDjJ+cBFwJ3AxqraM1z0RWDjGrfZlmRHkh379++fYkxJmtToAU7yVOADwNVV9ej8ZVVVQB3udlW1vapWq2p1ZWVl7DElaXKjBjjJGczie2NVfXBY3ptk03D5JmDfmDNI0rIa81EQAa4D7quqt89ddAtwxXD+CuBDY80gSctsw4gf+/nAa4B7kuwc1n4ReDPw/iSvAz4PvHLEGSRpaY0W4Kr6CyBrXPzisbYrSScLnwknSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUZKEAJ3n+ImuSpMUteg/4txZckyQt6Ij/FTnJ9wKXACtJ3jB30TcBp485mCSd6o72b+mfBDx1uN6Zc+uPAq8YayhJWg+OGOCq+gjwkSTvqarPTzSTJK0LR7sHfNCTk2wHzp+/TVW9aIyhJGk9WDTAfwC8C/hd4LHxxpGk9WPRAB+oqneOOokkrTOLPgztj5L89ySbkpxz8DTqZJJ0ilv0HvAVw9ufn1sr4BkndhxJWj8WCnBVXTD2IJK03iwU4CQ/ebj1qvq9EzuOJK0fix6C+O65898AvBi4CzDAknScFj0E8TPz7yc5C3jfGANJ0npxvC9H+Q+Ax4Ul6QlY9BjwHzF71APMXoTnPwDvH2soSVoPFj0G/Na58weAz1fVwyPMI0nrxkKHIIYX5fk0s1dEOxv45zGHkqT1YNH/iPFK4OPAjwOvBO5M4stRStITsOghiDcC311V+wCSrAB/DvzhWINJ0qlu0UdBnHYwvoMvHcNtJUmHsWhEb03yZ0lem+S1wB8Df3KkGyS5Psm+JLvm1t6UZHeSncPp0uMfXZJObkf7n3DfDmysqp9P8mPAC4aL/hK48Sgf+z3Ab/Pvny33jqp667+/uiStL0e7B/wbzP7/G1X1wap6Q1W9Abh5uGxNVfVR4MsnYEZJOiUdLcAbq+qeQxeHtfOPc5tXJbl7OERx9nF+DEk66R0twGcd4bJvPI7tvRN4JnAhsAd421pXTLItyY4kO/bv338cm5Kk5Xa0AO9I8l8PXUzyU8AnjnVjVbW3qh6rqseBdwMXH+G626tqtapWV1ZWjnVTkrT0jvY44KuBm5P8BP8W3FXgScDLj3VjSTZV1Z7h3ZcDu450fUk6lR0xwFW1F7gkyQ8A/3FY/uOq+t9H+8BJ3gt8P/C0JA8Dvwx8f5ILmb2wz4PAlcc9uSSd5BZ9PeA7gDuO5QNX1eWHWb7uWD6GJJ3KfDabJDUxwJLUxABLUhMDLElNDLAkNVn09YB1sjttA0m6p+Dp525h90Nf6B5DWgoGeL14/ACvuvZj3VNw05WXdI8gLQ0PQUhSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1IT/yOGpuW/RpK+xgBrWv5rJOlrPAQhSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1GS0ACe5Psm+JLvm1s5JcluS+4e3Z4+1fUladmPeA34PsPWQtWuA26vqWcDtw/uStC6NFuCq+ijw5UOWLwNuGM7fALxsrO1L0rKb+hjwxqraM5z/IrBxrSsm2ZZkR5Id+/fvn2Y6SZpQ2x/hqqqAOsLl26tqtapWV1ZWJpxMkqYxdYD3JtkEMLzdN/H2JWlpTB3gW4ArhvNXAB+aePuStDTGfBjae4G/BJ6d5OEkrwPeDPxQkvuBHxzel6R1acNYH7iqLl/johePtU1JOpn4TDhJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWGm3ech5JWk+bt5zXvRvWrQ3dA0jr2d89/BCvuvZjrTPcdOUlrdtfz7wHLElNDLAkNTHAktTEAEtSk5Y/wiV5EPgK8BhwoKpWO+aQpE6dj4L4gap6pHH7ktTKQxCS1KTrHnABH05SwLVVtf3QKyTZBmwDOO88HyiuE+y0DSTpnmI5LMm+OP2MJ/PYv3y1ewyefu4Wdj/0hUm21RXgF1TV7iTfCtyW5NNV9dH5KwxR3g6wurpaHUPqFPb4gfYnQMCSPAliifbFsswxlZZDEFW1e3i7D7gZuLhjDknqNHmAkzwlyZkHzwMvAXZNPYckdes4BLERuHk45rQB+P2qurVhDklqNXmAq+pzwHdNvV1JWjY+DE2SmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKatAQ4ydYkn0nyQJJrOmaQpG6TBzjJ6cDvAC8FngNcnuQ5U88hSd067gFfDDxQVZ+rqn8G3gdc1jCHJLVKVU27weQVwNaq+qnh/dcA31NVVx1yvW3AtuHdZwOfmXTQo3sa8Ej3EIdYtpmWbR5wpkU502LmZ3qkqrYey403nPh5Toyq2g5s755jLUl2VNVq9xzzlm2mZZsHnGlRzrSYJzpTxyGI3cCWuffPHdYkaV3pCPD/AZ6V5IIkTwJeDdzSMIcktZr8EERVHUhyFfBnwOnA9VV179RznADLeHhk2WZatnnAmRblTIt5QjNN/kc4SdKMz4STpCYGWJKaGOCjSLIlyR1JPpXk3iSvH9bflGR3kp3D6dKJ53owyT3DtncMa+ckuS3J/cPbsyec59lz+2JnkkeTXD31fkpyfZJ9SXbNrR12v2TmN4enxN+d5LkTzvTrST49bPfmJGcN6+cn+ae5/fWuCWda83OV5BeG/fSZJD884Uw3zc3zYJKdw/pU+2mt7/8T8zVVVZ6OcAI2Ac8dzp8J/A2zp1C/Cfi5xrkeBJ52yNqvAdcM568B3tI02+nAF4Fvm3o/Ad8HPBfYdbT9AlwK/CkQ4HnAnRPO9BJgw3D+LXMznT9/vYn302E/V8PX+yeBJwMXAJ8FTp9ipkMufxvwSxPvp7W+/0/I15T3gI+iqvZU1V3D+a8A9wGbe6da02XADcP5G4CXNc3xYuCzVfX5qTdcVR8FvnzI8lr75TLg92rmr4CzkmyaYqaq+nBVHRje/Stmj4efzBr7aS2XAe+rqq9W1d8CDzB7SYHJZkoS4JXAe0/0do8y01rf/yfka8oAH4Mk5wMXAXcOS1cNv2ZcP+Wv+4MCPpzkE8PTtgE2VtWe4fwXgY0Tz3TQq/n6b5TO/QRr75fNwENz13uYnh+u/4XZvaaDLkjy10k+kuSFE89yuM/VMuynFwJ7q+r+ubVJ99Mh3/8n5GvKAC8oyVOBDwBXV9WjwDuBZwIXAnuY/Xo0pRdU1XOZvarcTyf5vvkLa/b70OSPMczsyTU/CvzBsNS9n75O135ZS5I3AgeAG4elPcB5VXUR8Abg95N800TjLNXn6hCX8/U/1CfdT4f5/v+aJ/I1ZYAXkOQMZjv/xqr6IEBV7a2qx6rqceDdjPAr2ZFU1e7h7T7g5mH7ew/+ujO83TflTIOXAndV1d5hvtb9NFhrv7Q+LT7Ja4H/BPzE8E3M8Gv+l4bzn2B2vPU7ppjnCJ+r7v20Afgx4Ka5WSfbT4f7/ucEfU0Z4KMYjj1dB9xXVW+fW58/rvNyYNehtx1xpqckOfPgeWZ/0NnF7CndVwxXuwL40FQzzfm6eyqd+2nOWvvlFuAnh79cPw/4f3O/Vo4qyVbgfwA/WlX/OLe+ktlrZpPkGcCzgM9NNNNan6tbgFcneXKSC4aZPj7FTIMfBD5dVQ8fXJhqP631/c+J+poa+6+IJ/sJeAGzXy/uBnYOp0uB/wncM6zfAmyacKZnMPur9CeBe4E3DuvfAtwO3A/8OXDOxPvqKcCXgG+eW5t0PzGL/x7gX5gdf3vdWvuF2V+qf4fZvad7gNUJZ3qA2bHCg19T7xqu+5+Hz+lO4C7gRyacac3PFfDGYT99BnjpVDMN6+8B/tsh151qP631/X9CvqZ8KrIkNfEQhCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUpN/BUIY8xtKpoYKAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(twin_sys.mean())\n",
    "sns.displot(twin_sys)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.17792157194401959\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<seaborn.axisgrid.FacetGrid at 0x214037cd0>"
      ]
     },
     "execution_count": 103,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFgCAYAAACFYaNMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAQzklEQVR4nO3de6ykBXnH8e8PVrRVrKhbQtelqxYv1Fq0R2vRGC3WrCSKVssl1WKKLlWwGo2p1SY19Y/a1kub1iirErBRBBEiVsUiokYR7IorF6nXYllAOGhbjE3Vhad/zLv1dNnL7HLeec7l+0kmZ+ady/vs7PLlPe/MO5OqQpI0ewd0DyBJq5UBlqQmBliSmhhgSWpigCWpyZruAaaxcePGuvjii7vHkKT9lV0tXBZbwLfffnv3CJK06JZFgCVpJTLAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktRkRQd43frDSTLKad36w7v/eJKWuWXxgez76+ZtN3LCGZeP8tjnnnr0KI8rafVY0VvAkrSUGWBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpqMFuAk65NcluRrSa5L8sph+RuT3JRk63A6dqwZJGkpG/MribYDr6mqq5IcDHw5ySXDdW+vqreMuG5JWvJGC3BV3QLcMpz/YZLrgXVjrU+SlpuZ7ANOsgF4HHDlsOj0JFcnOTPJIbu5z6YkW5JsmZ+fn8WYkjRTowc4yf2ADwOvqqo7gHcCDweOYrKF/NZd3a+qNlfVXFXNrV27duwxJWnmRg1wknsxie/7q+oCgKq6tarurKq7gHcDTxxzBklaqsZ8F0SA9wLXV9XbFiw/bMHNngdcO9YMkrSUjfkuiCcDLwKuSbJ1WPZ64KQkRwEF3ACcOuIMkrRkjfkuiM8D2cVVHx9rnZK0nHgknCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTdZ0D7BsHbCGJKM89IH3ujd3/vTHozz2Lz1kPTfd+O+jPLakfWOA99dd2znhjMtHeehzTz161MeWtDS4C0KSmhhgSWpigCWpiQGWpCYGWJKaGGBJajJagJOsT3JZkq8luS7JK4flD0xySZJvDj8PGWsGSVrKxtwC3g68pqqOBJ4EnJbkSOB1wKVVdQRw6XBZklad0QJcVbdU1VXD+R8C1wPrgOOAs4ebnQ08d6wZJGkpm8k+4CQbgMcBVwKHVtUtw1XfAw7dzX02JdmSZMv8/PwsxpSkmRo9wEnuB3wYeFVV3bHwuqoqoHZ1v6raXFVzVTW3du3asceUpJkbNcBJ7sUkvu+vqguGxbcmOWy4/jDgtjFnkKSlasx3QQR4L3B9Vb1twVUXAScP508GPjLWDJK0lI35aWhPBl4EXJNk67Ds9cCbgfOSnAJ8Fzh+xBkkackaLcBV9Xlgdx+Ye8xY65Wk5cIj4SSpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWoyVYCTPHmaZTtdf2aS25Jcu2DZG5PclGTrcDp230eWpJVh2i3gv59y2UJnARt3sfztVXXUcPr4lOuXpBVnzZ6uTPJbwNHA2iSvXnDV/YED93Tfqvpckg33eEJJWqH2tgV8EHA/JqE+eMHpDuAF+7nO05NcPeyiOGR3N0qyKcmWJFvm5+f3c1WStHTtcQu4qj4LfDbJWVX13UVY3zuBNwE1/Hwr8Ie7WfdmYDPA3NxcLcK6JWlJ2WOAF7h3ks3AhoX3qarf3peVVdWtO84neTfwT/tyf0laSaYN8IeAdwHvAe7c35UlOayqbhkuPg+4dk+3l6SVbNoAb6+qd+7LAyc5B3ga8OAk24A/B56W5CgmuyBuAE7dl8eUpJVk2gB/NMnLgQuBH+9YWFU/2N0dquqkXSx+776NJ0kr17QBPnn4+doFywp42OKOI0mrx1QBrqqHjj2IJK02UwU4yR/sanlVvW9xx5Gk1WPaXRBPWHD+PsAxwFWAAZak/TTtLohXLLyc5AHAB8cYSJJWi/39OMofAe4XXo4OWEOSUU7r1h/e/aeTlpVp9wF/lMm7HmDyITyPBs4bayiN6K7tnHDG5aM89LmnHj3K40or1bT7gN+y4Px24LtVtW2EeSRp1ZhqF8TwoTz/yuST0A4BfjLmUJK0Gkz7jRjHA18Cfg84Hrgyyf5+HKUkiel3QbwBeEJV3QaQZC3wKeD8sQaTpJVu2ndBHLAjvoPv78N9JUm7MO0W8MVJPgmcM1w+AfD73CTpHtjbd8L9CnBoVb02ye8CTxmu+iLw/rGHk6SVbG9bwH8L/ClAVV0AXACQ5NeG65494myStKLtbT/uoVV1zc4Lh2UbRplIklaJvQX4AXu47ucWcQ5JWnX2FuAtSV6688IkLwG+PM5IkrQ67G0f8KuAC5P8Pj8L7hxwEJMv1ZQk7ac9Bnj4GvmjkzwdeMyw+GNV9enRJ5OkFW7azwO+DLhs5FkkaVXxaDZJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJanJaAFOcmaS25Jcu2DZA5NckuSbw89Dxlq/JC11Y24BnwVs3GnZ64BLq+oI4NLhsiStSqMFuKo+B/xgp8XHAWcP588GnjvW+iVpqZv1PuBDq+qW4fz3gEN3d8Mkm5JsSbJlfn5+NtNJ0gy1vQhXVQXUHq7fXFVzVTW3du3aGU4mSbMx6wDfmuQwgOHnbTNevyQtGbMO8EXAycP5k4GPzHj9krRkjPk2tHOALwKPTLItySnAm4HfSfJN4BnDZUlaldaM9cBVddJurjpmrHVK0nLikXCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwFo8B6whySindesP7/7TSYtuTfcAWkHu2s4JZ1w+ykOfe+rRozyu1MktYElqYoAlqYkBlqQmBliSmhhgSWpigCWpScvb0JLcAPwQuBPYXlVzHXNIUqfO9wE/vapub1y/JLVyF4QkNekKcAH/nOTLSTbt6gZJNiXZkmTL/Pz8jMeTpPF1BfgpVfV44FnAaUmeuvMNqmpzVc1V1dzatWtnP6EkjawlwFV10/DzNuBC4Ikdc0hSp5kHOMl9kxy84zzwTODaWc8hSd063gVxKHBhkh3r/0BVXdwwhyS1mnmAq+o7wK/Per2StNT4NjRJamKAJamJAZakJgZYkpoYYElqYoC16q1bf7jf5qwWfiuyVr2bt93otzmrhVvAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMPRdbycMAahq+xklYMA6zl4a7tfl6DVhx3QUhSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLI1pOITar7zXrngosjQmD6HWHrgFLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywtFyNeJhzEtYcdJ9leRj1uvWHL5u5PRRZWq5GPMwZJoc6L8fDqG/eduOymdstYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYtAU6yMcnXk3wryes6ZpCkbjMPcJIDgXcAzwKOBE5KcuSs55Ckbh1bwE8EvlVV36mqnwAfBI5rmEOSWqWqZrvC5AXAxqp6yXD5RcBvVtXpO91uE7BpuPhI4OtTruLBwO2LNO49tVRmcY67WyqzOMfdLZVZFnOO26tq484Ll+xnQVTVZmDzvt4vyZaqmhthpH22VGZxjrtbKrM4x90tlVlmMUfHLoibgPULLj9kWCZJq0pHgP8FOCLJQ5McBJwIXNQwhyS1mvkuiKranuR04JPAgcCZVXXdIq5in3dbjGipzOIcd7dUZnGOu1sqs4w+x8xfhJMkTXgknCQ1McCS1GTZBnhvhzMnuXeSc4frr0yyoWmOpya5Ksn24T3Qo5lillcn+VqSq5NcmuSXm+b4oyTXJNma5PNjHgk57WHvSZ6fpJKM8rajKZ6TFyeZH56TrUle0jHHcJvjh38n1yX5wBhzTDNLkrcveD6+keQ/m+Y4PMllSb4y/Ldz7KKtvKqW3YnJi3ffBh4GHAR8FThyp9u8HHjXcP5E4NymOTYAjwXeB7yg+Tl5OvDzw/mXNT4n919w/jnAxV3PyXC7g4HPAVcAc03PyYuBfxjr38c+zHEE8BXgkOHyL3b+3Sy4/SuYvGDf8ZxsBl42nD8SuGGx1r9ct4CnOZz5OODs4fz5wDFJMus5quqGqroauGuR170/s1xWVf89XLyCyXuwO+a4Y8HF+wJjvRI87WHvbwL+Cvif5jnGNs0cLwXeUVX/AVBVtzXOstBJwDlNcxRw/+H8LwA3L9bKl2uA1wE3Lri8bVi2y9tU1Xbgv4AHNcwxK/s6yynAJ7rmSHJakm8Dfw388QhzTDVLkscD66vqYyPNMNUcg+cPv+Ken2T9Lq6fxRyPAB6R5AtJrkhyt8NnZzgLAMOusocCn26a443AC5NsAz7OZGt8USzXAOseSPJCYA74m64ZquodVfVw4E+AP+uYIckBwNuA13SsfycfBTZU1WOBS/jZb2+ztobJboinMdnqfHeSBzTNssOJwPlVdWfT+k8CzqqqhwDHAv84/Nu5x5ZrgKc5nPn/bpNkDZNfHb7fMMesTDVLkmcAbwCeU1U/7ppjgQ8Czx1hjmlmORh4DPCZJDcATwIuGuGFuL0+J1X1/QV/H+8BfmORZ5hqDiZbgBdV1U+r6t+AbzAJcscsO5zIOLsfpp3jFOA8gKr6InAfJh/Uc8+NsYN97BOT/0t/h8mvJTt2nP/qTrc5jf//Itx5HXMsuO1ZjPsi3DTPyeOYvOBwRPMcRyw4/2xgS9csO93+M4zzItw0z8lhC84/D7iiaY6NwNnD+Qcz+fX8QV1/N8CjgBsYDhprek4+Abx4OP9oJvuAF2WeRf8DzerE5FeBbwxBecOw7C+YbNnB5P9SHwK+BXwJeFjTHE9gslXxIyZb4Nc1PiefAm4Ftg6ni5rm+DvgumGGy/YUxbFn2em2owR4yufkL4fn5KvDc/KopjnCZLfM14BrgBM7/26Y7H9981gzTPmcHAl8Yfi72Qo8c7HW7aHIktRkue4DlqRlzwBLUhMDLElNDLAkNTHAktTEAEtSEwMsSU3+F/hCXWZMmO1uAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(saliency_sim.mean())\n",
    "sns.displot(saliency_sim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## My Advanced Method"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "What are the constraints for explanation-by-example?\n",
    "1. It should preferably be a correctly classified instance of the same class. (not always possible)\n",
    "2. It should have a similar latent feature which can be referenced.\n",
    "3. That feature should have a similar contribution to the classification. (measured with logits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_NEIGHBORS = 10\n",
    "FEATURE_NUM = 1  # pick most nb feature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier(algorithm='brute', metric='euclidean', n_neighbors=1)"
      ]
     },
     "execution_count": 105,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "keep = (train_preds == y_train.detach().numpy())\n",
    "\n",
    "temp_train = X_train_c[keep]\n",
    "temp_preds = train_preds[keep]\n",
    "\n",
    "temp_train_loader = deepcopy(train_loader)\n",
    "temp_train_loader.dataset.data = temp_train_loader.dataset.data[keep]\n",
    "temp_train_loader.dataset.targets = temp_train_loader.dataset.targets[keep]\n",
    "\n",
    "# Fit twin\n",
    "twin = KNeighborsClassifier(n_neighbors=1, algorithm=\"brute\", metric='euclidean') \n",
    "twin.fit(temp_train, temp_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Fit twin\n",
    "# twin = KNeighborsClassifier(n_neighbors=1, algorithm=\"brute\", metric='euclidean') \n",
    "# twin.fit(X_train_c, train_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|          | 100/10000 [00:09<15:56, 10.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Contamination: 100.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "results = list()\n",
    "saliency_sim = list()\n",
    "contamination = 0\n",
    "\n",
    "\n",
    "for query_idx, data in enumerate(tqdm(test_loader)):\n",
    "    \n",
    "    # Get data\n",
    "    img, label = data\n",
    "    \n",
    "    # Get query information\n",
    "    query_label = y_test[query_idx].item()\n",
    "    query_pred  = test_preds[query_idx].item()\n",
    "    query_img_trans = get_transformed_data([query_idx], test_loader)[0]\n",
    "    query_cont = X_test_c[query_idx]\n",
    "    \n",
    "    # Get query salient feature\n",
    "    query_logits, query_x, query_C = netC(query_img_trans)\n",
    "    query_nb_boxes = get_salient_regions(query_C, query_logits, net_classifier)\n",
    "    window_idx = query_nb_boxes[0][1]  # Only the most NB one\n",
    "    query_feature = query_C[ :, :, window_idx[0]:window_idx[0]+1, window_idx[1]:window_idx[1]+1]\n",
    "\n",
    "    # Get explanation nn\n",
    "    xp_idxs = twin.kneighbors(X=[X_test_c[query_idx]], n_neighbors=NUM_NEIGHBORS, return_distance=False)\n",
    "    \n",
    "    \n",
    "    # How many NNs are different class\n",
    "    contamination += (temp_preds[xp_idxs[0]] == query_pred).sum()\n",
    "    \n",
    "    \n",
    "    # Iterate all three salient regions in xp nn\n",
    "    min_dist = float('inf')\n",
    "    min_logit = float('inf')\n",
    "    for i in range(NUM_NEIGHBORS):\n",
    "\n",
    "        xps_img_trans = get_transformed_data([xp_idxs[0][i]], temp_train_loader)\n",
    "        xp_logits, xp_x, xp_C = netC(xps_img_trans[0])\n",
    "        xp_pred = torch.argmax(xp_logits, dim=1).item()\n",
    "        \n",
    "        coord, dist_query_to_xp = get_box_xp(xp_C, query_feature)  # get similar latent feature in nn xp\n",
    "        xp_feature = xp_C[ :, :, coord[0]:coord[0]+1, coord[1]:coord[1]+1]\n",
    "        \n",
    "        # Get logit change\n",
    "        temp = xp_C.clone()\n",
    "        temp[ :, :, coord[0]:coord[0]+1, coord[1]:coord[1]+1] = 0.0\n",
    "        new_xp_logits = net_classifier(temp)\n",
    "        logit_change = xp_logits[0][xp_pred] - new_xp_logits[0][xp_pred]\n",
    "        logit_change = abs( logit_change - query_nb_boxes[0][0] )\n",
    "                       \n",
    "        if dist_query_to_xp < min_dist:\n",
    "            min_dist = dist_query_to_xp\n",
    "            min_logit = logit_change\n",
    "\n",
    "    results.append(min_dist)\n",
    "    saliency_sim.append(min_logit.item())\n",
    "    \n",
    "    if query_idx == 100:\n",
    "        break\n",
    "            \n",
    "print('Contamination:', contamination / len(results)*NUM_NEIGHBORS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [],
   "source": [
    "saliency_sim = np.array(saliency_sim)\n",
    "results = np.array(deepcopy(results))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.20668514264692175\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<seaborn.axisgrid.FacetGrid at 0x213d2a750>"
      ]
     },
     "execution_count": 110,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFgCAYAAACFYaNMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAOdElEQVR4nO3dfaxkd13H8c+3rBWFEkrYNM26taCEWBEBF9RCCIiaoqGAIJX4UBOwVUEkGCKKiU//4LOJEmxVQkkQCgihBCxiRYipTytBQJEUDdhCabdghGgUl379Y6fhZt3tzrY753sfXq9kcmfOzO35/no375ycO3NudXcAWN5Z0wMA7FUCDDBEgAGGCDDAEAEGGLJveoB1XHLJJX399ddPjwFwT9WJNu6II+A77rhjegSAM25HBBhgNxJggCECDDBEgAGGCDDAEAEGGCLAAEMEGGCIAAMMEWCAIQIMMESAAYYIMMAQAQYYIsBnyIGDF6SqFrvtO/u+i+7vwMELpv8Xw66zIy7IvhN86pabc9lVNy62v2uvvHjx/QFnliNggCECDDBEgAGGCDDAEAEGGCLAAEMEGGCIAAMMEWCAIQIMMESAAYYIMMAQAQYYIsAAQwQYYIgAAwwRYIAhAgwwRIABhggwwBABBhgiwABDNhbgqjpYVe+pqn+qqn+sqp9cbX9QVb27qm5afT13UzMAbGebPAI+muSnuvuiJN+S5AVVdVGSlyW5obsfluSG1WOAPWdjAe7uW7v7/av7n0/ykSQHkjw9yTWrl12T5BmbmgFgO1vkHHBVXZjk0Un+Jsl53X3r6qlPJzlviRkAtpuNB7iq7p/kj5O8uLs/t/W57u4kfZLvu6KqDlfV4SNHjtyjfR84eEGqapEbwOnat8n/eFV9WY7F93Xd/ZbV5tuq6vzuvrWqzk9y+4m+t7uvTnJ1khw6dOiEkT6VT91ycy676sZ78q2n7dorL15kP8Duscl3QVSSP0zyke7+zS1PXZfk8tX9y5O8bVMzAGxnmzwCfnySH0zyoar6wGrbzyZ5RZI3VtXzknwiyXM2OAPAtrWxAHf3XyY52cnRp2xqvwA7hU/CAQwRYIAhAgwwRIABhggwwBABBhgiwABDBBhgiAADDBFggCECDDBEgAGGCDDAEAEGGCLAAEMEGGCIAAMMEWCAIQIMMESAAYYIMMAQAQYYIsAAQwQYYIgAAwwRYIAhAgwwRIABhggwwBABBhgiwABDBBhgiAADDBFggCECDDBEgAGGCDDAEAEGGCLAAEMEGGCIAAMMEWCAIQIMMESAAYYIMMAQAQYYIsAAQwQYYIgAAwwRYIAhAsx6ztqXqlrsduDgBdMrho3bNz0AO8SdR3PZVTcutrtrr7x4sX3BFEfAAEMEGGCIAAMMEWCAIQIMMESAAYYIMMAQAQYYIsAAQzYW4Kp6dVXdXlUf3rLtF6rqk1X1gdXtuza1f4DtbpNHwK9JcskJtv9Wdz9qdXvnBvcPsK1tLMDd/b4kn93Ufx9gp5s4B/zCqvrg6hTFuSd7UVVdUVWHq+rwkSNHlpwPYBFLB/hVSb4myaOS3JrkN072wu6+ursPdfeh/fv3LzQewHIWDXB339bdX+zuO5P8fpLHLbl/gO1k0QBX1flbHj4zyYdP9lqA3W5jF2SvqtcneVKSB1fVLUl+PsmTqupRSTrJx5Ncuan9A2x3Gwtwdz/3BJv/cFP7A9hpfBIOYIgAAwwRYIAhAgwwRIABhggwwBABBhgiwABDBBhgiAADDBFggCECDDBEgAGGCDDAEAEGGCLAAEMEGGCIAAMMEWCAIQIMMESAAYYIMMAQAQYYIsAAQwQYYIgAAwwRYIAhAgwwRIABhggwwJC1AlxVj19nGwDrW/cI+HfW3AbAmvbd3ZNV9a1JLk6yv6pesuWpByS5zyYHA9jt7jbASc5Ocv/V687Zsv1zSZ69qaEA9oK7DXB3vzfJe6vqNd39iYVmAtgTTnUEfJcvr6qrk1y49Xu6+9s2MRTAXrBugN+U5PeS/EGSL25uHIC9Y90AH+3uV210EoA9Zt23ob29qn68qs6vqgfdddvoZAC73LpHwJevvr50y7ZO8tAzOw7A3rFWgLv7IZseBGCvWSvAVfVDJ9re3a89s+MA7B3rnoJ47Jb7903ylCTvTyLAAPfQuqcgfmLr46p6YJI3bGIggL3inl6O8j+TOC8McC+sew747Tn2rofk2EV4vi7JGzc1FMBesO454F/fcv9okk909y0bmAdgz1jrFMTqojz/nGNXRDs3yRc2ORTAXrDuX8R4TpK/TfK9SZ6T5G+qyuUoAe6FdU9BvDzJY7v79iSpqv1J/izJmzc1GMBut+67IM66K74rnzmN7wXgBNY9Ar6+qt6V5PWrx5cleedmRgLYG071N+G+Nsl53f3SqvqeJE9YPfVXSV636eEAdrNTHQH/dpKfSZLufkuStyRJVX3D6rmnbXA2gF3tVOdxz+vuDx2/cbXtwo1MBLBHnCrAD7yb577iDM4BsOecKsCHq+pHjt9YVc9P8vebGQlgbzjVOeAXJ3lrVX1/vhTcQ0nOTvLMDc4FsOvdbYC7+7YkF1fVk5M8YrX5Hd395xufDGCXW/d6wO9J8p4NzwKwp2zs02xV9eqqur2qPrxl24Oq6t1VddPq67mb2j/AdrfJjxO/Jsklx217WZIbuvthSW5YPQbYkzYW4O5+X5LPHrf56UmuWd2/JskzNrV/gO1u6QvqnNfdt67ufzrJeQvvH2DbGLuiWXd3vvRnjv6fqrqiqg5X1eEjR44sOBnAMpYO8G1VdX6SrL7efrIXdvfV3X2ouw/t379/sQEBlrJ0gK9Lcvnq/uVJ3rbw/gG2jU2+De31OXbZyodX1S1V9bwkr0jyHVV1U5JvXz0G2JPWvSD7aevu557kqadsap8AO4k/KwQwRIABhggwwBABBhgiwABDBBhgiAADDBFggCECDDBEgNmeztqXqlrkduDgBdOrZY/a2EeR4V6582guu+rGRXZ17ZUXL7IfOJ4jYIAhAgwwRIABhggwwBABBhgiwABDBBhgiAADDBFggCECDDBEgAGGCDAseOEfF/9hKxfjgQUv/JO4+A9f4ggYYIgAAwwRYIAhAgwwRIABhggwwBABBhgiwABDBBhgiAADDBFggCECDDBEgAGGCDDAEAEGGCLAAEMEGGCIAAMMEWCAIQIMMESAAYYIMMAQAQYYIsAAQwQYYIgAAwwRYIAhAgwwRIABhggwwBABBhgiwABDBBhgiAADDBFggCECDDBEgAGG7JvYaVV9PMnnk3wxydHuPjQxB8CkkQCvPLm77xjcP8AopyAAhkwFuJP8aVX9fVVdcaIXVNUVVXW4qg4fOXJk4fEANm8qwE/o7sckeWqSF1TVE49/QXdf3d2HuvvQ/v37l58QYMNGAtzdn1x9vT3JW5M8bmIOgEmLB7iq7ldV59x1P8l3Jvnw0nMATJt4F8R5Sd5aVXft/4+6+/qBOQBGLR7g7v7XJN+49H4BthtvQwMYIsAAQwQYYIgAAwwRYIAhAgwwRIABhggwwBABBhgiwABDBBhgiAADDBFggCECDDBEgAGGCDDAEAEGGCLAAEMEGGCIAAMMEWCAIQIMMESAAYYIMMAQAQYYIsAAQwQYYIgAAwwRYIAhAgwwRIABhggwwBABBhgiwABDBBhgiAADDBFgWNpZ+1JVi90OHLxgesWcxL7pAWDPufNoLrvqxsV2d+2VFy+2L06PI2CAIQIMMESAAYYIMMAQAQYYIsAAQwQYYIgAAwwRYIAhAgwwRIABhggwwBABht1uwauvLX3ltQMHL9jRV5ZzNTTY7Ra8+trSV1771C037+gryzkCBhgiwABDBBhgiAADDBFggCECDDBEgAGGCDDAEAEGGDIS4Kq6pKo+WlUfq6qXTcwAMG3xAFfVfZK8MslTk1yU5LlVddHScwBMmzgCflySj3X3v3b3F5K8IcnTB+YAGFXdvewOq56d5JLufv7q8Q8m+ebufuFxr7siyRWrhw9P8tFFB00enOSOhfe5FGvbuXbz+nbz2u7o7kuO37htr4bW3VcnuXpq/1V1uLsPTe1/k6xt59rN69vNazuZiVMQn0xycMvjr1ptA9hTJgL8d0keVlUPqaqzk3xfkusG5gAYtfgpiO4+WlUvTPKuJPdJ8uru/sel51jD2OmPBVjbzrWb17eb13ZCi/8SDoBjfBIOYIgAAwzZ0wE+1Ueiq+qJVfX+qjq6ev/yjrLG+l5SVf9UVR+sqhuq6qsn5rwn1ljbj1bVh6rqA1X1lzvt05brfly/qp5VVV1VO+btW2v87H64qo6sfnYfqKrnT8y5iO7ek7cc+wXgvyR5aJKzk/xDkouOe82FSR6Z5LVJnj098wbW9+QkX7m6/2NJrp2e+wyu7QFb7l+a5Prpuc/k+lavOyfJ+5L8dZJD03OfwZ/dDyf53elZl7jt5SPgU34kurs/3t0fTHLnxID30jrre093/9fq4V/n2Huyd4J11va5LQ/vl2Qn/bZ53Y/r/3KSX0ny30sOdy+5FMEWeznAB5LcvOXxLattu8Xpru95Sf5koxOdOWutrapeUFX/kuRXk7xoodnOhFOur6oek+Rgd79jycHOgHX/XT5rdWrszVV18ATP7wp7OcCsVNUPJDmU5NemZzmTuvuV3f01SX46yc9Nz3OmVNVZSX4zyU9Nz7Ihb09yYXc/Msm7k1wzPM/G7OUA7/aPRK+1vqr69iQvT3Jpd//PQrPdW6f7s3tDkmdscqAz7FTrOyfJI5L8RVV9PMm3JLluh/wi7pQ/u+7+zJZ/i3+Q5JsWmm1xeznAu/0j0adcX1U9OslVORbf2wdmvKfWWdvDtjz87iQ3LTjfvXW36+vu/+juB3f3hd19YY6dv7+0uw/PjHta1vnZnb/l4aVJPrLgfIvatldD27Q+yUeiq+qXkhzu7uuq6rFJ3prk3CRPq6pf7O6vHxx7beusL8dOOdw/yZuqKkn+rbsvHRt6TWuu7YWro/v/TfLvSS6fm/j0rLm+HWnNtb2oqi5NcjTJZ3PsXRG7ko8iAwzZy6cgAEYJMMAQAQYYIsAAQwQYYIgAAwwRYIAh/wcAaMVHb0WW3QAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(results.mean())\n",
    "sns.displot(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.15790464146302477\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<seaborn.axisgrid.FacetGrid at 0x213a93e50>"
      ]
     },
     "execution_count": 111,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFgCAYAAACFYaNMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAP5klEQVR4nO3dfYxlB1nH8e+vXQpqixQ6bsoy64IWQlMikAGhEAWLZiFKQbGlESyxdBugBEJDgvCHRP7ByIuJEuhCm1YDpQWKFMEilpUGgeoCBfoigljYbUu7BQSiEdj28Y85dYdlu3u3vec+M3e+n2Qy9577cp49mfnm7Jl77k1VIUmavSO6B5Ck9coAS1ITAyxJTQywJDUxwJLUZEP3AJPYunVrXXnlld1jSNK9lQMtXBN7wHfccUf3CJI0dWsiwJI0jwywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwIdp0+Jmkkz1a9Pi5u5/lqQGa+IN2VeTW3bv4vTzPz3V57z0nJOn+nyS1gb3gCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqclcB3iM04YlaVrm+lRkTxuWtJrN9R6wJK1mBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpyWgBTrKYZEeSG5Jcn+QVw/LXJ7k5ybXD17PGmkGSVrMxPxFjL3BeVX0+yTHA55J8fLjtrVX1phHXLUmr3mgBrqpbgVuHyz9IciOwaaz1SdJaM5NjwEm2AI8DrhkWnZvkS0kuTHLsLGaQpNVm9AAnORr4APDKqvo+8Hbgl4DHsryH/OZ7eNy2JDuT7NyzZ8/YY0rSzI0a4CT3Yzm+766qywGq6raqurOq7gLeCTzxQI+tqu1VtVRVSwsLC2OOKUktxnwVRIALgBur6i0rlh+/4m7PBa4bawZJWs3GfBXEU4AXAl9Ocu2w7LXAGUkeCxRwE3DOiDNI0qo15qsgPgXkADd9dKx1StJa4plwktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUpPRApxkMcmOJDckuT7JK4blD07y8SRfHb4fO9YMkrSajbkHvBc4r6pOBJ4EvCzJicBrgKuq6gTgquG6JK07owW4qm6tqs8Pl38A3AhsAk4FLh7udjHwnLFmkKTVbCbHgJNsAR4HXANsrKpbh5u+BWy8h8dsS7Izyc49e/bMYkxJmqnRA5zkaOADwCur6vsrb6uqAupAj6uq7VW1VFVLCwsLY48pSTM3aoCT3I/l+L67qi4fFt+W5Pjh9uOB28ecQZJWqzFfBRHgAuDGqnrLipuuAM4cLp8JfGisGSRpNdsw4nM/BXgh8OUk1w7LXgu8EbgsyVnAN4DTRpxBklat0QJcVZ8Ccg83nzLWeiVprfBMOElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmkwU4CRPmWSZJGlyk+4B/+WEyyRJE9pwsBuTPBk4GVhI8qoVNz0QOHLMwSRp3h00wMBRwNHD/Y5Zsfz7wPPGGkqS1oODBriqPgl8MslFVfWNGc0kSevCofaA73b/JNuBLSsfU1W/McZQkrQeTBrg9wHvAN4F3DnJA5JcCPw2cHtVnTQsez1wNrBnuNtrq+qjhzOwJM2LSQO8t6refpjPfRHwV8Bf77f8rVX1psN8LkmaO5O+DO3DSV6a5PgkD77762APqKqrge/c9xElaT5Nugd85vD91SuWFfCIe7HOc5P8IbATOK+qvnugOyXZBmwD2Lx5871YzRpyxAaSTP1pH/qwRW7e9c2pP6+k6ZgowFX18Cmt7+3AG1iO9xuANwN/dA/r3A5sB1haWqoprX91umsvp5//6ak/7aXnnDz155Q0PRMFeNhj/SlVtf/x3YOqqttWPOc7gb87nMdL0jyZ9BDEE1ZcfgBwCvB5fvoPbAeV5PiqunW4+lzgusN5vCTNk0kPQbx85fUkDwLee7DHJLkEeBpwXJLdwJ8AT0vyWJYPQdwEnHO4A0vSvJh0D3h//w0c9LhwVZ1xgMUX3Mv1SdLcmfQY8IdZ3muF5TfheTRw2VhDSdJ6MOke8MoTJ/YC36iq3SPMI0nrxkQnYgxvyvNvLL8j2rHAj8YcSpLWg0k/EeM04F+A3wdOA65J4ttRStJ9MOkhiNcBT6iq2wGSLAD/CLx/rMEkad5N+l4QR9wd38G3D+Ox6jKc4jzNr02Lc35auDRDk+4BX5nkY8Alw/XTAd9GcrUb4RRnT2+WpudQnwn3y8DGqnp1kt8Fnjrc9Bng3WMPJ0nz7FB7wH8B/DFAVV0OXA6Q5DHDbb8z4mySNNcOdRx3Y1V9ef+Fw7Ito0wkSevEoQL8oIPc9jNTnEOS1p1DBXhnkrP3X5jkxcDnxhlJktaHQx0DfiXwwSR/wL7gLgFHsfx2kpKke+mgAR7eQP3kJE8HThoWf6SqPjH6ZJI05yZ9P+AdwI6RZ5GkdcWz2SSpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJqMFOMmFSW5Pct2KZQ9O8vEkXx2+HzvW+iVptRtzD/giYOt+y14DXFVVJwBXDdclaV0aLcBVdTXwnf0WnwpcPFy+GHjOWOuXpNVu1seAN1bVrcPlbwEb7+mOSbYl2Zlk5549e2YznSTNUNsf4aqqgDrI7duraqmqlhYWFmY4mSTNxqwDfFuS4wGG77fPeP2StGrMOsBXAGcOl88EPjTj9UvSqjHmy9AuAT4DPCrJ7iRnAW8EfjPJV4FnDNclaV3aMNYTV9UZ93DTKWOtU5LWEs+Ek6QmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgza1Ni5tJMtWvTYubu/9ZmiMbugeQxnLL7l2cfv6np/qcl55z8lSfT+ube8CS1MQAS1ITAyxJTQywJDUxwJLUxABLUpOWl6EluQn4AXAnsLeqljrmkKROna8DfnpV3dG4fklq5SEISWrSFeAC/iHJ55JsO9AdkmxLsjPJzj179sx4PEkaX1eAn1pVjweeCbwsya/tf4eq2l5VS1W1tLCwMPsJJWlkLQGuqpuH77cDHwSe2DGHJHWaeYCT/FySY+6+DPwWcN2s55Ckbh2vgtgIfDDJ3et/T1Vd2TCHJLWaeYCr6uvAr8x6vZK02vgyNElqYoAlqYkBlqQmBliSmhhgSWpigHV4jtjgJw1LU+KnIuvw3LXXTxqWpsQ9YElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKa+KGc6jd80vJ6tWlxM7fs3jX15z3yfvfnzh//cKrP+dCHLXLzrm9O9TnXMwOsfiN80jKsnU9bvmX3rtH+/X6C9ermIQhJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCXNnU2Lm0ky9a9Ni5unOqenIkuaO2Oe3j1N7gFLUhMDLElNDLAkNTHAktTEAEtSEwMsSU1aApxka5KvJPlaktd0zCBJ3WYe4CRHAm8DngmcCJyR5MRZzyFJ3Tr2gJ8IfK2qvl5VPwLeC5zaMIcktUpVzXaFyfOArVX14uH6C4Ffrapz97vfNmDbcPVRwFfuxeqOA+64D+POE7fFPm6Ln+T22GesbXFHVW3df+GqPRW5qrYD2+/LcyTZWVVLUxppTXNb7OO2+Eluj31mvS06DkHcDCyuuP6wYZkkrSsdAf5X4IQkD09yFPB84IqGOSSp1cwPQVTV3iTnAh8DjgQurKrrR1rdfTqEMWfcFvu4LX6S22OfmW6Lmf8RTpK0zDPhJKmJAZakJnMR4EOd2pzk/kkuHW6/JsmWhjFnYoJt8aokNyT5UpKrkvxix5yzMOkp70l+L0klmduXYk2yLZKcNvxsXJ/kPbOecZYm+D3ZnGRHki8MvyvPGmWQqlrTXyz/Ie8/gEcARwFfBE7c7z4vBd4xXH4+cGn33I3b4unAzw6XX7Ket8Vwv2OAq4HPAkvdczf+XJwAfAE4drj+C91zN2+P7cBLhssnAjeNMcs87AFPcmrzqcDFw+X3A6ckyQxnnJVDbouq2lFV/zNc/SzLr8OeR5Oe8v4G4M+A/53lcDM2ybY4G3hbVX0XoKpun/GMszTJ9ijggcPlnwduGWOQeQjwJmDXiuu7h2UHvE9V7QW+BzxkJtPN1iTbYqWzgL8fdaI+h9wWSR4PLFbVR2Y5WINJfi4eCTwyyT8n+WySnzptdo5Msj1eD7wgyW7go8DLxxhk1Z6KrHEleQGwBPx69ywdkhwBvAV4UfMoq8UGlg9DPI3l/xVdneQxVfVfnUM1OgO4qKrenOTJwN8kOamq7prmSuZhD3iSU5v//z5JNrD8X4pvz2S62ZroNO8kzwBeBzy7qn44o9lm7VDb4hjgJOCfktwEPAm4Yk7/EDfJz8Vu4Iqq+nFV/Sfw7ywHeR5Nsj3OAi4DqKrPAA9g+Y16pmoeAjzJqc1XAGcOl58HfKKGo+tz5pDbIsnjgPNZju88H+c76Laoqu9V1XFVtaWqtrB8PPzZVbWzZ9xRTfI78rcs7/2S5DiWD0l8fYYzztIk2+ObwCkASR7NcoD3THuQNR/g4Zju3ac23whcVlXXJ/nTJM8e7nYB8JAkXwNeBczlp3BMuC3+HDgaeF+Sa5PM5ftwTLgt1oUJt8XHgG8nuQHYAby6qubxf4mTbo/zgLOTfBG4BHjRGDttnoosSU3W/B6wJK1VBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJavJ/6h8oAiKmf0EAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(saliency_sim.mean())\n",
    "sns.displot(saliency_sim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Try My Method With Cosine Similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier(algorithm='brute', metric='euclidean', n_neighbors=1)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "temp_train = list()\n",
    "temp_test = list()\n",
    "\n",
    "for i in range(X_train_C.shape[0]):\n",
    "    x = deepcopy(X_train_x[i].flatten())\n",
    "    if x.sum() == 0.:\n",
    "        temp_train.append(x)\n",
    "    else:\n",
    "        x = x/x.sum()\n",
    "        temp_train.append(x)\n",
    "    \n",
    "for i in range(X_test_C.shape[0]):\n",
    "    x = deepcopy(X_test_x[i].flatten())\n",
    "    if x.sum() == 0.:\n",
    "        temp_test.append(x)\n",
    "    else:\n",
    "        x = x/x.sum()\n",
    "        temp_test.append(x)  \n",
    "        \n",
    "temp_train = np.array(temp_train)\n",
    "temp_test  = np.array(temp_test)\n",
    "\n",
    "# Fit COLE and DkNN\n",
    "twin = KNeighborsClassifier(n_neighbors=1, algorithm=\"brute\", metric='euclidean') \n",
    "twin.fit(temp_train, train_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  3%|▎         | 300/10000 [00:09<05:15, 30.72it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Contamination: 1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "results = list()\n",
    "saliency_sim = list()\n",
    "contamination = 0\n",
    "\n",
    "\n",
    "for query_idx, data in enumerate(tqdm(test_loader)):\n",
    "    \n",
    "    # Get data\n",
    "    img, label = data\n",
    "    \n",
    "    # Get query information\n",
    "    query_label = y_test[query_idx].item()\n",
    "    query_pred  = test_preds[query_idx].item()\n",
    "    query_img_trans = get_transformed_data([query_idx], test_loader)[0]\n",
    "    query_cont = X_test_c[query_idx]\n",
    "    \n",
    "    # Get query salient feature\n",
    "    query_logits, query_x, query_C = netC(query_img_trans)\n",
    "    query_nb_boxes = get_salient_regions(query_C, query_logits, net_classifier)\n",
    "    window_idx = query_nb_boxes[0][1]  # Only the most NB one\n",
    "    query_feature = query_C[ :, :, window_idx[0]:window_idx[0]+1, window_idx[1]:window_idx[1]+1]\n",
    "\n",
    "    # Get explanation nn\n",
    "    xp_idxs = twin.kneighbors(X=[temp_test[query_idx]], n_neighbors=NUM_NEIGHBORS, return_distance=False)\n",
    "    \n",
    "    \n",
    "    # How many NNs are different class\n",
    "    contamination += (train_preds[xp_idxs[0]] == query_pred).sum()\n",
    "    \n",
    "    \n",
    "    # Iterate all three salient regions in xp nn\n",
    "    min_dist = float('inf')\n",
    "    min_logit = float('inf')\n",
    "    for i in range(NUM_NEIGHBORS):\n",
    "\n",
    "        xps_img_trans = get_transformed_data([xp_idxs[0][i]], train_loader)\n",
    "        xp_logits, xp_x, xp_C = netC(xps_img_trans[0])\n",
    "        xp_pred = torch.argmax(xp_logits, dim=1).item()\n",
    "        \n",
    "        coord, dist_query_to_xp = get_box_xp(xp_C, query_feature)  # get similar latent feature in nn xp\n",
    "        xp_feature = xp_C[ :, :, coord[0]:coord[0]+1, coord[1]:coord[1]+1]\n",
    "        \n",
    "        # Get logit change\n",
    "        temp = xp_C.clone()\n",
    "        temp[ :, :, coord[0]:coord[0]+1, coord[1]:coord[1]+1] = 0.0\n",
    "        new_xp_logits = net_classifier(temp)\n",
    "        logit_change = xp_logits[0][xp_pred] - new_xp_logits[0][xp_pred]\n",
    "        logit_change = abs( logit_change - query_nb_boxes[0][0] )\n",
    "                       \n",
    "        if dist_query_to_xp < min_dist:\n",
    "            min_dist = dist_query_to_xp\n",
    "            min_logit = logit_change\n",
    "\n",
    "    results.append(min_dist)\n",
    "    saliency_sim.append(min_logit.item())\n",
    "    \n",
    "    if query_idx == 300:\n",
    "        break\n",
    "            \n",
    "print('Contamination:', contamination / len(results)*NUM_NEIGHBORS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "saliency_sim = np.array(saliency_sim)\n",
    "results = np.array(deepcopy(results))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "36.05689384929366\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<seaborn.axisgrid.FacetGrid at 0x2087e6710>"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFgCAYAAACFYaNMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAQgUlEQVR4nO3dbYxlBX3H8e8PRoqPBex2Q5fdgIFoiY3QjBQX01TQZlut0IayWms3DXZJqi1Wq0X7piZ9oYnxIY2xbMC6TagsRQhoGyxd0bbBrC7SVgENlIosT7soVOsLdeXfF/cg03VhZ5c99z8z9/tJJnPPuU//ubn75XDunDOpKiRJ03dE9wCSNKsMsCQ1McCS1MQAS1ITAyxJTea6B1iMDRs21A033NA9hiQdquxv5bLYAn744Ye7R5Ckw25ZBFiSViIDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDPCMWbN2HUlG+Vqzdl33jyctK8vihOw6fO7fdS8bL715lMfedtH6UR5XWqncApakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmo/5FjCTfAL4L/AjYW1XzSY4DtgEnAt8ALqiqR8acQ5KWomlsAb+iqk6rqvlh+RJge1WdAmwfliVp5nTsgjgX2Dpc3gqc1zCDJLUbO8AF/FOSW5JsHtatrqoHhssPAqv3d8ckm5PsTLJzz549I48pSdM39l9FfnlV3ZfkZ4Ebk3xt4ZVVVUlqf3esqi3AFoD5+fn93kaSlrNRt4Cr6r7h+27gWuAM4KEkxwMM33ePOYMkLVWjBTjJs5M89/HLwK8CXwWuBzYNN9sEXDfWDJK0lI25C2I1cG2Sx5/n76rqhiRfAq5KciFwD3DBiDNI0pI1WoCr6m7gJftZ/y3gnLGeV5KWC4+Ek6QmBliSmhhgSWpigCWpiQGWpCYGWJKaGOAlaM3adSQZ5UvS0jH2uSB0CO7fdS8bL715lMfedtH6UR5X0sFzC1iSmhhgSWpigCWpiQE+RH5QJunp8kO4Q+QHZZKeLreAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKajB7gJEcmuTXJp4flk5LsSHJXkm1Jjhp7BklaiqaxBXwxcMeC5fcBH6yqk4FHgAunMIMkLTmjBjjJCcCrgcuG5QBnA1cPN9kKnDfmDJK0VI29Bfwh4J3AY8Py84FHq2rvsLwLWLO/OybZnGRnkp179uwZeUxJmr7RApzkNcDuqrrlUO5fVVuqar6q5letWnWYp5OkfnMjPvZZwGuT/DpwNPA84MPAMUnmhq3gE4D7RpxBkpas0baAq+pdVXVCVZ0IvA74bFW9AbgJOH+42SbgurFmkKSlrOP3gP8MeFuSu5jsE768YQZJajfmLogfq6rPAZ8bLt8NnDGN55Wkpcwj4SSpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBljLwpq160gyyteateu6fzzNqLnuAaTFuH/XvWy89OZRHnvbRetHeVzpQNwClqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqcmiApzkrMWs2+f6o5N8Mcl/JLktyXuG9Scl2ZHkriTbkhx1aKNL0vK22C3gv1rkuoW+D5xdVS8BTgM2JDkTeB/wwao6GXgEuHCRM0jSivKUp6NM8jJgPbAqydsWXPU84Minum9VFfC/w+Izhq8CzgZ+Z1i/FfgL4KMHO7gkLXcHOh/wUcBzhts9d8H67wDnH+jBkxwJ3AKcDHwE+C/g0araO9xkF7DmSe67GdgMsG7doZ0we83addy/695Duq8OwRFzJOmeQlo2njLAVfV54PNJPl5V9xzsg1fVj4DTkhwDXAu86CDuuwXYAjA/P18H+9zgSbyn7rG9vt7SQVjsX8T4qSRbgBMX3qeqzl7Mnavq0SQ3AS8DjkkyN2wFnwDcd3AjS9LKsNgA/z3w18BlwI8Wc4ckq4AfDvF9JvAqJh/A3cRk98WVwCbguoMdWpJWgsUGeG9VHewHZccDW4f9wEcAV1XVp5PcDlyZ5C+BW4HLD/JxJWlFWGyAP5XkD5nsx/3+4yur6ttPdoeq+k/g9P2svxs44yDnlKQVZ7EB3jR8f8eCdQW84PCOI0mzY1EBrqqTxh5EkmbNogKc5Pf2t76q/vbwjiNJs2OxuyBeuuDy0cA5wJcBAyxJh2ixuyD+aOHycGDFlWMMJEmz4lBPR/k9wP3CkvQ0LHYf8KeY/NYDTE7C8/PAVWMNJUmzYLH7gN+/4PJe4J6q2jXCPJI0Mxa1C2I4Kc/XmJwR7VjgB2MOJUmzYLF/EeMC4IvAbwMXADuSHPB0lJKkJ7fYXRB/Dry0qnbDj0+088/A1WMNJkkr3WJ/C+KIx+M7+NZB3FeStB+L3QK+IclngE8MyxuBfxxnJEmaDQf6m3AnA6ur6h1Jfgt4+XDVF4Arxh5OklayA20Bfwh4F0BVXQNcA5DkF4brfmPE2SRpRTvQftzVVfWVfVcO604cZSJJmhEHCvAxT3HdMw/jHJI0cw4U4J1J/mDflUnexOTPzUuSDtGB9gG/Fbg2yRt4IrjzwFHAb444lySteE8Z4Kp6CFif5BXAi4fV/1BVnx19Mkla4RZ7PuCbmPw5eUnSYeLRbJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsHTEHElG+Vqzdl33T6clbLEnZJdWrsf2svHSm0d56G0XrR/lcbUyuAUsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTUYLcJK1SW5KcnuS25JcPKw/LsmNSe4cvh871gyStJSNuQW8F3h7VZ0KnAm8OcmpwCXA9qo6Bdg+LEvSzBktwFX1QFV9ebj8XeAOYA1wLrB1uNlW4LyxZpCkpWwq+4CTnAicDuwAVlfVA8NVDwKrn+Q+m5PsTLJzz5490xhTkqZq9AAneQ7wSeCtVfWdhddVVQG1v/tV1Zaqmq+q+VWrVo09piRN3agBTvIMJvG9oqquGVY/lOT44frjgd1jziBJS9WYvwUR4HLgjqr6wIKrrgc2DZc3AdeNNYMkLWVzIz72WcAbga8k+fdh3buB9wJXJbkQuAe4YMQZJGnJGi3AVfVvQJ7k6nPGel5JWi48Ek6SmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlsZ0xBxJRvlas3Zd90+np2muewBpRXtsLxsvvXmUh9520fpRHlfT4xawJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1KT0QKc5GNJdif56oJ1xyW5Mcmdw/djx3p+SVrqxtwC/jiwYZ91lwDbq+oUYPuwLEkzabQAV9W/AN/eZ/W5wNbh8lbgvLGeX5KWumnvA15dVQ8Mlx8EVj/ZDZNsTrIzyc49e/ZMZzpJmqK2D+GqqoB6iuu3VNV8Vc2vWrVqipNJ0nRMO8APJTkeYPi+e8rPL0lLxrQDfD2wabi8Cbhuys8vSUvGmL+G9gngC8ALk+xKciHwXuBVSe4EXjksS9JMmhvrgavq9U9y1TljPackLSceCSdJTQywJDUxwJLUxABLUhMDLElNDLC0XB0xR5JRvtasXdf9082E0X4NTdLIHtvLxktvHuWht120fpTH1f/nFrAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABL+kkjnmfCc008wXNBSPpJI55nAjzXxOPcApakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWNL0jXio83I6zNlDkSVN34iHOi+nw5zdApakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWNLKsowOc/ZQZEkryzI6zNktYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYtAU6yIcnXk9yV5JKOGSSp29QDnORI4CPArwGnAq9Pcuq055Ckbh1bwGcAd1XV3VX1A+BK4NyGOSSpVapquk+YnA9sqKo3DctvBH6pqt6yz+02A5uHxRcCX5/qoOP4GeDh7iGWCF+LJ/haTKzk1+Hhqtqw78oley6IqtoCbOme43BKsrOq5rvnWAp8LZ7gazExi69Dxy6I+4C1C5ZPGNZJ0kzpCPCXgFOSnJTkKOB1wPUNc0hSq6nvgqiqvUneAnwGOBL4WFXdNu05mqyoXSpPk6/FE3wtJmbudZj6h3CSpAmPhJOkJgZYkpoY4JEkWZvkpiS3J7ktycXD+uOS3JjkzuH7sd2zTkOSI5PcmuTTw/JJSXYMh6NvGz6QXfGSHJPk6iRfS3JHkpfN8HviT4Z/G19N8okkR8/a+8IAj2cv8PaqOhU4E3jzcMj1JcD2qjoF2D4sz4KLgTsWLL8P+GBVnQw8AlzYMtX0fRi4oapeBLyEyWsyc++JJGuAPwbmq+rFTD6Qfx0z9r4wwCOpqgeq6svD5e8y+Ye2hslh11uHm20FzmsZcIqSnAC8GrhsWA5wNnD1cJNZeR1+Gvhl4HKAqvpBVT3KDL4nBnPAM5PMAc8CHmDG3hcGeAqSnAicDuwAVlfVA8NVDwKru+aaog8B7wQeG5afDzxaVXuH5V1M/uO00p0E7AH+Ztgdc1mSZzOD74mqug94P/BNJuH9H+AWZux9YYBHluQ5wCeBt1bVdxZeV5PfAVzRvweY5DXA7qq6pXuWJWAO+EXgo1V1OvA99tndMAvvCYBhP/e5TP6j9HPAs4GfOFfCSmeAR5TkGUzie0VVXTOsfijJ8cP1xwO7u+abkrOA1yb5BpMz353NZD/oMcP/esLsHI6+C9hVVTuG5auZBHnW3hMArwT+u6r2VNUPgWuYvFdm6n1hgEcy7Oe8HLijqj6w4KrrgU3D5U3AddOebZqq6l1VdUJVncjkQ5bPVtUbgJuA84ebrfjXAaCqHgTuTfLCYdU5wO3M2Hti8E3gzCTPGv6tPP5azNT7wiPhRpLk5cC/Al/hiX2f72ayH/gqYB1wD3BBVX27ZcgpS/IrwJ9W1WuSvIDJFvFxwK3A71bV9xvHm4okpzH5MPIo4G7g95lsCM3ceyLJe4CNTH5j6FbgTUz2+c7M+8IAS1ITd0FIUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1+T+/Rpw7zS4MEQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(results.mean())\n",
    "sns.displot(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.23285225112968902\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<seaborn.axisgrid.FacetGrid at 0x208f7fbd0>"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFgCAYAAACFYaNMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAASpElEQVR4nO3df7Bnd13f8ecrWWOsggG57qSb3W4YIpqJ5cdcKSyMFaJOpJakLSYwiFtd3aiV0aFjG+Wf+uMPmbH+aIfR7ATK6iBspNCsv2JxCTI1EFxMJJCIhDQxG5LsBRO0dhSXvPvH90TubHf3fvfuPd/393vv8zHznXvO+Z6z53Xv3Puas5/v+ZGqQpI0e+d1B5CkrcoClqQmFrAkNbGAJamJBSxJTbZ1B5jGVVddVbfeemt3DEk6naxno4U4Av7sZz/bHUGSNtxCFLAkbUYWsCQ1sYAlqYkFLElNLGBJamIBS1ITC1iSmljAktTEApakJhawJDWxgCWpiQUsSU0sYElqYgFLUhML+DR27NxFknW/duzc1f0tSJpzC3FD9g6fOfYQ1914+7q3P3T9ng1MI2kz8ghYkppYwJLUxAKWpCYWsCQ1sYAlqYkFLElNLGBJamIBS1ITC1iSmljAktTEApakJhawJDWxgCWpiQUsSU0sYElqYgFLUhMLWJKaWMCS1MQClqQmFrAkNbGAJanJpi3gc32svCSNbdTH0ie5CLgJuAIo4PuATwKHgN3AA8C1VfX4Ru/bx8pLmndjHwH/MnBrVX098DzgXuAG4EhVXQYcGeYlacsZrYCTfDXwzcBbAarqC1X1BHA1cHBY7SBwzVgZJGmejXkEfCmwAvy3JHcmuSnJVwLbq+qRYZ1Hge0jZpCkuTVmAW8DXgj8SlW9APgbThpuqKpiMjb8/0myP8nRJEdXVlZGjClJPcYs4GPAsaq6Y5h/N5NCfizJxQDD1+On2riqDlTVclUtLy0tjRhTknqMVsBV9SjwUJLnDouuBO4BDgN7h2V7gVvGyiBJ82zU09CANwDvSHIBcD/wvUxK/+Yk+4AHgWtHziBJc2nUAq6qu4DlU7x15Zj7laRFsGmvhJOkeWcBS1ITC1iSmljAktTEApakJhawJDWxgCWpiQUsSU0sYElqYgFLUhMLWJKaWMCS1MQClqQmFrAkNbGAJamJBSxJTSxgSWpiAUtSEwtYkppYwJLUxAKWpCYWsCQ1sYAlqYkFLElNLGBJamIBS1ITC1iSmljAktTEApakJhawJDWxgCWpiQUsSU0sYElqYgFLUhMLWJKaWMCS1MQClqQmFrAkNdk25j+e5AHgr4EvAieqajnJM4FDwG7gAeDaqnp8zBySNI9mcQT88qp6flUtD/M3AEeq6jLgyDAvSVtOxxDE1cDBYfogcE1DBklqN3YBF/A/k3w0yf5h2faqemSYfhTYfqoNk+xPcjTJ0ZWVlZFjStLsjToGDLysqh5O8rXA+5L82eo3q6qS1Kk2rKoDwAGA5eXlU64jSYts1CPgqnp4+HoceC/wIuCxJBcDDF+Pj5lBkubVaAWc5CuTPO2paeDbgY8Dh4G9w2p7gVvGyiBJ82zMIYjtwHuTPLWf36iqW5P8MXBzkn3Ag8C1I2aQpLk1WgFX1f3A806x/HPAlWPtV5IWhVfCSVITC1iSmljAktTEApakJhawJDWxgCWpiQUsSU0sYElqYgFLUhMLWJKaWMCS1MQClqQmFrAkNbGAJamJBSxJTSxgSWpiAUtSEwtYkppYwJLUxAKWpCYWsCQ1sYAlqYkFLElNLGBJamIBS1ITC1iSmljAktTEApakJhawJDWxgCWpiQUsSU0s4LGct40k637t2Lmr+zuQNLJt3QE2rSdPcN2Nt69780PX79nAMJLmkUfAktTEApakJhawJDWxgCWpyegFnOT8JHcm+e1h/tIkdyS5L8mhJBeMnUGS5tEsjoB/FLh31fybgV+squcAjwP7ZpBBkubOqAWc5BLgXwA3DfMBXgG8e1jlIHDNmBkkaV6NfQT8S8B/AJ4c5r8GeKKqTgzzx4Adp9owyf4kR5McXVlZGTmmJM3eaAWc5DuB41X10fVsX1UHqmq5qpaXlpY2OJ0k9RvzSriXAq9K8krgQuDpwC8DFyXZNhwFXwI8PGIGSZpbox0BV9VPVNUlVbUbeA3w/qp6HXAb8Ophtb3ALWNlkKR51nEe8H8E3pjkPiZjwm9tyCBJ7WZyM56q+gDwgWH6fuBFs9ivJM0zr4STpCYWsCQ1sYAlqYkFLElNLGBJamIBS1ITC1iSmkxVwEleOs0ySdL0pj0C/q9TLpMkTemMV8IleQmwB1hK8sZVbz0dOH/MYJK02a11KfIFwFcN6z1t1fK/4ks31JEkrcMZC7iq/hD4wyRvr6oHZ5RJkraEaW/G8+VJDgC7V29TVa8YI5QkbQXTFvBvAr/K5NluXxwvjiRtHdMW8Imq+pVRk0jSFjPtaWi/leSHk1yc5JlPvUZNJkmb3LRHwHuHrz++alkBz97YOJK0dUxVwFV16dhBdJLztpFk3Zv/40t28vBDf7GBgSRttKkKOMn3nGp5Vf3axsbRP3jyBNfdePu6Nz90/Z4NDCNpDNMOQXzTqukLgSuBPwEsYElap2mHIN6wej7JRcC7xggkSVvFem9H+TeA48KSdA6mHQP+LSZnPcDkJjzfANw8VihJ2gqmHQP++VXTJ4AHq+rYCHkkacuYaghiuCnPnzG5I9ozgC+MGUqStoJpn4hxLfAR4LuAa4E7kng7Skk6B9MOQbwJ+KaqOg6QZAn4A+DdYwWTpM1u2rMgznuqfAefO4ttJUmnMO0R8K1Jfh945zB/HfC740SSpK1hrWfCPQfYXlU/nuRfAy8b3voQ8I6xw0nSZrbWEfAvAT8BUFXvAd4DkOQbh/f+5YjZJGlTW2scd3tV3X3ywmHZ7lESSdIWsVYBX3SG975iA3NI0pazVgEfTfIDJy9M8v3AR8eJJElbw1pjwD8GvDfJ6/hS4S4DFwD/asRckrTpnbGAq+oxYE+SlwNXDIt/p6reP3oySdrkpr0f8G3AbSNnkaQtZbSr2ZJcmOQjSf40ySeS/NSw/NIkdyS5L8mhJBeMlUGS5tmYlxP/HfCKqnoe8HzgqiQvBt4M/GJVPQd4HNg3YgZJmlujFXBN/J9h9suGVwGv4Es38TkIXDNWBkmaZ6PeUCfJ+UnuAo4D7wM+DTxRVSeGVY4BO8bMIEnzatQCrqovVtXzgUuAFwFfP+22SfYnOZrk6MrKylgRJanNTG4pWVVPMDmL4iXARUmeOvviEuDh02xzoKqWq2p5aWlpFjElaabGPAtiaXh8PUm+Avg24F4mRfzU0zT2AreMlUGS5tm09wNej4uBg0nOZ1L0N1fVbye5B3hXkp8F7gTeOmIGSZpboxVwVX0MeMEplt/PZDxYkrY0HyskSU0sYElqYgFLUhMLWJKaWMCS1MQClqQmFrAkNbGAJamJBSxJTSzgzeq8bSRZ92vHzl3d34G06Y15Lwh1evIE1914+7o3P3T9ng0MI+lUPAKWpCYWsCQ1sYAlqYkFLElNLGBJamIBS1ITC1iSmljAktTEApakJhawJDWxgCWpiQUsSU0sYElqYgFLUhMLWJKaWMCS1MQClqQmFrAkNbGAJamJBSxJTSxgSWpiAUtSEwtYkppYwJLUxAKWpCYWsCQ1sYAlqcloBZxkZ5LbktyT5BNJfnRY/swk70vyqeHrM8bKIEnzbMwj4BPAv6+qy4EXA/8uyeXADcCRqroMODLMS9KWM1oBV9UjVfUnw/RfA/cCO4CrgYPDageBa8bKIEnzbCZjwEl2Ay8A7gC2V9Ujw1uPAttPs83+JEeTHF1ZWZlFTEmaqdELOMlXAf8d+LGq+qvV71VVAXWq7arqQFUtV9Xy0tLS2DElaeZGLeAkX8akfN9RVe8ZFj+W5OLh/YuB42NmkKR5NeZZEAHeCtxbVb+w6q3DwN5hei9wy1gZJGmebRvx334p8Hrg7iR3Dct+Evg54OYk+4AHgWtHzCBJc2u0Aq6q/wXkNG9fOdZ+JWlReCWcJDWxgCWpiQUsSU0sYJ3aedtIsu7Xjp27ur8Dae6NeRaEFtmTJ7juxtvXvfmh6/dsYBhpc/IIWJKaWMCS1MQClqQmFrAkNbGAJamJBSxJTSxgSWpiAUtSEwtY4ziHK+m8ik5bhVfCaRzncCWdV9Fpq/AIWJKaWMCS1MQClqQmFrAkNbGAJamJBSxJTSxgSWpiAWv++DgkbRFeiKH54+OQtEV4BCxJTSxgSWpiAUtSEwtYkppYwJLUxAKWpCYWsCQ1sYAlqYkFLElNLGBJamIBS1ITC1iSmljAktRktAJO8rYkx5N8fNWyZyZ5X5JPDV+fMdb+JWnejXkE/HbgqpOW3QAcqarLgCPDvCRtSaMVcFV9EPjLkxZfDRwcpg8C14y1f0mad7MeA95eVY8M048C20+3YpL9SY4mObqysjKbdJI0Q20fwlVVAXWG9w9U1XJVLS8tLc0wmSTNxqwL+LEkFwMMX4/PeP+SNDdmXcCHgb3D9F7glhnvX5Lmxpinob0T+BDw3CTHkuwDfg74tiSfAr51mJekLWm0pyJX1WtP89aVY+1TkhaJV8JJUhMLWJKaWMCS1MQClqQmFrAkNbGAJamJBSxJTSxgSWpiAUtSEwtYkppYwJLUxAKWpCYWsCQ1sYClObNj5y6SrPu1Y+eu7m9BUxrtdpSS1uczxx7iuhtvX/f2h67fs4FpNCaPgCWpiQUsSU0sYElqYgFLUhMLWJKaWMCS1MQClqQmFrA2n/O2eSGDFoIXYmjzefKEFzJoIXgELElNLGBJamIBS1ITC1g6mR/iaUb8EE46mR/iaUY8ApakJhawJDWxgCWpiQUsbbRz/BCve//n+iFi5yOVFu1xTn4IJ2207g/xmvff+UilRXuck0fAktTEApakJg5BSJovwxj2VmABS5ov5zCGvWgXwbQMQSS5Ksknk9yX5IaODJLUbeYFnOR84C3AdwCXA69Ncvmsc0hSt44j4BcB91XV/VX1BeBdwNUNOSSpVapqtjtMXg1cVVXfP8y/HvhnVfUjJ623H9g/zD4X+ORZ7upZwGfPMW6XRc4Oi53f7D0WOTvAhVV1xdluNLcfwlXVAeDAerdPcrSqljcw0swscnZY7Pxm77HI2WGSfz3bdQxBPAzsXDV/ybBMkraUjgL+Y+CyJJcmuQB4DXC4IYcktZr5EERVnUjyI8DvA+cDb6uqT4ywq3UPX8yBRc4Oi53f7D0WOTusM//MP4STJE14LwhJamIBS1KThS/gtS5rTvLlSQ4N79+RZHdDzFOaIvsbk9yT5GNJjiT5Jx05T2Xay8mT/JsklWSuTjGaJn+Sa4ef/yeS/MasM57OFL83u5LcluTO4XfnlR05TyXJ25IcT/Lx07yfJP9l+N4+luSFs854OlNkf92Q+e4ktyd53pr/aFUt7IvJh3ifBp4NXAD8KXD5Sev8MPCrw/RrgEPduc8i+8uBfzRM/9AiZR/WexrwQeDDwHJ37rP82V8G3Ak8Y5j/2u7cZ5H9APBDw/TlwAPduVdl+2bghcDHT/P+K4HfAwK8GLijO/NZZN+z6vflO6bJvuhHwNNc1nw1cHCYfjdwZebjXndrZq+q26rq/w6zH2ZyzvQ8mPZy8p8B3gz87SzDTWGa/D8AvKWqHgeoquMzzng602Qv4OnD9FcDn5lhvjOqqg8Cf3mGVa4Gfq0mPgxclOTi2aQ7s7WyV9XtT/2+MOXf66IX8A7goVXzx4Zlp1ynqk4Anwe+Zibpzmya7KvtY3JkMA/WzD7813FnVf3OLINNaZqf/dcBX5fkj5J8OMlVM0t3ZtNk/0/Adyc5Bvwu8IbZRNsQZ/t3Ma+m+nud20uR9SVJvhtYBv55d5ZpJDkP+AXg3zZHORfbmAxDfAuTI5kPJvnGqnqiM9SUXgu8var+c5KXAL+e5IqqerI72FaQ5OVMCvhla6276EfA01zW/A/rJNnG5L9kn5tJujOb6pLsJN8KvAl4VVX93YyyrWWt7E8DrgA+kOQBJmN5h+fog7hpfvbHgMNV9fdV9b+BP2dSyN2myb4PuBmgqj4EXMjkZjeLYKFvVZDknwI3AVdX1Zo9s+gFPM1lzYeBvcP0q4H31zBK3mzN7EleANzIpHznZQwS1sheVZ+vqmdV1e6q2s1kPOxVVbWuG5aMYJrfm//B5OiXJM9iMiRx/wwzns402f8CuBIgyTcwKeCVmaZcv8PA9wxnQ7wY+HxVPdIdahpJdgHvAV5fVX8+1UbdnyxuwCeTr2RydPJp4E3Dsp9m8gcPk1++3wTuAz4CPLs781lk/wPgMeCu4XW4O/O02U9a9wPM0VkQU/7sw2QY5R7gbuA13ZnPIvvlwB8xOUPiLuDbuzOvyv5O4BHg75n8L2Mf8IPAD676ub9l+N7unqffmymy3wQ8vurv9eha/6aXIktSk0UfgpCkhWUBS1ITC1iSmljAktTEApakJhawJDWxgCWpyf8DCUID0rWx0RkAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(saliency_sim.mean())\n",
    "sns.displot(saliency_sim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## New Evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. Take 100 NNs, filter out those which are not labelled/predicted in the query class.\n",
    "2. Take closest representation of the nb feature in each NN. Then take the logit change.\n",
    "\n",
    "#### Note: Normalize Distances and Logit distances\n",
    "$Score = \\frac{numMatches}{casePool} (1 - \\frac{avg(distL1)}{len(numMatches)}) (1 - \\frac{distLogit}{len(numMatches)})  $"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 154,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_NEIGHBORS = 100\n",
    "FEATURE_NUM = 0  # pick most nb feature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier(algorithm='brute', metric='euclidean', n_neighbors=1)"
      ]
     },
     "execution_count": 155,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Fit twin\n",
    "twin = KNeighborsClassifier(n_neighbors=1, algorithm=\"brute\", metric='euclidean') \n",
    "twin.fit(X_train_c, train_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 156,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 1/10000 [00:00<1:42:11,  1.63it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.6220788386464119 0.7013667583465576\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 2/10000 [00:01<1:47:53,  1.54it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.6540866205096245 0.7097906970977783\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 3/10000 [00:01<1:41:35,  1.64it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.7863338408619165 0.7691479015350342\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 4/10000 [00:02<1:34:31,  1.76it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.7634108706563711 0.8728809213638306\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 5/10000 [00:03<1:43:05,  1.62it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.5389021995663643 0.6266146326065063\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 6/10000 [00:03<1:50:47,  1.50it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.8026175872981548 0.8722674751281738\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 7/10000 [00:04<1:46:41,  1.56it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.6284126450121403 0.6811836379766465\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 8/10000 [00:04<1:41:45,  1.64it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.5317842234671115 0.727643346786499\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 9/10000 [00:05<1:43:03,  1.62it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.96 0.6508559504151344 0.6847880312800407\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 10/10000 [00:06<1:48:20,  1.54it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.6673754233121871 0.8069199395179749\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 11/10000 [00:06<1:47:26,  1.55it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.7581985920667649 0.8493252038955689\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 12/10000 [00:07<1:50:58,  1.50it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.6555050638318062 0.402157096862793\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 13/10000 [00:08<1:53:09,  1.47it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.7112283872067928 0.9113344430923462\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 14/10000 [00:09<1:53:46,  1.46it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.8103917718678713 0.8745167493820191\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 15/10000 [00:09<1:49:18,  1.52it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.7518708999454975 0.7138951063156128\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 16/10000 [00:10<1:45:28,  1.58it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.5080441564321518 0.48479166746139524\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 17/10000 [00:10<1:46:24,  1.56it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.776623937562108 0.8163562083244323\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 18/10000 [00:11<1:41:57,  1.63it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.7636316557228565 0.8193531799316406\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 19/10000 [00:12<1:40:40,  1.65it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.5367426273226739 0.5972383835911751\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 20/10000 [00:12<1:43:25,  1.61it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.636623956412077 0.6348786282539367\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 21/10000 [00:13<1:39:15,  1.68it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.5338270524144173 0.8137118983268737\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 22/10000 [00:13<1:39:48,  1.67it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.5558705255389214 0.5228573989868164\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 23/10000 [00:14<1:44:57,  1.58it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.5271932478249073 0.24264389038085943\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 24/10000 [00:15<1:47:46,  1.54it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.4359193706512451 0.43813631534576414\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 25/10000 [00:15<1:52:55,  1.47it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.604903234988451 0.7719255113601684\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 26/10000 [00:16<1:50:09,  1.51it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.719418159276247 0.8597093009948731\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 27/10000 [00:17<1:48:41,  1.53it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.6540668119490147 0.8269237041473388\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 28/10000 [00:17<1:42:27,  1.62it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.6755216449499131 0.6778575253486634\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 29/10000 [00:18<1:44:33,  1.59it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.7670408098399639 0.8635321283340454\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 30/10000 [00:19<1:42:34,  1.62it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.7357065285742284 0.7674863815307618\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 31/10000 [00:19<1:42:01,  1.63it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.5257382009923458 0.6038850545883179\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 32/10000 [00:20<1:40:55,  1.65it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.7562592967599631 0.8489441204071044\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 33/10000 [00:20<1:41:02,  1.64it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.7419671555608511 0.6384826993942261\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 34/10000 [00:21<1:39:34,  1.67it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.5094688713550568 0.5850406253337861\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 35/10000 [00:21<1:38:14,  1.69it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.6828008151799441 0.834995174407959\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 36/10000 [00:22<1:40:15,  1.66it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.693244771361351 0.43963317394256596\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 37/10000 [00:23<1:42:18,  1.62it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.99 0.5409245243668557 0.5249174827337265\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 38/10000 [00:23<1:38:48,  1.68it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.8268647665530443 0.862571783065796\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 39/10000 [00:24<1:36:41,  1.72it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.6609055684506893 0.7854983806610107\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 40/10000 [00:24<1:37:23,  1.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.7702661310881376 0.8644624185562134\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 41/10000 [00:25<1:34:31,  1.76it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.7515443959832191 0.7780110263824462\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 42/10000 [00:26<1:33:48,  1.77it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.5793018469214439 0.4972599375247956\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 43/10000 [00:26<1:35:18,  1.74it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.5597310189902782 0.6898640275001526\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 44/10000 [00:27<1:35:57,  1.73it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.5720610782504082 0.530562338232994\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 45/10000 [00:27<1:37:16,  1.71it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.3696836334466934 0.3775019180774689\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 46/10000 [00:28<1:37:16,  1.71it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.5718474861979485 0.5949370396137237\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 47/10000 [00:28<1:33:15,  1.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.7549224875867366 0.7554525852203369\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 48/10000 [00:29<1:29:59,  1.84it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.6261935560405254 0.6481360578536988\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 49/10000 [00:29<1:27:42,  1.89it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.7302424760162831 0.7980484527349472\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 50/10000 [00:30<1:26:42,  1.91it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.5010094849765301 0.5881295418739318\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 50/10000 [00:31<1:43:33,  1.60it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0.5665299797058105 0.4773601078987122\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "term1 = list()\n",
    "term2 = list()\n",
    "term3 = list()\n",
    "\n",
    "results = list()\n",
    "\n",
    "\n",
    "for query_idx, data in enumerate(tqdm(test_loader)):\n",
    "    \n",
    "    # Get data\n",
    "    img, label = data\n",
    "    \n",
    "    # Get query information\n",
    "    query_label = y_test[query_idx].item()\n",
    "    query_pred  = test_preds[query_idx].item()\n",
    "    query_img_trans = get_transformed_data([query_idx], test_loader)[0]\n",
    "    query_cont = X_test_c[query_idx]\n",
    "    \n",
    "    # Get query salient feature\n",
    "    query_logits, query_x, query_C = netC(query_img_trans)\n",
    "    query_nb_boxes = get_salient_regions(query_C, query_logits, net_classifier)\n",
    "    window_idx = query_nb_boxes[0][1]  # Only the most NB one\n",
    "    query_feature = query_C[ :, :, window_idx[0]:window_idx[0]+1, window_idx[1]:window_idx[1]+1]\n",
    "\n",
    "    # Get explanation nns\n",
    "    xp_idxs = twin.kneighbors(X=[X_test_c[query_idx]], n_neighbors=NUM_NEIGHBORS, return_distance=False)\n",
    "    \n",
    "    \n",
    "    # How many NNs are predicted and labelled in query prediction\n",
    "    matches = ((train_preds[xp_idxs[0]] == query_pred)*(y_train[xp_idxs[0]] == query_pred).detach().numpy())\n",
    "    t1 = matches.sum() / NUM_NEIGHBORS\n",
    "    \n",
    "    \n",
    "    # Iterate all three salient regions in xp nn\n",
    "    l1_dists    = list()\n",
    "    logit_dists = list()\n",
    "    for i in range(NUM_NEIGHBORS):\n",
    "\n",
    "        xps_img_trans = get_transformed_data([xp_idxs[0][i]], train_loader)\n",
    "        xp_logits, xp_x, xp_C = netC(xps_img_trans[0])\n",
    "        xp_pred = torch.argmax(xp_logits, dim=1).item()\n",
    "        \n",
    "        # Get distance l1\n",
    "        coord, dist_query_to_xp = get_box_xp(xp_C, query_feature)  # get similar latent feature in nn xp\n",
    "        \n",
    "        # Get logit change\n",
    "        xp_C[ :, :, coord[0]:coord[0]+1, coord[1]:coord[1]+1] = 0.0\n",
    "        new_xp_logits = net_classifier(xp_C)\n",
    "        logit_change = xp_logits[0][xp_pred] - new_xp_logits[0][xp_pred]\n",
    "        abs_logit_change = abs( logit_change - query_nb_boxes[0][0] )\n",
    "        \n",
    "        l1_dists.append(dist_query_to_xp)\n",
    "        logit_dists.append(abs_logit_change.item())\n",
    "        \n",
    "    \n",
    "    \n",
    "    t2 = np.array(l1_dists).mean()\n",
    "    t3 = np.array(logit_dists).mean()\n",
    "    \n",
    "    term1.append(t1)\n",
    "    term2.append(t2)\n",
    "    term3.append(t3)\n",
    "    \n",
    "    result = t1 * (1 - t2) * (1 - t3)\n",
    "    results.append(result)\n",
    "    \n",
    "    #print(t1, (1 - t2), (1 - t3))\n",
    "    \n",
    "    \n",
    "    if query_idx == 50:\n",
    "        break\n",
    "                       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 157,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.4574464784934077"
      ]
     },
     "execution_count": 157,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum(results) / len(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 160,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9990196078431373 0.35184913372730514 0.31245239197039143\n"
     ]
    }
   ],
   "source": [
    "print(sum(term1)/len(term1), sum(term2)/len(term2), sum(term3)/len(term3))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### New Evaluation with DkNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 165,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier(algorithm='brute', metric='euclidean', n_neighbors=1)"
      ]
     },
     "execution_count": 165,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "temp_train = list()\n",
    "temp_test = list()\n",
    "\n",
    "for i in range(X_train_C.shape[0]):\n",
    "    x = deepcopy(X_train_x[i].flatten())\n",
    "    if x.sum() == 0.:\n",
    "        temp_train.append(x)\n",
    "    else:\n",
    "        x = x/x.sum()\n",
    "        temp_train.append(x)\n",
    "    \n",
    "for i in range(X_test_C.shape[0]):\n",
    "    x = deepcopy(X_test_x[i].flatten())\n",
    "    if x.sum() == 0.:\n",
    "        temp_test.append(x)\n",
    "    else:\n",
    "        x = x/x.sum()\n",
    "        temp_test.append(x)  \n",
    "        \n",
    "temp_train = np.array(temp_train)\n",
    "temp_test  = np.array(temp_test)\n",
    "\n",
    "# Fit COLE and DkNN\n",
    "twin = KNeighborsClassifier(n_neighbors=1, algorithm=\"brute\", metric='euclidean') \n",
    "twin.fit(temp_train, train_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 166,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 50/10000 [00:29<1:38:06,  1.69it/s]\n"
     ]
    }
   ],
   "source": [
    "term1 = list()\n",
    "term2 = list()\n",
    "term3 = list()\n",
    "\n",
    "results = list()\n",
    "\n",
    "\n",
    "for query_idx, data in enumerate(tqdm(test_loader)):\n",
    "    \n",
    "    # Get data\n",
    "    img, label = data\n",
    "    \n",
    "    # Get query information\n",
    "    query_label = y_test[query_idx].item()\n",
    "    query_pred  = test_preds[query_idx].item()\n",
    "    query_img_trans = get_transformed_data([query_idx], test_loader)[0]\n",
    "    query_cont = X_test_c[query_idx]\n",
    "    \n",
    "    # Get query salient feature\n",
    "    query_logits, query_x, query_C = netC(query_img_trans)\n",
    "    query_nb_boxes = get_salient_regions(query_C, query_logits, net_classifier)\n",
    "    window_idx = query_nb_boxes[0][1]  # Only the most NB one\n",
    "    query_feature = query_C[ :, :, window_idx[0]:window_idx[0]+1, window_idx[1]:window_idx[1]+1]\n",
    "\n",
    "    # Get explanation nns\n",
    "    xp_idxs = twin.kneighbors(X=[temp_test[query_idx]], n_neighbors=NUM_NEIGHBORS, return_distance=False)\n",
    "    \n",
    "    \n",
    "    # How many NNs are predicted and labelled in query prediction\n",
    "    matches = ((train_preds[xp_idxs[0]] == query_pred)*(y_train[xp_idxs[0]] == query_pred).detach().numpy())\n",
    "    t1 = matches.sum() / NUM_NEIGHBORS\n",
    "    \n",
    "    # Iterate all three salient regions in xp nn\n",
    "    l1_dists    = list()\n",
    "    logit_dists = list()\n",
    "    for i in range(NUM_NEIGHBORS):\n",
    "\n",
    "        xps_img_trans = get_transformed_data([xp_idxs[0][i]], train_loader)\n",
    "        xp_logits, xp_x, xp_C = netC(xps_img_trans[0])\n",
    "        xp_pred = torch.argmax(xp_logits, dim=1).item()\n",
    "        \n",
    "        # Get distance l1\n",
    "        coord, dist_query_to_xp = get_box_xp(xp_C, query_feature)  # get similar latent feature in nn xp\n",
    "        \n",
    "        # Get logit change\n",
    "        xp_C[ :, :, coord[0]:coord[0]+1, coord[1]:coord[1]+1] = 0.0\n",
    "        new_xp_logits = net_classifier(xp_C)\n",
    "        logit_change = xp_logits[0][xp_pred] - new_xp_logits[0][xp_pred]\n",
    "        abs_logit_change = abs( logit_change - query_nb_boxes[0][0] )\n",
    "        \n",
    "        l1_dists.append(dist_query_to_xp)\n",
    "        logit_dists.append(abs_logit_change.item())\n",
    "        \n",
    "    t2 = np.array(l1_dists).mean()\n",
    "    t3 = np.array(logit_dists).mean()\n",
    "    \n",
    "    term1.append(t1)\n",
    "    term2.append(t2)\n",
    "    term3.append(t3)\n",
    "    \n",
    "    result = t1 * (1 - t2) * (1 - t3)\n",
    "    results.append(result)\n",
    "    \n",
    "    \n",
    "    if query_idx == 50:\n",
    "        break\n",
    "                       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 167,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.47044801993758123"
      ]
     },
     "execution_count": 167,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum(results) / len(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 168,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9923529411764705 0.34030622801798227 0.30206476715849895\n"
     ]
    }
   ],
   "source": [
    "print(sum(term1)/len(term1), sum(term2)/len(term2), sum(term3)/len(term3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "neurips_img",
   "language": "python",
   "name": "neurips_img"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
