{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Quantitative Experiments\n",
    "\n",
    "1. Test how much the salient regions on a query overlap with the salient regions on the nearest neighbors\n",
    "\n",
    "2. Do Sanity Check on salient areas by blacking out the input pixels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import pickle\n",
    "import time\n",
    "import scipy\n",
    "import matplotlib.pyplot as plt\n",
    "import torchvision.models as models\n",
    "\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "\n",
    "from functions import *\n",
    "from ANNs import *\n",
    "from tqdm import tqdm\n",
    "from copy import deepcopy\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pre-Requisites for Algorithm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "netC = CNN()\n",
    "netC.load_state_dict(torch.load('weights/cnn.pth', map_location=torch.device('cpu')))\n",
    "netC = netC.eval()\n",
    "DATAROOT = 'data'\n",
    "DEVICE = 'cpu'\n",
    "# MEAN_NORM=(0.485, 0.456, 0.406)\n",
    "# STD_NORM=(0.229, 0.224, 0.225)\n",
    "MEAN_NORM=(0.5, 0.5, 0.5)\n",
    "STD_NORM=(0.5, 0.5, 0.5)\n",
    "\n",
    "NUM_NEIGHBORS = 1  # Num neighbors to check for features\n",
    "XP_FEATURE_NUM = 1  # Num box features to show"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "class netClassifier(nn.Module):\n",
    "    \n",
    "    def __init__(self, netC):\n",
    "        super(netClassifier, self).__init__()\n",
    "        self.net = netC\n",
    "        \n",
    "    def forward(self, C):\n",
    "        x = self.net.avgpool(C)\n",
    "        x = x.view(-1, 128)\n",
    "        logits = self.net.linear(x)\n",
    "        return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "def inverse_normalize(tensor, mean=MEAN_NORM, std=STD_NORM):\n",
    "    for t, m, s in zip(tensor, mean, std):\n",
    "        t.mul_(s).add_(m)\n",
    "    return tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_transformed_data(idxs, loader):\n",
    "    \"\"\"\n",
    "    Takes in indexs and a dataloader\n",
    "    returns: Those indexs transformed\n",
    "    \"\"\"\n",
    "    \n",
    "    subset = torch.utils.data.Subset(loader.dataset, idxs)\n",
    "    loader_subset = torch.utils.data.DataLoader(subset, batch_size=1, num_workers=0, shuffle=False)\n",
    "\n",
    "    transformed_data = list()\n",
    "    for data in loader_subset:\n",
    "        transformed_data.append(data[0])\n",
    "        \n",
    "    return transformed_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_box_xp(Conv, latent_feature):\n",
    "    \"\"\"\n",
    "    Get region in Conv which is most similar to latent feature\n",
    "    \"\"\"\n",
    "    \n",
    "    max_dist = float('inf')\n",
    "    coord = [None, None]\n",
    "    for i in range(Conv.shape[2]):\n",
    "        for j in range(Conv.shape[3]):\n",
    "            query_seg = Conv[:, :, i:i+1, j:j+1 ]\n",
    "            new_dist = torch.cdist(query_seg.view(-1, Conv.shape[1]),\n",
    "                                   latent_feature.view(-1, Conv.shape[1]), p=1.0).item()\n",
    "            if new_dist < max_dist:\n",
    "                max_dist = new_dist\n",
    "                coord = [i, j]\n",
    "                \n",
    "    return coord, max_dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_box_query(Conv, org_logits, net_classifier):\n",
    "    \n",
    "    results = list()\n",
    "    pred = torch.argmax(org_logits, dim=1).item()\n",
    "    \n",
    "    for i in range(Conv.shape[2]):\n",
    "        for j in range(Conv.shape[3]):\n",
    "            C_copy = Conv.clone().detach()\n",
    "            C_copy[:, :, i:i+1, j:j+1 ] = 0\n",
    "            new_logit = net_classifier(C_copy)[0][pred]\n",
    "            logit_change = org_logits[0][pred].item() - new_logit.item()   # Must modify for negative logits\n",
    "            results.append( [logit_change, [i, j]] )\n",
    "            \n",
    "    return sorted(results, key=lambda x: x[0], reverse=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "net_classifier = netClassifier(netC)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "LABELS_DICT = {0:\"0\", 1:\"1\", 2:\"2\", 3:\"3\", 4:\"4\", 5:\"5\", 6:\"6\", 7:\"7\", 8:\"8\", 9:\"9\" }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights = netC.linear.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, test_loader = load_dataloaders()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = train_loader.dataset.data\n",
    "y_train = train_loader.dataset.targets\n",
    "\n",
    "X_test = test_loader.dataset.data\n",
    "y_test = test_loader.dataset.targets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "# X_train, y_train, X_test, y_test = get_MNIST_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_c = np.load(DATAROOT + \"/X_train_cont.npy\")\n",
    "X_test_c = np.load(DATAROOT + \"/X_test_cont.npy\")\n",
    "X_train_C = np.load(DATAROOT + \"/X_train_conv.npy\")\n",
    "X_test_C = np.load(DATAROOT + \"/X_test_conv.npy\")\n",
    "\n",
    "X_train_x = np.load(DATAROOT + \"/X_train_x.npy\")\n",
    "X_test_x = np.load(DATAROOT + \"/X_test_x.npy\")\n",
    "train_preds = np.load(DATAROOT + \"/X_train_y.npy\")\n",
    "test_preds = np.load(DATAROOT + \"/X_test_y.npy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier(algorithm='brute', metric='euclidean', n_neighbors=1)"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Fit COLE and DkNN\n",
    "twin = KNeighborsClassifier(n_neighbors=1, algorithm=\"brute\", metric='euclidean') \n",
    "twin.fit(X_train_c, train_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_salient_regions(Conv, org_logits, net_classifier):\n",
    "    \n",
    "    results = list()\n",
    "    pred = torch.argmax(org_logits, dim=1).item()\n",
    "    \n",
    "    for i in range(Conv.shape[2]):\n",
    "        for j in range(Conv.shape[3]):\n",
    "            C_copy = Conv.clone().detach()\n",
    "            C_copy[:, :, i:i+1, j:j+1 ] = 0\n",
    "            new_logit = net_classifier(C_copy)[0][pred]\n",
    "            logit_change = org_logits[0][pred].item() - new_logit.item()   # Must modify for negative logits\n",
    "            results.append( [logit_change, [i, j]] )\n",
    "            \n",
    "    return sorted(results, key=lambda x: x[0], reverse=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "def blackout_C():\n",
    "\n",
    "    box = query_nb_boxes[0][1]\n",
    "    \n",
    "    unit = 4\n",
    "    filter_size = 7\n",
    "\n",
    "    for i in range(filter_size):\n",
    "        for j in range(filter_size):\n",
    "\n",
    "            img_copy = query_img_trans.clone()\n",
    "            img_copy[:, :, i*unit: i*unit+unit, j*unit: j*unit+unit] = -1.0\n",
    "\n",
    "            logits, x, C = netC(img_copy)\n",
    "\n",
    "            if i == box[0] and j == box[1]:\n",
    "                print(logits[0][query_pred].item(), [i, j])\n",
    "                plt.imshow(img_copy[0][0].detach().numpy())\n",
    "                plt.show()\n",
    "\n",
    "                return C, i, j"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Baseline with ExMatchnia"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier(algorithm='brute', metric='euclidean', n_neighbors=1)"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "temp_train = list()\n",
    "temp_test = list()\n",
    "\n",
    "for i in range(X_train_C.shape[0]):\n",
    "    x = deepcopy(X_train_C[i][0].flatten())\n",
    "    x = x/x.sum()\n",
    "    temp_train.append(x)\n",
    "    \n",
    "for i in range(X_test_C.shape[0]):\n",
    "    x = deepcopy(X_test_C[i][0].flatten())\n",
    "    x = x/x.sum()\n",
    "    temp_test.append(x)\n",
    "    \n",
    "temp_train = np.array(temp_train)\n",
    "temp_test  = np.array(temp_test)\n",
    "\n",
    "# Fit COLE and DkNN\n",
    "twin = KNeighborsClassifier(n_neighbors=1, algorithm=\"brute\", metric='euclidean') \n",
    "twin.fit(temp_train, train_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|          | 100/10000 [01:10<1:56:14,  1.42it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Agreement: 1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "results = list()\n",
    "agreement = 0\n",
    "saliency_sim = list()\n",
    "\n",
    "for query_idx, data in enumerate(tqdm(test_loader)):\n",
    "    \n",
    "    # Get data\n",
    "    img, label = data\n",
    "    \n",
    "    # Get query information\n",
    "    query_label = y_test[query_idx].item()\n",
    "    query_pred  = test_preds[query_idx].item()\n",
    "    query_img_trans = get_transformed_data([query_idx], test_loader)[0]\n",
    "    query_cont = X_test_c[query_idx]\n",
    "    \n",
    "    # Get query boxes\n",
    "    query_logits, query_x, query_C = netC(query_img_trans)\n",
    "    query_nb_boxes = get_salient_regions(query_C, query_logits, net_classifier)\n",
    "    \n",
    "    # Get explanation nn\n",
    "    xp_idxs = twin.kneighbors(X=[temp_test[query_idx]], n_neighbors=NUM_NEIGHBORS, return_distance=False)\n",
    "    xps_imgs_trans = get_transformed_data(xp_idxs[0], train_loader)\n",
    "    xp_logits, xp_x, xp_C = netC(xps_imgs_trans[0])\n",
    "    xp_nb_boxes = get_salient_regions(xp_C, xp_logits, net_classifier)\n",
    "    \n",
    "    if train_preds[xp_idxs[0][0]] == query_pred:\n",
    "        agreement += 1\n",
    "        \n",
    "    saliency_sim.append(abs(query_nb_boxes[0][0] - xp_nb_boxes[0][0]))\n",
    "    \n",
    "    # See distance between query nb feature and nn\n",
    "    window_idx = query_nb_boxes[0][1]  # -1 just to make hyperparam easier to think about\n",
    "    query_feature = query_C[ :, :, window_idx[0]:window_idx[0]+1, window_idx[1]:window_idx[1]+1]\n",
    "\n",
    "    window_idx = xp_nb_boxes[0][1]  # -1 just to make hyperparam easier to think about\n",
    "    xp_feature = xp_C[ :, :, window_idx[0]:window_idx[0]+1, window_idx[1]:window_idx[1]+1]\n",
    "\n",
    "    dist = sum(abs(query_feature.flatten() - xp_feature.flatten()))\n",
    "\n",
    "    results.append(dist.item())\n",
    "        \n",
    "    if query_idx == 100:\n",
    "        print(\"Agreement:\", agreement / len(results))\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [],
   "source": [
    "saliency_sim = np.array(saliency_sim)\n",
    "ex_matchina = np.array(deepcopy(results))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "58.37529969923567\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<seaborn.axisgrid.FacetGrid at 0x139f05b50>"
      ]
     },
     "execution_count": 75,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFgCAYAAACFYaNMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAARrUlEQVR4nO3de4yldX3H8fdHVtR4KaLbzbosBSu1NW28ZLWIl0SxdrUqaC1qjG5T7NpUG42tFmvS2KR/aC9q2xh1K8a18YI3AtoWRURNo6ILooBoQQrCuuwuqNVeol349o95Vo/DzO647vN8DzvvVzKZc37nzJ5vnpl57zPPnPNMqgpJ0vTu0j2AJK1WBliSmhhgSWpigCWpiQGWpCZrugdYic2bN9cFF1zQPYYkHaostXin2AO+5ZZbukeQpMPuThFgSToSGWBJamKAJamJAZakJqM+CyLJ9cD3gduAfVW1KcmxwDnACcD1wBlV9Z0x55CkeTTFHvATquphVbVpuH4WcFFVnQRcNFyXpFWn4xDEacD24fJ24PSGGSSp3dgBLuDjSS5NsnVYW1dVu4bLNwPrlvrAJFuT7EiyY+/evSOPKUnTG/uVcI+tqp1Jfh64MMnXZm+sqkqy5AmJq2obsA1g06ZNnrRY0hFn1D3gqto5vN8DnAs8CtidZD3A8H7PmDNI0rwaLcBJ7pnk3vsvA08GrgTOB7YMd9sCnDfWDJI0z8Y8BLEOODfJ/sd5T1VdkOSLwPuTnAncAJwx4gySNLdGC3BVXQc8dIn1W4FTx3pcSbqz8JVwktTkiA7who3Hk6T9bcPG47s3haQ5dKc4Ifuh+tZNN/Kct322ewzOefEp3SNImkNH9B6wJM0zAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUZPcBJjkrypSQfHa6fmOSSJNcmOSfJ0WPPIEnzaIo94JcBV89cfz3wxqp6EPAd4MwJZpCkuTNqgJMcB/wW8PbheoAnAh8c7rIdOH3MGSRpXo29B/wm4FXA7cP1+wHfrap9w/WbgA1LfWCSrUl2JNmxd+/ekceUpOmNFuAkTwP2VNWlh/LxVbWtqjZV1aa1a9ce5ukkqd+aEf/txwDPSPJU4O7AfYC/A45JsmbYCz4O2DniDJI0t0bbA66qV1fVcVV1AvBc4JNV9XzgYuDZw922AOeNNYMkzbOO5wH/KfCKJNeycEz47IYZJKndmIcgfqSqPgV8arh8HfCoKR5XkuaZr4STpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJqMFOMndk3whyZeTXJXkL4b1E5NckuTaJOckOXqsGSRpno25B/wD4IlV9VDgYcDmJCcDrwfeWFUPAr4DnDniDJI0t0YLcC34r+HqXYe3Ap4IfHBY3w6cPtYMkjTPRj0GnOSoJJcDe4ALgW8A362qfcNdbgI2LPOxW5PsSLJj7969Y44pSS1GDXBV3VZVDwOOAx4F/PJP8bHbqmpTVW1au3btWCNKUptJngVRVd8FLgYeDRyTZM1w03HAzilmkKR5M+azINYmOWa4fA/gN4CrWQjxs4e7bQHOG2sGSZpnaw5+l0O2Htie5CgWQv/+qvpokq8C70vyl8CXgLNHnEGS5tZoAa6qrwAPX2L9OhaOB0vSquYr4SSpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqcmKApzkMStZkySt3Er3gP9hhWuSpBU64LkgkjwaOAVYm+QVMzfdBzhqzMEk6Uh3sJPxHA3ca7jfvWfWv8ePTykpSToEBwxwVX0a+HSSd1bVDRPNJEmrwkpPR3m3JNuAE2Y/pqqeOMZQkrQarDTAHwDeCrwduG28cSRp9VhpgPdV1VtGnUSSVpmVPg3tI0n+MMn6JMfufxt1Mkk6wq10D3jL8P6VM2sFPPDwjiNJq8eKAlxVJ449iCStNisKcJIXLrVeVe86vONI0uqx0kMQj5y5fHfgVOAywABL0iFa6SGIP5q9nuQY4H1jDCRJq8Whno7yvwGPC0vSz2Clx4A/wsKzHmDhJDy/Arx/rKEkaTVY6THgv5m5vA+4oapuGmEeSVo1VnQIYjgpz9dYOCPafYEfjjmUJK0GK/2LGGcAXwB+BzgDuCSJp6OUpJ/BSg9BvAZ4ZFXtAUiyFvgE8MGxBpOkI91KnwVxl/3xHdz6U3ysJGkJK90DviDJx4D3DtefA/zLOCNJ0upwsL8J9yBgXVW9MsmzgMcON30OePfYw0nSkexge8BvAl4NUFUfBj4MkOTXhtuePuJsknREO9hx3HVVdcXixWHthFEmkqRV4mABPuYAt93jMM4hSavOwQK8I8nvL15M8iLg0nFGkqTV4WDHgF8OnJvk+fw4uJuAo4FnjjiXJB3xDhjgqtoNnJLkCcCvDsv/XFWfHH0ySTrCrfR8wBcDF488iyStKr6aTZKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJanJaAFOsjHJxUm+muSqJC8b1o9NcmGSa4b39x1rBkmaZ2PuAe8D/riqHgKcDLwkyUOAs4CLquok4KLhuiStOqMFuKp2VdVlw+XvA1cDG4DTgO3D3bYDp481gyTNs0mOASc5AXg4cAmwrqp2DTfdDKxb5mO2JtmRZMfevXunGFOSJjV6gJPcC/gQ8PKq+t7sbVVVQC31cVW1rao2VdWmtWvXjj2mJE1u1AAnuSsL8X13VX14WN6dZP1w+3pgz5gzSNK8GvNZEAHOBq6uqjfM3HQ+sGW4vAU4b6wZJGmerejP0h+ixwAvAK5Icvmw9mfA64D3JzkTuAE4Y8QZJGlujRbgqvo3IMvcfOpYjytJdxa+Ek6SmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmqzpHmBVuMsaknRPAcADjtvIzhu/2T2GJAzwNG7fx3Pe9tnuKQA458WndI8gaeAhCElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWriX8RYbfzzSNLcMMCrjX8eSZobHoKQpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqMlqAk7wjyZ4kV86sHZvkwiTXDO/vO9bjS9K8G3MP+J3A5kVrZwEXVdVJwEXDdUlalUYLcFV9Bvj2ouXTgO3D5e3A6WM9viTNu6mPAa+rql3D5ZuBdcvdMcnWJDuS7Ni7d+8000nShNp+CVdVBdQBbt9WVZuqatPatWsnnEySpjF1gHcnWQ8wvN8z8eNL0tyYOsDnA1uGy1uA8yZ+fEmaG2M+De29wOeABye5KcmZwOuA30hyDfCk4bokrUqjnQ+4qp63zE2njvWYknRn4ivhJKmJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmoz2PGDpoO6yhiTdU/CA4zay88Zvdo+hVcgAq8/t+3jO2z7bPQXnvPiU7hG0SnkIQpKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCXdwYaNx5NkLt42bDy+e3OMxnNBSLqDb91041ycpwOO7HN1uAcsSU0MsCQ1McCS1MQAS1ITAyxJTXwWhDQnfxoJ/PNIq40BlubkTyPBkf2UK92RhyAkqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJr4UWZonc3ReirkxR9vkcJ+rwwBL82ROzksxV+ekmJNtAod/u3gIQpKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJatIS4CSbk3w9ybVJzuqYQZK6TR7gJEcBbwaeAjwEeF6Sh0w9hyR169gDfhRwbVVdV1U/BN4HnNYwhyS1SlVN+4DJs4HNVfWi4foLgF+vqpcuut9WYOtw9cHA1ycd9I7uD9zSPMNiznRw8zYPzN9M8zYPHHkz3VJVmxcvzu2fJKqqbcC27jn2S7KjqjZ1zzHLmQ5u3uaB+Ztp3uaB1TNTxyGIncDGmevHDWuStKp0BPiLwElJTkxyNPBc4PyGOSSp1eSHIKpqX5KXAh8DjgLeUVVXTT3HIZibwyEznOng5m0emL+Z5m0eWCUzTf5LOEnSAl8JJ0lNDLAkNTHAS0iyMcnFSb6a5KokLxvWX5tkZ5LLh7enTjjT9UmuGB53x7B2bJILk1wzvL/vhPM8eGY7XJ7ke0lePvU2SvKOJHuSXDmztuR2yYK/H14C/5Ukj5honr9O8rXhMc9NcsywfkKS/53ZVm893PMcYKZlP09JXj1so68n+c0JZzpnZp7rk1w+rI++nQ7wPT/u11JV+bboDVgPPGK4fG/g31l42fRrgT9pmul64P6L1v4KOGu4fBbw+qbZjgJuBn5h6m0EPB54BHDlwbYL8FTgX4EAJwOXTDTPk4E1w+XXz8xzwuz9Jt5GS36ehq/zLwN3A04EvgEcNcVMi27/W+DPp9pOB/ieH/VryT3gJVTVrqq6bLj8feBqYEPvVEs6Ddg+XN4OnN40x6nAN6rqhqkfuKo+A3x70fJy2+U04F214PPAMUnWjz1PVX28qvYNVz/PwnPfJ7PMNlrOacD7quoHVfUfwLUsnD5gspmSBDgDeO/hftwDzLPc9/yoX0sG+CCSnAA8HLhkWHrp8CPHO6b8kR8o4ONJLh1epg2wrqp2DZdvBtZNOM+s5/KT3yxd22i/5bbLBuDGmfvdxPT/sf4eC3tO+52Y5EtJPp3kcRPPstTnaR620eOA3VV1zczaZNtp0ff8qF9LBvgAktwL+BDw8qr6HvAW4BeBhwG7WPgxaSqPrapHsHAWuZckefzsjbXwc9HkzynMwotpngF8YFjq3EZ30LVdlpLkNcA+4N3D0i7g+Kp6OPAK4D1J7jPROHP1eVrkefzkf+iTbaclvud/ZIyvJQO8jCR3ZeET8e6q+jBAVe2uqtuq6nbgHxnhR7PlVNXO4f0e4NzhsXfv/7FneL9nqnlmPAW4rKp2D/O1baMZy22XtpfBJ/ld4GnA84dvZIYf828dLl/KwvHWX5pingN8nlpPFZBkDfAs4JyZWSfZTkt9zzPy15IBXsJwDOps4OqqesPM+uwxnmcCVy7+2JHmuWeSe++/zMIvda5k4SXcW4a7bQHOm2KeRX5ib6VrGy2y3HY5H3jh8Bvsk4H/nPnxcjRJNgOvAp5RVf8zs742C+fHJskDgZOA68aeZ3i85T5P5wPPTXK3JCcOM31hipkGTwK+VlU37V+YYjst9z3P2F9LY/5m8c76BjyWhR81vgJcPrw9Ffgn4Iph/Xxg/UTzPJCF30x/GbgKeM2wfj/gIuAa4BPAsRNvp3sCtwI/N7M26TZiIf67gP9j4TjcmcttFxZ+Y/1mFvagrgA2TTTPtSwcL9z/tfTW4b6/PXw+LwcuA54+4TZa9vMEvGbYRl8HnjLVTMP6O4E/WHTf0bfTAb7nR/1a8qXIktTEQxCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElN/h9s3cWG4FJZLwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(ex_matchina.mean())\n",
    "sns.displot(ex_matchina)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.21728234524183934\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<seaborn.axisgrid.FacetGrid at 0x2280cd2d0>"
      ]
     },
     "execution_count": 76,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFgCAYAAACFYaNMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAARrElEQVR4nO3da4ylBX3H8e8PVsRWLFhXQtbd4r0SW9GMVtG0KmqQF6KtFYkXTNBFLUajMTH6ovbyQhMvTVtjWZWIjRdQseKlWESUKIpdFbl6pVgWVnbw3jZVV/59cR7qBHd2zu7Oc/5ndr6f5GTPec5zzvOfycw3zz5znnNSVUiSZu+Q7gEkab0ywJLUxABLUhMDLElNDLAkNdnQPcA0TjrppLrooou6x5Ck5WR/HrQm9oBvu+227hEkadWtiQBL0sHIAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwO8gk2bt5DkgC+bNm/p/lIkzZk18YbsnW7ZcROnnn35AT/PeWeesArTSDqYuAcsSU0MsCQ1McCS1MQAS1KT0QKc5PAkX07y9STXJvmrYfl9k1yR5DtJzkty2FgzSNI8G3MP+OfAE6vqYcDxwElJHg28EXhrVT0A+BFwxogzSNLcGi3ANfFfw827DJcCngh8aFh+LvD0sWaQpHk26jHgJIcmuRLYBVwMfBf4cVXtHlbZAWwacwZJmlejBriqflVVxwP3AR4F/P60j02yNcn2JNsXFxfHGlGS2szkVRBV9WPgUuAxwJFJ7jgD7z7Azcs8ZltVLVTVwsaNG2cxpiTN1JivgtiY5Mjh+t2AJwPXMwnxM4fVTgc+OtYMkjTPxnwviGOAc5McyiT051fVx5NcB3wgyd8CXwPeNeIMkjS3RgtwVV0FPHwPy29gcjxYktY1z4STpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJqMFOMnmJJcmuS7JtUlePix/fZKbk1w5XE4eawZJmmcbRnzu3cCrquqrSY4AvpLk4uG+t1bVm0bctiTNvdECXFU7gZ3D9Z8luR7YNNb2JGmtmckx4CTHAg8HrhgWnZXkqiTnJDlqmcdsTbI9yfbFxcVZjClJMzV6gJPcHfgw8Iqq+inwduD+wPFM9pDfvKfHVdW2qlqoqoWNGzeOPaYkzdyoAU5yFybxfW9VXQBQVbdW1a+q6nbgHcCjxpxBkubVmK+CCPAu4PqqesuS5ccsWe0ZwDVjzSBJ82zMV0E8FngecHWSK4dlrwVOS3I8UMCNwJkjziBJc2vMV0F8Hsge7vrkWNuUpLXEM+EkqYkBlqQmBliSmhhgSWpy0AZ40+YtJDngiySNZcyXobW6ZcdNnHr25Qf8POedecIqTCNJv+mg3QOWpHlngCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqcloAU6yOcmlSa5Lcm2Slw/L75nk4iTfHv49aqwZJGmejbkHvBt4VVUdBzwa+IskxwGvAS6pqgcClwy3JWndGS3AVbWzqr46XP8ZcD2wCTgFOHdY7Vzg6WPNIEnzbCbHgJMcCzwcuAI4uqp2Dnd9Hzh6mcdsTbI9yfbFxcVZjClJMzV6gJPcHfgw8Iqq+unS+6qqgNrT46pqW1UtVNXCxo0bxx5TkmZu1AAnuQuT+L63qi4YFt+a5Jjh/mOAXWPOIEnzasxXQQR4F3B9Vb1lyV0XAqcP108HPjrWDJI0zzaM+NyPBZ4HXJ3kymHZa4E3AOcnOQP4HvCsEWeQpLk1WoCr6vNAlrn7xLG2K0lrhWfCSVITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNZkqwEkeO80ySdL0pt0D/ocpl0mSprTXT0VO8hjgBGBjklcuuesewKFjDiZJB7uVPpb+MODuw3pHLFn+U+CZYw0lSevBXgNcVZ8DPpfk3VX1vRnNJEnrwkp7wHe4a5JtwLFLH1NVTxxjKElaD6YN8AeBfwLeCfxqvHEkaf2YNsC7q+rto04iSevMtC9D+1iSlyY5Jsk977iMOpkkHeSm3QM+ffj31UuWFXC/1R1HktaPqQJcVfcdexBJWm+mCnCS5+9peVW9Z3XHkaT1Y9pDEI9ccv1w4ETgq4ABlqT9NO0hiJctvZ3kSOADYwwkSevF/r4d5X8DHheWpAMw7THgjzF51QNM3oTnIcD5Yw0lSevBtMeA37Tk+m7ge1W1Y4R5JGndmOoQxPCmPN9g8o5oRwG/GHMoSVoPpv1EjGcBXwb+HHgWcEUS345Skg7AtIcgXgc8sqp2ASTZCHwa+NBYg0nSwW7aV0Ecckd8Bz/Yh8dKkvZg2ohelORTSV6Q5AXAJ4BP7u0BSc5JsivJNUuWvT7JzUmuHC4n7//okrS2rfSZcA8Ajq6qVyf5U+Bxw11fBN67wnO/G/hHfvNsubdW1Zt+c3VJWl9W2gP+Oyaf/0ZVXVBVr6yqVwIfGe5bVlVdBvxwFWaUpIPSSgE+uqquvvPCYdmx+7nNs5JcNRyiOGo/n0OS1ryVAnzkXu67235s7+3A/YHjgZ3Am5dbMcnWJNuTbF9cXNyPTc2ZQzaQ5IAvmzZv6f5KJK2SlV6Gtj3Ji6rqHUsXJnkh8JV93VhV3brkOd4BfHwv624DtgEsLCzUcuutGbfv5tSzLz/gpznvzBNWYRhJ82ClAL8C+EiS5/Dr4C4AhwHP2NeNJTmmqnYON58BXLO39SXpYLbXAA97rCckeQLw0GHxJ6rqMys9cZL3A48H7pVkB/CXwOOTHM/kjX1uBM7c78klaY2b9v2ALwUu3ZcnrqrT9rD4XfvyHJJ0MPNsNklqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWoyWoCTnJNkV5Jrliy7Z5KLk3x7+PeosbYvSfNuzD3gdwMn3WnZa4BLquqBwCXDbUlal0YLcFVdBvzwTotPAc4drp8LPH2s7UvSvJv1MeCjq2rncP37wNHLrZhka5LtSbYvLi7OZrq14JANJDngy6bNW7q/Emnd29C14aqqJLWX+7cB2wAWFhaWXW/duX03p559+QE/zXlnnrAKw0g6ELPeA741yTEAw7+7Zrx9SZobsw7whcDpw/XTgY/OePuSNDfGfBna+4EvAg9OsiPJGcAbgCcn+TbwpOG2JK1Lox0DrqrTlrnrxLG2KUlriWfCSVITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUZEPHRpPcCPwM+BWwu6oWOuaQpE4tAR48oapua9y+JLXyEIQkNekKcAH/luQrSbbuaYUkW5NsT7J9cXFxxuNpWps2byHJAV82bd7S/aVIM9d1COJxVXVzknsDFyf5RlVdtnSFqtoGbANYWFiojiG1slt23MSpZ19+wM9z3pknrMI00trSsgdcVTcP/+4CPgI8qmMOSeo08wAn+e0kR9xxHXgKcM2s55Ckbh2HII4GPpLkju2/r6ouaphDklrNPMBVdQPwsFlvV5LmjS9Dk6QmBliSmhhgSWpigCWpiQFerw7ZsCpnsGnvVutMQc8WPDh1vhmPOt2+2zPYZmC1zhQEv9cHI/eAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigDUfVunMPM8W01rimXCaD56Zp3XIPWBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkB1sFllT5bbsNhh6/K88yjTZu3zNX3aLU+x2+1vq5Zfq6gnwmng8sqfrbcwfoZdbfsuOmg/B6t5tc1K+4BS1ITAyxJTQywJDUxwJLUxABLUpOWACc5Kck3k3wnyWs6ZpCkbjMPcJJDgbcBTwWOA05Lctys55Ckbh17wI8CvlNVN1TVL4APAKc0zCFJrVJVs91g8kzgpKp64XD7ecAfVdVZd1pvK7B1uPlg4Jv7uKl7Abcd4Lid1vL8zt7D2fscXlUP3dcHze2ZcFW1Ddi2v49Psr2qFlZxpJlay/M7ew9n75Nk+/48ruMQxM3A5iW37zMsk6R1pSPA/w48MMl9kxwGPBu4sGEOSWo180MQVbU7yVnAp4BDgXOq6toRNrXfhy/mxFqe39l7OHuf/Zp/5n+EkyRNeCacJDUxwJLUZM0HeKXTmpPcNcl5w/1XJDm2Ycw9mmL2Vya5LslVSS5J8nsdcy5n2lPKk/xZkkoyNy8zmmb2JM8avv/XJnnfrGdczhQ/N1uSXJrka8PPzskdc+5JknOS7EpyzTL3J8nfD1/bVUkeMesZlzPF7M8ZZr46yeVJHrbik1bVmr0w+SPed4H7AYcBXweOu9M6LwX+abj+bOC87rn3YfYnAL81XH/JvMw+7fzDekcAlwFfAha6596H7/0Dga8BRw2379099z7Mvg14yXD9OODG7rmXzPbHwCOAa5a5/2TgX4EAjwau6J55H2Y/YcnPy1OnmX2t7wFPc1rzKcC5w/UPAScmc/FhXSvOXlWXVtX/DDe/xOQ10/Ni2lPK/wZ4I/C/sxxuBdPM/iLgbVX1I4Cq2jXjGZczzewF3GO4/jvALTOcb6+q6jLgh3tZ5RTgPTXxJeDIJMfMZrq9W2n2qrr8jp8Xpvx9XesB3gTctOT2jmHZHtepqt3AT4Dfncl0ezfN7EudwWTPYF6sOP/w38fNVfWJWQ42hWm+9w8CHpTkC0m+lOSkmU23d9PM/nrguUl2AJ8EXjab0VbFvv5ezKupfl/n9lRk/VqS5wILwJ90zzKtJIcAbwFe0DzK/trA5DDE45nsyVyW5A+q6sedQ03pNODdVfXmJI8B/jnJQ6vq9u7B1oMkT2AS4MettO5a3wOe5rTm/18nyQYm/yX7wUym27upTslO8iTgdcDTqurnM5ptGivNfwTwUOCzSW5kcjzvwjn5Q9w03/sdwIVV9cuq+g/gW0yC3G2a2c8Azgeoqi8ChzN5s5u1YE2/VUGSPwTeCZxSVSt2Zq0HeJrTmi8ETh+uPxP4TA1HyZutOHuShwNnM4nvvByDvMNe56+qn1TVvarq2Ko6lskxsadV1X69ackqm+bn5l+Y7P2S5F5MDkncMMMZlzPN7P8JnAiQ5CFMArw40yn334XA84dXQzwa+ElV7eweahpJtgAXAM+rqm9N9aDuvyyuwl8mT2ayd/Jd4HXDsr9m8ssOkx++DwLfAb4M3K975n2Y/dPArcCVw+XC7pn3Zf47rftZ5uRVEFN+78PkEMp1wNXAs7tn3ofZjwO+wOQVElcCT+meecns7wd2Ar9k8r+MM4AXAy9e8n1/2/C1XT1nPzMrzf5O4EdLfl+3r/ScnoosSU3W+iEISVqzDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1OT/AEPGtnntT+qwAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(saliency_sim.mean())\n",
    "sns.displot(saliency_sim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Baseline: D-kNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier(algorithm='brute', metric='euclidean', n_neighbors=1)"
      ]
     },
     "execution_count": 78,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "temp_train = list()\n",
    "temp_test = list()\n",
    "\n",
    "for i in range(X_train_C.shape[0]):\n",
    "    x = deepcopy(X_train_x[i].flatten())\n",
    "    if x.sum() == 0.:\n",
    "        temp_train.append(x)\n",
    "    else:\n",
    "        x = x/x.sum()\n",
    "        temp_train.append(x)\n",
    "    \n",
    "for i in range(X_test_C.shape[0]):\n",
    "    x = deepcopy(X_test_x[i].flatten())\n",
    "    if x.sum() == 0.:\n",
    "        temp_test.append(x)\n",
    "    else:\n",
    "        x = x/x.sum()\n",
    "        temp_test.append(x)  \n",
    "        \n",
    "temp_train = np.array(temp_train)\n",
    "temp_test  = np.array(temp_test)\n",
    "\n",
    "# Fit COLE and DkNN\n",
    "twin = KNeighborsClassifier(n_neighbors=1, algorithm=\"brute\", metric='euclidean') \n",
    "twin.fit(temp_train, train_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|          | 100/10000 [00:04<07:03, 23.38it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Agreement: 1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "results = list()\n",
    "agreement = 0\n",
    "saliency_sim = list()\n",
    "\n",
    "for query_idx, data in enumerate(tqdm(test_loader)):\n",
    "    \n",
    "    # Get data\n",
    "    img, label = data\n",
    "    \n",
    "    # Get query information\n",
    "    query_label = y_test[query_idx].item()\n",
    "    query_pred  = test_preds[query_idx].item()\n",
    "    query_img_trans = get_transformed_data([query_idx], test_loader)[0]\n",
    "    query_cont = X_test_c[query_idx]\n",
    "    \n",
    "    # Get query boxes\n",
    "    query_logits, query_x, query_C = netC(query_img_trans)\n",
    "    query_nb_boxes = get_salient_regions(query_C, query_logits, net_classifier)\n",
    "    \n",
    "    # Get explanation nn\n",
    "    xp_idxs = twin.kneighbors(X=[temp_test[query_idx]], n_neighbors=NUM_NEIGHBORS, return_distance=False)\n",
    "    xps_imgs_trans = get_transformed_data(xp_idxs[0], train_loader)\n",
    "    xp_logits, xp_x, xp_C = netC(xps_imgs_trans[0])\n",
    "    xp_nb_boxes = get_salient_regions(xp_C, xp_logits, net_classifier)\n",
    "    \n",
    "    if train_preds[xp_idxs[0][0]] == query_pred:\n",
    "        agreement += 1\n",
    "        \n",
    "    saliency_sim.append(abs(query_nb_boxes[0][0] - xp_nb_boxes[0][0]))\n",
    "    \n",
    "    # See distance between query nb feature and nn\n",
    "    window_idx = query_nb_boxes[0][1]  # -1 just to make hyperparam easier to think about\n",
    "    query_feature = query_C[ :, :, window_idx[0]:window_idx[0]+1, window_idx[1]:window_idx[1]+1]\n",
    "\n",
    "    window_idx = xp_nb_boxes[0][1]  # -1 just to make hyperparam easier to think about\n",
    "    xp_feature = xp_C[ :, :, window_idx[0]:window_idx[0]+1, window_idx[1]:window_idx[1]+1]\n",
    "\n",
    "    dist = sum(abs(query_feature.flatten() - xp_feature.flatten()))\n",
    "\n",
    "    results.append(dist.item())\n",
    "        \n",
    "    if query_idx == 100:\n",
    "        print(\"Agreement:\", agreement / len(results))\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "saliency_sim = np.array(saliency_sim)\n",
    "dknn = np.array(deepcopy(results))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "69.84689330110456\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<seaborn.axisgrid.FacetGrid at 0x1338356d0>"
      ]
     },
     "execution_count": 81,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFgCAYAAACFYaNMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAP0UlEQVR4nO3dXYxc9XmA8efFxhAVIkNYWcY2whSU1FJVQBsKBkUVNK1DP0wqiomixBdOjdRQBaVNY8pNkHoRqjYfraIUN6A4FQJTQgSJUlJCTaKK1HRJzFcciqGA7Ri8NFDSXiQ1fnsxx2VYeXfHy555Z2een7TamTMz3vfvYz+aPTtzNjITSVL/HVc9gCSNKgMsSUUMsCQVMcCSVMQAS1KRxdUD9GLdunV53333VY8hSUcTc33ggngG/PLLL1ePIEnzbkEEWJKGkQGWpCIGWJKKGGBJKmKAJamIAZakIgZYkooYYEkqYoAlqYgBlqQiBliSihhgSSpigCWpiAGWpCJDG+AVq84gIso/Vqw6o/qvQtKAWhAnZJ+LH+/by4abH6oeg+3XrK0eQdKAGtpnwJI06AywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUpHWAxwRiyLiBxHxjeb66ojYGRF7ImJ7RCxpewZJGkT9eAb8MWB31/WbgM9m5tnAK8CmPswgSQOn1QBHxErgt4AvNdcDuBS4q7nLNuCKNmeQpEHV9jPgzwF/Chxurr8DeDUzDzXX9wErjvbAiNgcERMRMTE5OdnymJLUf60FOCJ+GziYmY/M5fGZuTUzxzNzfGxsbJ6nk6R6bf5W5IuB342Iy4ETgbcDnweWRsTi5lnwSmB/izNI0sBq7RlwZl6fmSsz80zgauCfM/ODwA7gyuZuG4F72ppBkgZZxeuAPwl8PCL20DkmfEvBDJJUrs1DEP8vMx8EHmwuPwtc0I+vK0mDzHfCSVIRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRVoLcEScGBEPR8SjEfFkRNzYbF8dETsjYk9EbI+IJW3NIEmDrM1nwD8DLs3MXwHOBdZFxIXATcBnM/Ns4BVgU4szSNLAai3A2fHfzdXjm48ELgXuarZvA65oawZJGmStHgOOiEURsQs4CNwPPAO8mpmHmrvsA1ZM89jNETEREROTk5NtjilJJVoNcGa+npnnAiuBC4B3HcNjt2bmeGaOj42NtTWiJJXpy6sgMvNVYAdwEbA0IhY3N60E9vdjBkkaNG2+CmIsIpY2l98GvBfYTSfEVzZ32wjc09YMkjTIFs9+lzlbDmyLiEV0Qn9nZn4jIn4I3BERfw78ALilxRkkaWC1FuDMfAw47yjbn6VzPFiSRprvhJOkIgZYkooYYEkqYoAlqYgBlqQiBliSihhgSSpigCWpiAGWpCIGWJKKGGBJKmKAJamIAZakIgZYkooYYEkqYoAlqYgBlqQiBliSihhgSSpigCWpiAGWpCIGWJKKGGBJKmKAJamIAZakIgZYkooYYEkqYoAlqYgBlqQiBliSihhgSSpigCWpiAGWpCIGWJKKGGBJKtJTgCPi4l62SZJ61+sz4L/pcZskqUeLZ7oxIi4C1gJjEfHxrpveDixqczBJGnYzBhhYApzU3O/kru2vAVe2NZQkjYIZA5yZ3wG+ExFfzszn+zSTJI2E2Z4BH3FCRGwFzux+TGZe2sZQkjQKeg3wPwB/C3wJeL29cSRpdPQa4EOZ+cVWJ5GkEdPry9C+HhF/GBHLI+LUIx+tTiZJQ67XZ8Abm8+f6NqWwFnzO44kjY6eApyZq9seRJJGTU8BjogPH217Zn5lfseRpNHR6yGId3ddPhG4DPg+YIAlaY56PQTxR93XI2IpcEcbA0nSqJjr6Sj/B/C4sCS9Bb0eA/46nVc9QOckPL8E3NnWUJI0Cno9BvyXXZcPAc9n5r4W5pGkkdHTIYjmpDw/onNGtFOAn7c5lCSNgl5/I8ZVwMPA7wNXATsjwtNRStJb0OshiBuAd2fmQYCIGAO+DdzV1mCSNOx6fRXEcUfi2/jPY3isJOkoen0GfF9EfAu4vbm+AfhmOyNJ0miY7XfCnQ0sy8xPRMTvAZc0N30PuK3t4SRpmM32DPhzwPUAmXk3cDdARPxyc9vvtDibJA212Y7jLsvMx6dubLad2cpEkjQiZgvw0hlue9s8ziFJI2e2AE9ExB9M3RgRHwEeaWckSRoNsx0Dvg74WkR8kDeCOw4sAd4/0wMjYhWd01Uuo3Meia2Z+fnmVxltp3MI4zngqsx8ZY7zS9KCNeMz4Mx8KTPXAjfSieVzwI2ZeVFmvjjLn30I+OPMXANcCHw0ItYAW4AHMvMc4IHmuiSNnF7PB7wD2HEsf3BmHgAONJd/GhG7gRXAeuDXmrttAx4EPnksf7YkDYO+vJstIs4EzgN20nllxYHmphfpHKI42mM2R8RERExMTk72Y0xJ6qvWAxwRJwFfBa7LzNe6b8vM5I3zDDPltq2ZOZ6Z42NjY22PKUl912qAI+J4OvG9rXkjB8BLEbG8uX05cHC6x0vSMGstwBERwC3A7sz8TNdN9wIbm8sbgXvamkGSBlmvJ+OZi4uBDwGPR8SuZtufAZ8G7oyITcDzdM4vLEkjp7UAZ+a/ADHNzZe19XUlaaHwnL6SVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBVp8zdiCOC4xXR+O1Ot01euYv/eF6rHkNTFALft8CE23PxQ9RRsv2Zt9QiSpvAQhCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUU8Gc+oGICzsnlGNunNDPCoGICzsnlGNunNPAQhSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBVpLcARcWtEHIyIJ7q2nRoR90fE083nU9r6+pI06Np8BvxlYN2UbVuABzLzHOCB5rokjaTWApyZ3wV+MmXzemBbc3kbcEVbX1+SBl2/jwEvy8wDzeUXgWV9/vqSNDDKfgiXmQnkdLdHxOaImIiIicnJyT5OJrVvxaoziIjyjxWrzqj+qxhpi/v89V6KiOWZeSAilgMHp7tjZm4FtgKMj49PG2ppIfrxvr1suPmh6jHYfs3a6hFGWr+fAd8LbGwubwTu6fPXl6SB0ebL0G4Hvge8MyL2RcQm4NPAeyPiaeDXm+uSNJJaOwSRmR+Y5qbL2vqakrSQ+E44SSpigCWpSL9fBaFRdtxiIqJ6Ck5fuYr9e1+oHkMywOqjw4d86ZXUxUMQklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUZHH1AJIKHbeYiKiegtNXrmL/3heqx+g7AyyNssOH2HDzQ9VTsP2atdUjlPAQhCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQV8WxoGj0DcgpGyQBr9AzAKRhH9fSLejMPQUhSEQMsSUUMsCQVMcCSVMQAS1IRXwUhqd4AvDSw4jczG2BJ9Ub0pYEegpCkIgZYkoqUBDgi1kXEUxGxJyK2VMwgSdX6HuCIWAR8AXgfsAb4QESs6fccklSt4hnwBcCezHw2M38O3AGsL5hDkkpFZvb3C0ZcCazLzI801z8E/GpmXjvlfpuBzc3VdwJPAacBL/dx3EqjstZRWSe41mF0GvCjzFw3lwcP7MvQMnMrsLV7W0RMZOZ40Uh9NSprHZV1gmsdRs065xRfqDkEsR9Y1XV9ZbNNkkZKRYD/DTgnIlZHxBLgauDegjkkqVTfD0Fk5qGIuBb4FrAIuDUzn+zx4Vtnv8vQGJW1jso6wbUOo7e0zr7/EE6S1OE74SSpiAGWpCILJsDD/PbliHguIh6PiF0RMdFsOzUi7o+Ip5vPp1TPORcRcWtEHIyIJ7q2HXVt0fHXzT5+LCLOr5v82E2z1k9FxP5m3+6KiMu7bru+WetTEfGbNVMfu4hYFRE7IuKHEfFkRHys2T5U+3WGdc7fPs3Mgf+g88O6Z4CzgCXAo8Ca6rnmcX3PAadN2fYXwJbm8hbgpuo557i29wDnA0/MtjbgcuAfgQAuBHZWzz8Pa/0U8CdHue+a5t/xCcDq5t/3ouo19LjO5cD5zeWTgX9v1jNU+3WGdc7bPl0oz4BH8e3L64FtzeVtwBV1o8xdZn4X+MmUzdOtbT3wlez4V2BpRCzvy6DzYJq1Tmc9cEdm/iwz/wPYQ+ff+cDLzAOZ+f3m8k+B3cAKhmy/zrDO6RzzPl0oAV4B7O26vo+Z/yIWmgT+KSIead6CDbAsMw80l18EltWM1orp1jas+/na5lvvW7sOJQ3FWiPiTOA8YCdDvF+nrBPmaZ8ulAAPu0sy83w6Z4j7aES8p/vG7Hx/M5SvFxzmtTW+CPwicC5wAPir0mnmUUScBHwVuC4zX+u+bZj261HWOW/7dKEEeKjfvpyZ+5vPB4Gv0fm25aUj36Y1nw/WTTjvplvb0O3nzHwpM1/PzMPA3/HGt6QLeq0RcTydKN2WmXc3m4duvx5tnfO5TxdKgIf27csR8QsRcfKRy8BvAE/QWd/G5m4bgXtqJmzFdGu7F/hw81PzC4H/6vqWdkGacqzz/XT2LXTWenVEnBARq4FzgIf7Pd9cREQAtwC7M/MzXTcN1X6dbp3zuk+rf9J4DD+RvJzOTyGfAW6onmce13UWnZ+cPgo8eWRtwDuAB4CngW8Dp1bPOsf13U7n27T/pXNMbNN0a6PzU/IvNPv4cWC8ev55WOvfN2t5rPkPurzr/jc0a30KeF/1/MewzkvoHF54DNjVfFw+bPt1hnXO2z71rciSVGShHIKQpKFjgCWpiAGWpCIGWJKKGGBJKmKAJamIAZakIv8HLT1X9FbXQlwAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(dknn.mean())\n",
    "sns.displot(dknn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.19334502768988657\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<seaborn.axisgrid.FacetGrid at 0x1d9bf9950>"
      ]
     },
     "execution_count": 82,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFgCAYAAACFYaNMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAQhUlEQVR4nO3de6xld1mH8ec7DAWFIgWOTTPOWMACVsSCQ8WBELBoBhIoFewliiUBZpQWIRAigolE/xCVi0YJdICmJYEyUEooAYulFAiUFodaepWrJbSUdgpqCUZxOq9/nFU4jHPZ05m13n3OeT7Jztl77XX6e/ee8nSx9mVSVUiSpremewBJWq0MsCQ1McCS1MQAS1ITAyxJTdZ2DzCLzZs31yWXXNI9hiTdW9nbxmVxBHznnXd2jyBJh92yCLAkrUQGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKarOgAr1u/gSSTXNat39D9cCUtM8viC9nvrW/f8i1OO+eKSdbavnXTJOtIWjlW9BGwJM0zAyxJTQywJDUxwJLUxABLUhMDLElNRgtwkvVJLk9yY5Ibkrx82P76JLcmuWa4PGusGSRpno35PuBdwKuq6uokRwJfTHLpcN9bquqNI64tSXNvtABX1W3AbcP17ye5CVg31nqStNxMcg44ybHA44Grhk1nJ7k2yblJjtrH72xJsiPJjp07d04xpiRNavQAJ3kg8EHgFVV1F/A24JHACSweIb9pb79XVduqamNVbVxYWBh7TEma3KgBTnJfFuP7nqq6CKCqbq+qu6tqN/AO4MQxZ5CkeTXmuyACvAu4qarevGT7MUt2OwW4fqwZJGmejfkuiCcDLwCuS3LNsO21wBlJTgAKuBnYOuIMkjS3xnwXxGeB7OWuj421piQtJ34STpKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAPlzVrSTLZZd36Dd2PWNIhWts9wIqxexennXPFZMtt37ppsrUkjcMjYElqYoAlqYkBlqQmBliSmhhgSWpigCWpyWgBTrI+yeVJbkxyQ5KXD9sfkuTSJF8dfh411gySNM/GPALeBbyqqo4HngScleR44DXAZVV1HHDZcFuSVp3RAlxVt1XV1cP17wM3AeuAk4Hzh93OB5471gySNM8mOQec5Fjg8cBVwNFVddtw13eAo/fxO1uS7EiyY+fOnVOMKUmTGj3ASR4IfBB4RVXdtfS+qiqg9vZ7VbWtqjZW1caFhYWxx5SkyY0a4CT3ZTG+76mqi4bNtyc5Zrj/GOCOMWeQpHk15rsgArwLuKmq3rzkrouBM4frZwIfHmsGSZpnY34b2pOBFwDXJblm2PZa4A3A+5O8CPgmcOqIM0jS3BotwFX1WSD7uPuksdaVpOXCT8JJUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0M8HK1Zi1JJrusW7+h+xFLK86YfymnxrR7F6edc8Vky23fummytaTVwiNgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJanJaAFOcm6SO5Jcv2Tb65PcmuSa4fKssdaXpHk35hHwecDmvWx/S1WdMFw+NuL6kjTXRgtwVX0G+N5Y/3xJWu46zgGfneTa4RTFUfvaKcmWJDuS7Ni5c+eU80nSJKYO8NuARwInALcBb9rXjlW1rao2VtXGhYWFicaTpOlMGuCqur2q7q6q3cA7gBOnXF+S5smkAU5yzJKbpwDX72tfSVrp1s6yU5InV9XnDrRtj/svAJ4GPCzJLcCfAU9LcgJQwM3A1ns3tiQtfzMFGPh74AkzbPuRqjpjL5vfNeN6krTi7TfASX4d2AQsJHnlkrseBNxnzMEkaaU70BHwEcADh/2OXLL9LuD5Yw0lSavBfgNcVZ8GPp3kvKr65kQzSdKqMOs54Psl2QYcu/R3quo3xhhKklaDWQP8AeDtwDuBu8cbR5JWj1kDvKuq3jbqJJK0ysz6QYyPJHlpkmOSPOSey6iTab6sWUuSyS7r1m/ofsTS6GY9Aj5z+PnqJdsKeMThHUdza/cuTjvnismW275102RrSV1mCnBVPXzsQSRptZn1o8i/v7ftVfXuwzuOJK0es56CeOKS6/cHTgKuBgywJN1Ls56CeNnS20keDLxvjIEkabW4t19H+QPA88KSdAhmPQf8ERbf9QCLX8Lzi8D7xxpKklaDWc8Bv3HJ9V3AN6vqlhHmkaRVY6ZTEMOX8vwri9+IdhTwwzGHkqTVYKYAJzkV+ALwO8CpwFVJ/DpKSToEs56CeB3wxKq6AyDJAvAJ4MKxBpOklW7Wd0GsuSe+g+8exO9KkvZi1iPgS5J8HLhguH0a8LFxRpKk1eFAfyfcLwBHV9Wrk/w28JThrs8D7xl7OElayQ50BPy3wJ8AVNVFwEUASX55uO/ZI84mSSvagc7jHl1V1+25cdh27CgTSdIqcaAAP3g/9/3UYZxDkladAwV4R5KX7LkxyYuBL44zkiStDgc6B/wK4ENJfpcfB3cjcARwyohzSdKKt98AV9XtwKYkTwceO2z+aFV9cvTJJGmFm/X7gC8HLh95FklaVfw0myQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCaT2vWkmSSy7r1G7ofrVapWf9aemlau3dx2jlXTLLU9q2bJllH2pNHwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU1GC3CSc5PckeT6JdsekuTSJF8dfh411vqSNO/GPAI+D9i8x7bXAJdV1XHAZcNtSVqVRgtwVX0G+N4em08Gzh+unw88d6z1JWneTX0O+Oiqum24/h3g6H3tmGRLkh1JduzcuXOa6SRpQm0vwlVVAbWf+7dV1caq2riwsDDhZJI0jakDfHuSYwCGn3dMvL4kzY2pA3wxcOZw/UzgwxOvL0lzY8y3oV0AfB54dJJbkrwIeAPwm0m+CjxjuC1Jq9Jo3wdcVWfs466TxlpTkpYTPwknSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1WduxaJKbge8DdwO7qmpjxxyS1KklwIOnV9WdjetLUitPQUhSk64AF/BPSb6YZMvedkiyJcmOJDt27tw58XiSNL6uAD+lqp4APBM4K8lT99yhqrZV1caq2riwsDD9hJI0spYAV9Wtw887gA8BJ3bMIUmdJg9wkgckOfKe68BvAddPPYckdet4F8TRwIeS3LP+e6vqkoY5JKnV5AGuqm8AvzL1upI0b3wbmiQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwNKatSSZ7LJu/YZJH9669Rt8bHP6+Dr+TjhpvuzexWnnXDHZctu3bppsLYBv3/KtyR7fSn5scPgfn0fAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhM/iizp8Bm+V0OzMcCSDp8V/r0ah5unICSpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpq0BDjJ5iRfTvK1JK/pmEGSuk0e4CT3Ad4KPBM4HjgjyfFTzyFJ3TqOgE8EvlZV36iqHwLvA05umEOSWqWqpl0weT6wuapePNx+AfBrVXX2HvttAbYMNx8NfHnGJR4G3HmYxh2D8x26eZ/R+Q7dvM94sPPdWVWb99y49vDNc3hV1TZg28H+XpIdVbVxhJEOC+c7dPM+o/Mdunmf8XDN13EK4lZg/ZLbPzdsk6RVpSPA/wwcl+ThSY4ATgcubphDklpNfgqiqnYlORv4OHAf4NyquuEwLnHQpy0m5nyHbt5ndL5DN+8zHpb5Jn8RTpK0yE/CSVITAyxJTZZtgA/0ceYk90uyfbj/qiTHztl8T01ydZJdw3ujJzXDfK9McmOSa5NcluTn52y+P0hyXZJrkny249OUs36kPsnzklSSSd9WNcNz+MIkO4fn8JokL56n+YZ9Th3+PbwhyXunnG+WGZO8Zcnz95Uk/3FQC1TVsruw+OLd14FHAEcAXwKO32OflwJvH66fDmyfs/mOBR4HvBt4/hw+f08Hfnq4/odz+Pw9aMn15wCXzNtzOOx3JPAZ4Epg4zzNB7wQ+Icpn7eDnO844F+Ao4bbPztvM+6x/8tYfFPBzGss1yPgWT7OfDJw/nD9QuCkJJmX+arq5qq6Ftg90UwHO9/lVfVfw80rWXy/9jzNd9eSmw8Apn41edaP1P8F8FfAf085HPP/kf9Z5nsJ8Naq+neAqrpjDmdc6gzggoNZYLkGeB3wrSW3bxm27XWfqtoF/Cfw0Emmm22+Tgc734uAfxx1op8003xJzkrydeCvgT+aaLZ7HHDGJE8A1lfVR6ccbDDrn/HzhtNMFyZZv5f7xzLLfI8CHpXkc0muTPL/Pso7spn/dzKcons48MmDWWC5BlgTSfJ7wEbgb7pn2VNVvbWqHgn8MfCn3fMslWQN8GbgVd2z7MdHgGOr6nHApfz4/zHOi7UsnoZ4GotHl+9I8uDOgfbjdODCqrr7YH5puQZ4lo8z/2ifJGuBnwG+O8l08/9x65nmS/IM4HXAc6rqfyaaDQ7++Xsf8NwxB9qLA814JPBY4FNJbgaeBFw84QtxB3wOq+q7S/5c3wn86kSzwWx/xrcAF1fV/1bVvwFfYTHIUzmYfw9P5yBPPwDL9kW4tcA3WDzkv+fk+C/tsc9Z/OSLcO+fp/mW7Hse078IN8vz93gWX4A4bk7/fI9bcv3ZwI55m3GP/T/FtC/CzfIcHrPk+inAlXM232bg/OH6w1g8HfDQeZpx2O8xwM0MH2w7qDWmejAjPDnPYvG/iF8HXjds+3MWj9YA7g98APga8AXgEXM23xNZ/C/8D1g8Mr9hzub7BHA7cM1wuXjO5vs74IZhtsv3F7+uGffYd9IAz/gc/uXwHH5peA4fM2fzhcXTODcC1wGnz+OfMfB64A335p/vR5ElqclyPQcsScueAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmvwfRoyF4gCTFOcAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(saliency_sim.mean())\n",
    "sns.displot(saliency_sim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## My Method: Baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier(algorithm='brute', metric='euclidean', n_neighbors=1)"
      ]
     },
     "execution_count": 83,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "NUM_NEIGHBORS = 1\n",
    "FEATURE_NUM = 1  # pick most nb feature\n",
    "\n",
    "# Fit COLE and DkNN\n",
    "twin = KNeighborsClassifier(n_neighbors=1, algorithm=\"brute\", metric='euclidean') \n",
    "twin.fit(X_train_c, train_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|          | 100/10000 [00:03<05:37, 29.34it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Agreement: 1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "results = list()\n",
    "agreement = 0\n",
    "saliency_sim = list()\n",
    "\n",
    "for query_idx, data in enumerate(tqdm(test_loader)):\n",
    "    \n",
    "    # Get data\n",
    "    img, label = data\n",
    "    \n",
    "    # Get query information\n",
    "    query_label = y_test[query_idx].item()\n",
    "    query_pred  = test_preds[query_idx].item()\n",
    "    query_img_trans = get_transformed_data([query_idx], test_loader)[0]\n",
    "    query_cont = X_test_c[query_idx]\n",
    "    \n",
    "    # Get query boxes\n",
    "    query_logits, query_x, query_C = netC(query_img_trans)\n",
    "    query_nb_boxes = get_salient_regions(query_C, query_logits, net_classifier)\n",
    "    \n",
    "    # Get explanation nn\n",
    "    xp_idxs = twin.kneighbors(X=[X_test_c[query_idx]], n_neighbors=NUM_NEIGHBORS, return_distance=False)\n",
    "    xps_imgs_trans = get_transformed_data(xp_idxs[0], train_loader)\n",
    "    xp_logits, xp_x, xp_C = netC(xps_imgs_trans[0])\n",
    "    xp_nb_boxes = get_salient_regions(xp_C, xp_logits, net_classifier)\n",
    "    \n",
    "    if train_preds[xp_idxs[0][0]] == query_pred:\n",
    "        agreement += 1\n",
    "        \n",
    "    saliency_sim.append(abs(query_nb_boxes[0][0] - xp_nb_boxes[0][0]))\n",
    "    \n",
    "    # See distance between query nb feature and nn\n",
    "    window_idx = query_nb_boxes[0][1]  # -1 just to make hyperparam easier to think about\n",
    "    query_feature = query_C[ :, :, window_idx[0]:window_idx[0]+1, window_idx[1]:window_idx[1]+1]\n",
    "\n",
    "    window_idx = xp_nb_boxes[0][1]  # -1 just to make hyperparam easier to think about\n",
    "    xp_feature = xp_C[ :, :, window_idx[0]:window_idx[0]+1, window_idx[1]:window_idx[1]+1]\n",
    "\n",
    "    dist = sum(abs(query_feature.flatten() - xp_feature.flatten()))\n",
    "\n",
    "    results.append(dist.item())\n",
    "        \n",
    "    if query_idx == 100:\n",
    "        print(\"Agreement:\", agreement / len(results))\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [],
   "source": [
    "saliency_sim = np.array(saliency_sim)\n",
    "twin_sys = np.array(deepcopy(twin_sys))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "65.55672784976389\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<seaborn.axisgrid.FacetGrid at 0x134287090>"
      ]
     },
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFgCAYAAACFYaNMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAATUElEQVR4nO3dfawld33f8ffHXhyIofFDb1eb9a7sFIvISYWhN9QYFAkbEkPS7KaitlEEq8rpWiqkUKIUU1SplfoHSGlI+kS9jSlLRY2JY2udJnLibpxEkVOTtePgJ1wbB+Nd1rtrggOlUsjCt3+csXy13eu93pw533Pveb+kozPzm4fzndHcj+b+zsycVBWSpNk7o7sASVpUBrAkNTGAJamJASxJTQxgSWqyqbuAtbjqqqvqzjvv7C5Dkk5XTta4Ls6An3322e4SJGnq1kUAS9JGZABLUhMDWJKajBrASf5ZkoeTPJTk5iQvT3JRknuTPJHkliRnjVmDJM2r0QI4yVbgnwLLVfXDwJnAtcDHgI9X1auBrwPXjVWDJM2zsbsgNgGvSLIJ+F7gMHAFcOswfS+wc+QaJGkujRbAVXUI+EXgK0yC9y+A+4Dnqur4MNtBYOvJlk+yO8mBJAeOHTs2VpmS1GbMLohzgR3ARcD3A2cDV611+araU1XLVbW8tLQ0UpWS1GfMLoi3An9WVceq6q+A24A3AecMXRIAFwCHRqxBkubWmAH8FeCyJN+bJMCVwCPA3cA7h3l2AftGrEGS5taYfcD3Mvmy7X7gweGz9gAfAj6Y5AngfOCmsWqQpHmW9fCTRMvLy3XgwIHuMiTpdK3fh/FI0kZkAEtSEwP4Jdq6bTtJpvraum1792ZJarAuHsg+T7568GmuufGeqa7zlusvn+r6JK0PngFLUhMDWJKaGMCS1MQAlqQmBrAkNTGAJamJASxJTQxgSWpiAEtSEwNYkpoYwJLUxACWpCYGsCQ1MYAlqYkBLElNDGBJamIAS1ITA1iSmhjAktTEAJakJgawJDUxgCWpiQEsSU0MYElqYgBLUhMDWJKajBbASV6T5IEVr28k+UCS85LcleTx4f3csWqQpHk2WgBX1WNVdWlVXQr8XeD/ArcDNwD7q+piYP8wLkkLZ1ZdEFcCX6qqp4AdwN6hfS+wc0Y1SNJcmVUAXwvcPAxvrqrDw/AzwOaTLZBkd5IDSQ4cO3ZsFjVK0kyNHsBJzgJ+Cvi1E6dVVQF1suWqak9VLVfV8tLS0shVStLszeIM+O3A/VV1ZBg/kmQLwPB+dAY1SNLcmUUAv4sXuh8A7gB2DcO7gH0zqEGS5s6oAZzkbOBtwG0rmj8KvC3J48Bbh3FJWjibxlx5VX0LOP+Etq8xuSpCkhaad8JJUhMDWJKaGMCS1MQAlqQmBrAkNTGAJamJASxJTQxgSWpiAEtSEwNYkpoYwJLUxACWpCYGsCQ1MYAlqYkBLElNDGBJamIAS1ITA1iSmhjAktTEAJakJgawJDUxgCWpiQEsSU0MYElqYgBLUhMDWJKaGMCS1MQAlqQmBrAkNTGAJanJqAGc5Jwktyb5YpJHk7wxyXlJ7kry+PB+7pg1SNK8GvsM+FeAO6vqB4HXAo8CNwD7q+piYP8wLkkLZ7QATvJ9wI8CNwFU1ber6jlgB7B3mG0vsHOsGiRpno15BnwRcAz4r0n+JMmvJjkb2FxVh4d5ngE2n2zhJLuTHEhy4NixYyOWKUk9xgzgTcDrgU9U1euAb3FCd0NVFVAnW7iq9lTVclUtLy0tjVimJPUYM4APAger6t5h/FYmgXwkyRaA4f3oiDVI0twaLYCr6hng6SSvGZquBB4B7gB2DW27gH1j1SBJ82zTyOv/OeAzSc4CngT+EZPQ/1yS64CngKtHrkGS5tKoAVxVDwDLJ5l05ZifK0nrgXfCSVITA1iSmhjAktTEAJakJgawJDUxgCWpiQEsSU0MYElqYgBLUhMDWJKabOgA3rptO0mm+pKkaRn7YTytvnrwaa658Z6prvOW6y+f6vokLa4NfQYsSfPMAJakJgawJDUxgCWpiQEsSU0MYElqYgBLUhMDWJKaGMCS1MQAlqQmBrAkNTGAJamJASxJTQxgSWpiAEtSEwNYkpoYwJLUxACWpCaj/iRRki8D3wS+AxyvquUk5wG3ABcCXwaurqqvj1mHJM2jWZwBv6WqLq2q5WH8BmB/VV0M7B/GJWnhdHRB7AD2DsN7gZ0NNUhSu7EDuIDfSXJfkt1D2+aqOjwMPwNsHrkGSZpLY/8s/Zur6lCSvwXcleSLKydWVSWpky04BPZugO3bt49cpiTN3qhnwFV1aHg/CtwOvAE4kmQLwPB+dJVl91TVclUtLy0tjVmmJLUYLYCTnJ3kVc8PAz8GPATcAewaZtsF7BurBkmaZ2N2QWwGbk/y/Of896q6M8kfA59Lch3wFHD1iDVI0twaLYCr6kngtSdp/xpw5VifK0nrhXfCSVITA1iSmhjAktTEAJakJgawJDUxgCWpiQEsSU0MYElqYgBLUhMDWJKaGMCS1MQAlqQmBrAkNTGAJamJASxJTQxgSWpiAEtSEwNYkpqsKYCTvGktbZKktVvrGfC/X2ObJGmNXvRHOZO8EbgcWErywRWT/gZw5piFSdJGd6pfRT4LeOUw36tWtH8DeOdYRUnSInjRAK6q3wd+P8mnquqpGdW0eM7YRJKpr/b7L9jGoae/MvX1SpqOU50BP+97kuwBLly5TFVdMUZRC+e7x7nmxnumvtpbrr986uuUND1rDeBfA/4z8KvAd8YrR5IWx1oD+HhVfWLUSiRpwaz1MrTfSPJPkmxJct7zr1Erk6QNbq1nwLuG919Y0VbAD0y3HElaHGsK4Kq6aOxCJGnRrCmAk7znZO1V9enpliNJi2OtXRA/smL45cCVwP2AASxJp2mtXRA/t3I8yTnAZ9eybJIzgQPAoar6ySQXDcueD9wHvLuqvv1SipakjeB0H0f5LWCt/cLvBx5dMf4x4ONV9Wrg68B1p1mDJK1ra30c5W8kuWN4/SbwGHD7Gpa7APgJJjdwkMn9tlcAtw6z7AV2nkbdkrTurbUP+BdXDB8Hnqqqg2tY7peBf84LD/I5H3iuqo4P4weBrSdbMMluYDfA9u3b11imJK0fazoDHh7K80UmQXoucMo+2yQ/CRytqvtOp7Cq2lNVy1W1vLS0dDqrkKS5ttYuiKuBzwP/ELgauDfJqR5H+Sbgp5J8mcmXblcAvwKck+T5M+8LgEOnUbckrXtr/RLuI8CPVNWuqnoP8AbgX77YAlX14aq6oKouBK4Ffreqfga4mxeeJbwL2HdalUvSOrfWAD6jqo6uGP/aS1j2RB8CPpjkCSZ9wjed5nokaV1b65dwdyb5beDmYfwa4LfW+iFV9XvA7w3DTzI5g5akhXaq34R7NbC5qn4hyT8A3jxM+iPgM2MXJ0kb2anOgH8Z+DBAVd0G3AaQ5O8M0/7+iLVJ0oZ2qn7czVX14ImNQ9uFo1QkSQviVAF8zotMe8UU65CkhXOqAD6Q5B+f2JjkZ5k8SEeSdJpO1Qf8AeD2JD/DC4G7DJwF/PSIdUnShveiAVxVR4DLk7wF+OGh+Ter6ndHr0ySNri1Pg/4biZ3sEmSpuR072aTJP01GcCS1MQAlqQmBrAkNTGAJamJASxJTQxgSWpiAEtSEwNYkpoYwJLUxACWpCYGsCQ1MYAlqYkBLElNDGBJamIAS1ITA1iSmhjAktTEAJakJgawJDUxgCWpiQEsSU1GC+AkL0/y+SR/muThJP96aL8oyb1JnkhyS5KzxqpBkubZmGfAfwlcUVWvBS4FrkpyGfAx4ONV9Wrg68B1I9YgSXNrtACuif8zjL5seBVwBXDr0L4X2DlWDZI0z0btA05yZpIHgKPAXcCXgOeq6vgwy0Fg6yrL7k5yIMmBY8eOjVmmJLUYNYCr6jtVdSlwAfAG4AdfwrJ7qmq5qpaXlpbGKlGS2szkKoiqeg64G3gjcE6STcOkC4BDs6hBkubNmFdBLCU5Zxh+BfA24FEmQfzOYbZdwL6xapCkebbp1LOcti3A3iRnMgn6z1XV/0jyCPDZJP8G+BPgphFrkKS5NVoAV9UXgNedpP1JJv3BkrTQvBNOkpoYwJLUxADeyM7YRJKpvrZu2969VdKGMeaXcOr23eNcc+M9U13lLddfPtX1SYvMM2BJamIAS1ITA1iSmhjAktTEAJakJgawJDUxgCWpiQEsSU0MYElqYgBLUhMDWJKaGMCS1MQAlqQmBrAkNTGAJamJASxJTQxgSWpiAEtSEwNYL42/MydNjb8Jp5fG35mTpsYzYElqYgBLUhMDWJKaGMCS1MQAlqQmBrAkNRktgJNsS3J3kkeSPJzk/UP7eUnuSvL48H7uWDVI0jwb8wz4OPDzVXUJcBnw3iSXADcA+6vqYmD/MC5JC2e0AK6qw1V1/zD8TeBRYCuwA9g7zLYX2DlWDZI0z2bSB5zkQuB1wL3A5qo6PEx6Bti8yjK7kxxIcuDYsWOzKFOSZmr0AE7ySuDXgQ9U1TdWTquqAupky1XVnqparqrlpaWlscuUpJkbNYCTvIxJ+H6mqm4bmo8k2TJM3wIcHbMGSZpXY14FEeAm4NGq+qUVk+4Adg3Du4B9Y9UgSfNszKehvQl4N/BgkgeGtn8BfBT4XJLrgKeAq0esQZLm1mgBXFV/CGSVyVeO9bmStF54J5wkNTGAJamJASxJTQxgSWpiAEtSEwNYkpoYwJLUxACWpCYGsCQ1MYAlqYkBrH5nbCLJ1F9bt23v3jLpRY35MB5pbb57nGtuvGfqq73l+sunvk5pmjwDlqQmBrAkNTGAJamJASxJTQxgSWpiAEtSEwNYkpoYwJLUxACWpCYGsCQ1MYAlqYkBLElNDGBJamIAS1ITA1jagLZu2+4zltcBnwcsbUBfPfi0z1heBzwDlqQmBrAkNRktgJN8MsnRJA+taDsvyV1JHh/ezx3r86UxjNG3uq76VUf4/b51tf1TNmYf8KeA/wB8ekXbDcD+qvpokhuG8Q+NWIM0VWP0ra6rftURfr9vXW3/lI12BlxVfwD8+QnNO4C9w/BeYOdYny9J827WfcCbq+rwMPwMsHnGny9Jc6PtMrSqqiS12vQku4HdANu3L24fkf4ahv5KaV7NOoCPJNlSVYeTbAGOrjZjVe0B9gAsLy+vGtTSquyv1JybdRfEHcCuYXgXsG/Gny9Jc2PMy9BuBv4IeE2Sg0muAz4KvC3J48Bbh3FJWkijdUFU1btWmXTlWJ8pSeuJd8JJUhMDWJKaGMCS1MQAlqQmBrAkNTGAJamJv4ghdfOW6YVlAEvdvGV6YdkFIUlNDGBJamIAS1ITA1iSmhjAktTEAJakJgawJDUxgCWpiQEsSU0MYElqYgBL6jU8C2Oar63btndv1Zr4LAhJvRb4WRieAUtSEwNYkpoYwJLUxACWpCYGsCQ1MYAlqYkBLGnjGeHa4jGuL/Y6YEkbzwjXFsP0ry/2DFiSmhjAktTEAJakJi0BnOSqJI8leSLJDR01SFK3mQdwkjOB/wi8HbgEeFeSS2ZdhyR16zgDfgPwRFU9WVXfBj4L7GioQ5Japapm+4HJO4Grqupnh/F3A3+vqt53wny7gd3D6GuAx4C/CTw7w3LnkfvAfQDuA1hf++DZqrrqxMa5vQ64qvYAe1a2JTlQVctNJc0F94H7ANwHsDH2QUcXxCFg24rxC4Y2SVooHQH8x8DFSS5KchZwLXBHQx2S1GrmXRBVdTzJ+4DfBs4EPllVD69x8T2nnmXDcx+4D8B9ABtgH8z8SzhJ0oR3wklSEwNYkpqsiwBe5FuXk3w5yYNJHkhyYGg7L8ldSR4f3s/trnOaknwyydEkD61oO+k2Z+LfDcfGF5K8vq/y6VllH/yrJIeGY+GBJO9YMe3Dwz54LMmP91Q9XUm2Jbk7ySNJHk7y/qF9wxwLcx/A3roMwFuq6tIV1zzeAOyvqouB/cP4RvIp4MSL1lfb5rcDFw+v3cAnZlTj2D7F/78PAD4+HAuXVtVvAQx/D9cCPzQs85+Gv5v17jjw81V1CXAZ8N5hWzfMsTD3AYy3Lp/MDmDvMLwX2NlXyvRV1R8Af35C82rbvAP4dE38L+CcJFtmUuiIVtkHq9kBfLaq/rKq/gx4gsnfzbpWVYer6v5h+JvAo8BWNtCxsB4CeCvw9Irxg0Pboijgd5LcN9yeDbC5qg4Pw88Am3tKm6nVtnnRjo/3Df9ef3JF19OG3wdJLgReB9zLBjoW1kMAL7o3V9Xrmfx79d4kP7pyYk2uI1yoawkXcZsHnwD+NnApcBj4t63VzEiSVwK/Dnygqr6xctp6PxbWQwAv9K3LVXVoeD8K3M7kX8sjz/9rNbwf7atwZlbb5oU5PqrqSFV9p6q+C/wXXuhm2LD7IMnLmITvZ6rqtqF5wxwL6yGAF/bW5SRnJ3nV88PAjwEPMdn+XcNsu4B9PRXO1GrbfAfwnuEb8MuAv1jx7+mGckJ/5k8zORZgsg+uTfI9SS5i8iXU52dd37QlCXAT8GhV/dKKSRvnWKiquX8B7wD+N/Al4CPd9cxwu38A+NPh9fDz2w6cz+Tb38eB/wmc113rlLf7Zib/Yv8Vk36861bbZiBMrpL5EvAgsNxd/4j74L8N2/gFJmGzZcX8Hxn2wWPA27vrn9I+eDOT7oUvAA8Mr3dspGPBW5Elqcl66IKQpA3JAJakJgawJDUxgCWpiQEsSU0MYElqYgBLUpP/B7viOHNW2WKqAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(twin_sys.mean())\n",
    "sns.displot(twin_sys)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.17792157194401959\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<seaborn.axisgrid.FacetGrid at 0x133355490>"
      ]
     },
     "execution_count": 87,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFgCAYAAACFYaNMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAQzklEQVR4nO3de6ykBXnH8e8PVrRVrKhbQtelqxYv1Fq0R2vRGC3WrCSKVssl1WKKLlWwGo2p1SY19Y/a1kub1iirErBRBBEiVsUiokYR7IorF6nXYllAOGhbjE3Vhad/zLv1dNnL7HLeec7l+0kmZ+ady/vs7PLlPe/MO5OqQpI0ewd0DyBJq5UBlqQmBliSmhhgSWpigCWpyZruAaaxcePGuvjii7vHkKT9lV0tXBZbwLfffnv3CJK06JZFgCVpJTLAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktRkRQd43frDSTLKad36w7v/eJKWuWXxgez76+ZtN3LCGZeP8tjnnnr0KI8rafVY0VvAkrSUGWBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpqMFuAk65NcluRrSa5L8sph+RuT3JRk63A6dqwZJGkpG/MribYDr6mqq5IcDHw5ySXDdW+vqreMuG5JWvJGC3BV3QLcMpz/YZLrgXVjrU+SlpuZ7ANOsgF4HHDlsOj0JFcnOTPJIbu5z6YkW5JsmZ+fn8WYkjRTowc4yf2ADwOvqqo7gHcCDweOYrKF/NZd3a+qNlfVXFXNrV27duwxJWnmRg1wknsxie/7q+oCgKq6tarurKq7gHcDTxxzBklaqsZ8F0SA9wLXV9XbFiw/bMHNngdcO9YMkrSUjfkuiCcDLwKuSbJ1WPZ64KQkRwEF3ACcOuIMkrRkjfkuiM8D2cVVHx9rnZK0nHgknCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTdZ0D7BsHbCGJKM89IH3ujd3/vTHozz2Lz1kPTfd+O+jPLakfWOA99dd2znhjMtHeehzTz161MeWtDS4C0KSmhhgSWpigCWpiQGWpCYGWJKaGGBJajJagJOsT3JZkq8luS7JK4flD0xySZJvDj8PGWsGSVrKxtwC3g68pqqOBJ4EnJbkSOB1wKVVdQRw6XBZklad0QJcVbdU1VXD+R8C1wPrgOOAs4ebnQ08d6wZJGkpm8k+4CQbgMcBVwKHVtUtw1XfAw7dzX02JdmSZMv8/PwsxpSkmRo9wEnuB3wYeFVV3bHwuqoqoHZ1v6raXFVzVTW3du3asceUpJkbNcBJ7sUkvu+vqguGxbcmOWy4/jDgtjFnkKSlasx3QQR4L3B9Vb1twVUXAScP508GPjLWDJK0lI35aWhPBl4EXJNk67Ds9cCbgfOSnAJ8Fzh+xBkkackaLcBV9Xlgdx+Ye8xY65Wk5cIj4SSpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWoyVYCTPHmaZTtdf2aS25Jcu2DZG5PclGTrcDp230eWpJVh2i3gv59y2UJnARt3sfztVXXUcPr4lOuXpBVnzZ6uTPJbwNHA2iSvXnDV/YED93Tfqvpckg33eEJJWqH2tgV8EHA/JqE+eMHpDuAF+7nO05NcPeyiOGR3N0qyKcmWJFvm5+f3c1WStHTtcQu4qj4LfDbJWVX13UVY3zuBNwE1/Hwr8Ie7WfdmYDPA3NxcLcK6JWlJ2WOAF7h3ks3AhoX3qarf3peVVdWtO84neTfwT/tyf0laSaYN8IeAdwHvAe7c35UlOayqbhkuPg+4dk+3l6SVbNoAb6+qd+7LAyc5B3ga8OAk24A/B56W5CgmuyBuAE7dl8eUpJVk2gB/NMnLgQuBH+9YWFU/2N0dquqkXSx+776NJ0kr17QBPnn4+doFywp42OKOI0mrx1QBrqqHjj2IJK02UwU4yR/sanlVvW9xx5Gk1WPaXRBPWHD+PsAxwFWAAZak/TTtLohXLLyc5AHAB8cYSJJWi/39OMofAe4XXo4OWEOSUU7r1h/e/aeTlpVp9wF/lMm7HmDyITyPBs4bayiN6K7tnHDG5aM89LmnHj3K40or1bT7gN+y4Px24LtVtW2EeSRp1ZhqF8TwoTz/yuST0A4BfjLmUJK0Gkz7jRjHA18Cfg84Hrgyyf5+HKUkiel3QbwBeEJV3QaQZC3wKeD8sQaTpJVu2ndBHLAjvoPv78N9JUm7MO0W8MVJPgmcM1w+AfD73CTpHtjbd8L9CnBoVb02ye8CTxmu+iLw/rGHk6SVbG9bwH8L/ClAVV0AXACQ5NeG65494myStKLtbT/uoVV1zc4Lh2UbRplIklaJvQX4AXu47ucWcQ5JWnX2FuAtSV6688IkLwG+PM5IkrQ67G0f8KuAC5P8Pj8L7hxwEJMv1ZQk7ac9Bnj4GvmjkzwdeMyw+GNV9enRJ5OkFW7azwO+DLhs5FkkaVXxaDZJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJanJaAFOcmaS25Jcu2DZA5NckuSbw89Dxlq/JC11Y24BnwVs3GnZ64BLq+oI4NLhsiStSqMFuKo+B/xgp8XHAWcP588GnjvW+iVpqZv1PuBDq+qW4fz3gEN3d8Mkm5JsSbJlfn5+NtNJ0gy1vQhXVQXUHq7fXFVzVTW3du3aGU4mSbMx6wDfmuQwgOHnbTNevyQtGbMO8EXAycP5k4GPzHj9krRkjPk2tHOALwKPTLItySnAm4HfSfJN4BnDZUlaldaM9cBVddJurjpmrHVK0nLikXCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwFo8B6whySindesP7/7TSYtuTfcAWkHu2s4JZ1w+ykOfe+rRozyu1MktYElqYoAlqYkBlqQmBliSmhhgSWpigCWpScvb0JLcAPwQuBPYXlVzHXNIUqfO9wE/vapub1y/JLVyF4QkNekKcAH/nOTLSTbt6gZJNiXZkmTL/Pz8jMeTpPF1BfgpVfV44FnAaUmeuvMNqmpzVc1V1dzatWtnP6EkjawlwFV10/DzNuBC4Ikdc0hSp5kHOMl9kxy84zzwTODaWc8hSd063gVxKHBhkh3r/0BVXdwwhyS1mnmAq+o7wK/Per2StNT4NjRJamKAJamJAZakJgZYkpoYYElqYoC16q1bf7jf5qwWfiuyVr2bt93otzmrhVvAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMPRdbycMAahq+xklYMA6zl4a7tfl6DVhx3QUhSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLI1pOITar7zXrngosjQmD6HWHrgFLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywtFyNeJhzEtYcdJ9leRj1uvWHL5u5PRRZWq5GPMwZJoc6L8fDqG/eduOymdstYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYtAU6yMcnXk3wryes6ZpCkbjMPcJIDgXcAzwKOBE5KcuSs55Ckbh1bwE8EvlVV36mqnwAfBI5rmEOSWqWqZrvC5AXAxqp6yXD5RcBvVtXpO91uE7BpuPhI4OtTruLBwO2LNO49tVRmcY67WyqzOMfdLZVZFnOO26tq484Ll+xnQVTVZmDzvt4vyZaqmhthpH22VGZxjrtbKrM4x90tlVlmMUfHLoibgPULLj9kWCZJq0pHgP8FOCLJQ5McBJwIXNQwhyS1mvkuiKranuR04JPAgcCZVXXdIq5in3dbjGipzOIcd7dUZnGOu1sqs4w+x8xfhJMkTXgknCQ1McCS1GTZBnhvhzMnuXeSc4frr0yyoWmOpya5Ksn24T3Qo5lillcn+VqSq5NcmuSXm+b4oyTXJNma5PNjHgk57WHvSZ6fpJKM8rajKZ6TFyeZH56TrUle0jHHcJvjh38n1yX5wBhzTDNLkrcveD6+keQ/m+Y4PMllSb4y/Ldz7KKtvKqW3YnJi3ffBh4GHAR8FThyp9u8HHjXcP5E4NymOTYAjwXeB7yg+Tl5OvDzw/mXNT4n919w/jnAxV3PyXC7g4HPAVcAc03PyYuBfxjr38c+zHEE8BXgkOHyL3b+3Sy4/SuYvGDf8ZxsBl42nD8SuGGx1r9ct4CnOZz5OODs4fz5wDFJMus5quqGqroauGuR170/s1xWVf89XLyCyXuwO+a4Y8HF+wJjvRI87WHvbwL+Cvif5jnGNs0cLwXeUVX/AVBVtzXOstBJwDlNcxRw/+H8LwA3L9bKl2uA1wE3Lri8bVi2y9tU1Xbgv4AHNcwxK/s6yynAJ7rmSHJakm8Dfw388QhzTDVLkscD66vqYyPNMNUcg+cPv+Ken2T9Lq6fxRyPAB6R5AtJrkhyt8NnZzgLAMOusocCn26a443AC5NsAz7OZGt8USzXAOseSPJCYA74m64ZquodVfVw4E+AP+uYIckBwNuA13SsfycfBTZU1WOBS/jZb2+ztobJboinMdnqfHeSBzTNssOJwPlVdWfT+k8CzqqqhwDHAv84/Nu5x5ZrgKc5nPn/bpNkDZNfHb7fMMesTDVLkmcAbwCeU1U/7ppjgQ8Czx1hjmlmORh4DPCZJDcATwIuGuGFuL0+J1X1/QV/H+8BfmORZ5hqDiZbgBdV1U+r6t+AbzAJcscsO5zIOLsfpp3jFOA8gKr6InAfJh/Uc8+NsYN97BOT/0t/h8mvJTt2nP/qTrc5jf//Itx5HXMsuO1ZjPsi3DTPyeOYvOBwRPMcRyw4/2xgS9csO93+M4zzItw0z8lhC84/D7iiaY6NwNnD+Qcz+fX8QV1/N8CjgBsYDhprek4+Abx4OP9oJvuAF2WeRf8DzerE5FeBbwxBecOw7C+YbNnB5P9SHwK+BXwJeFjTHE9gslXxIyZb4Nc1PiefAm4Ftg6ni5rm+DvgumGGy/YUxbFn2em2owR4yufkL4fn5KvDc/KopjnCZLfM14BrgBM7/26Y7H9981gzTPmcHAl8Yfi72Qo8c7HW7aHIktRkue4DlqRlzwBLUhMDLElNDLAkNTHAktTEAEtSEwMsSU3+F/hCXWZMmO1uAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(saliency_sim.mean())\n",
    "sns.displot(saliency_sim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## My Advanced Method"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "What are the constraints for explanation-by-example?\n",
    "1. It should preferably be a correctly classified instance of the same class. (not always possible)\n",
    "2. It should have a similar latent feature which can be referenced.\n",
    "3. That feature should have a similar contribution to the classification. (measured with logits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 273,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_NEIGHBORS = 10\n",
    "FEATURE_NUM = 1  # pick most nb feature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 274,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier(algorithm='brute', metric='euclidean', n_neighbors=1)"
      ]
     },
     "execution_count": 274,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "keep = (train_preds == y_train.detach().numpy())\n",
    "\n",
    "temp_train = X_train_c[keep]\n",
    "temp_preds = train_preds[keep]\n",
    "\n",
    "temp_train_loader = deepcopy(train_loader)\n",
    "temp_train_loader.dataset.data = temp_train_loader.dataset.data[keep]\n",
    "temp_train_loader.dataset.targets = temp_train_loader.dataset.targets[keep]\n",
    "\n",
    "# Fit twin\n",
    "twin = KNeighborsClassifier(n_neighbors=1, algorithm=\"brute\", metric='euclidean') \n",
    "twin.fit(temp_train, temp_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 275,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Fit twin\n",
    "# twin = KNeighborsClassifier(n_neighbors=1, algorithm=\"brute\", metric='euclidean') \n",
    "# twin.fit(X_train_c, train_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 276,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  3%|▎         | 300/10000 [00:27<14:43, 10.98it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Contamination: 100.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "results = list()\n",
    "saliency_sim = list()\n",
    "contamination = 0\n",
    "\n",
    "\n",
    "for query_idx, data in enumerate(tqdm(test_loader)):\n",
    "    \n",
    "    # Get data\n",
    "    img, label = data\n",
    "    \n",
    "    # Get query information\n",
    "    query_label = y_test[query_idx].item()\n",
    "    query_pred  = test_preds[query_idx].item()\n",
    "    query_img_trans = get_transformed_data([query_idx], test_loader)[0]\n",
    "    query_cont = X_test_c[query_idx]\n",
    "    \n",
    "    # Get query salient feature\n",
    "    query_logits, query_x, query_C = netC(query_img_trans)\n",
    "    query_nb_boxes = get_salient_regions(query_C, query_logits, net_classifier)\n",
    "    window_idx = query_nb_boxes[0][1]  # Only the most NB one\n",
    "    query_feature = query_C[ :, :, window_idx[0]:window_idx[0]+1, window_idx[1]:window_idx[1]+1]\n",
    "\n",
    "    # Get explanation nn\n",
    "    xp_idxs = twin.kneighbors(X=[X_test_c[query_idx]], n_neighbors=NUM_NEIGHBORS, return_distance=False)\n",
    "    \n",
    "    \n",
    "    # How many NNs are different class\n",
    "    contamination += (temp_preds[xp_idxs[0]] == query_pred).sum()\n",
    "    \n",
    "    \n",
    "    # Iterate all three salient regions in xp nn\n",
    "    min_dist = float('inf')\n",
    "    min_logit = float('inf')\n",
    "    for i in range(NUM_NEIGHBORS):\n",
    "\n",
    "        xps_img_trans = get_transformed_data([xp_idxs[0][i]], temp_train_loader)\n",
    "        xp_logits, xp_x, xp_C = netC(xps_img_trans[0])\n",
    "        xp_pred = torch.argmax(xp_logits, dim=1).item()\n",
    "        \n",
    "        coord, dist_query_to_xp = get_box_xp(xp_C, query_feature)  # get similar latent feature in nn xp\n",
    "        xp_feature = xp_C[ :, :, coord[0]:coord[0]+1, coord[1]:coord[1]+1]\n",
    "        \n",
    "        # Get logit change\n",
    "        temp = xp_C.clone()\n",
    "        temp[ :, :, coord[0]:coord[0]+1, coord[1]:coord[1]+1] = 0.0\n",
    "        new_xp_logits = net_classifier(temp)\n",
    "        logit_change = xp_logits[0][xp_pred] - new_xp_logits[0][xp_pred]\n",
    "        logit_change = abs( logit_change - query_nb_boxes[0][0] )\n",
    "                       \n",
    "        if dist_query_to_xp < min_dist:\n",
    "            min_dist = dist_query_to_xp\n",
    "            min_logit = logit_change\n",
    "\n",
    "    results.append(min_dist)\n",
    "    saliency_sim.append(min_logit.item())\n",
    "    \n",
    "    if query_idx == 300:\n",
    "        break\n",
    "            \n",
    "print('Contamination:', contamination / len(results)*NUM_NEIGHBORS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 277,
   "metadata": {},
   "outputs": [],
   "source": [
    "saliency_sim = np.array(saliency_sim)\n",
    "results = np.array(deepcopy(results))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 278,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "26.804295013909325\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<seaborn.axisgrid.FacetGrid at 0x173d5fa90>"
      ]
     },
     "execution_count": 278,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFgCAYAAACFYaNMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAQpUlEQVR4nO3dbYwchXnA8f9jH2+FJIbkZLnns0wEgiK1mNQhYFCVmlI5aQq0orwIpf7g1EglFShREmikSkj9EKQqgKoqxQKKKyEwIaQQVEGJY1K1RKbHW3hxEA6F2mDwkYJoU4nE+OmHHYvLyS9re2efubv/T1rdzsze3mNv8mduPDsbmYkkafjmVQ8gSXOVAZakIgZYkooYYEkqYoAlqchI9QD9WLVqVT700EPVY0jSoYq9rZwRe8BvvfVW9QiSNHAzIsCSNBsZYEkqYoAlqYgBlqQiBliSihhgSSpigCWpiAGWpCIGWJKKGGBJKmKAJamIAZakIgZYkooYYEkqYoA1MGPjS4iIgdzGxpdU/3Gk1s2IC7JrZnh9+zYuveWxgTzXhitXDOR5pC5zD1iSihhgSSpigCWpiAGWpCIGWJKKeBbEDDU2voTXt28byHPNP+Io3v/lewN5Lkn9M8Az1KBP+RrEc3nqmHRwPAQhSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFWg9wRMyPiKci4sFm+cSI2BwRWyNiQ0Qc2fYMktRFw9gDvhrYMmX5BuDGzDwJeBtYM4QZJKlzWg1wRCwG/gC4tVkOYCVwb/OQ9cBFbc4gSV3V9h7wTcBXgd3N8keBdzJzV7O8HRjb2zdGxNqImIiIicnJyZbHlKThay3AEfE5YGdmPnEo35+Z6zJzeWYuHx0dHfB0klSvzctRngNcEBGfBY4GPgzcDCyIiJFmL3gx8FqLM0hSZ7W2B5yZ12Xm4sxcClwG/CAzrwA2ARc3D1sN3N/WDJLUZRXnAX8N+FJEbKV3TPi2ghkkqdxQPhEjMx8FHm3uvwycOYyfK0ld5jvhJKmIAe7D2PgSImIgt7HxJdV/HEkd4Ydy9mHQH4ApSeAesCSVMcCSVMQAS1IRAyxJRQywJBXxLIhhmzdC76qckuY6Azxsu3cN5JQ2T2eTZj4PQUhSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQOsbmou2+mnUGs283KU6iYv26k5wD1gSSpigCWpiAGWpCIGWJKKGGBJKmKAJamIAZakIgZYkooYYEkqYoAlqYgBlqQiBliSihhgSSpigCWpiAGWpCIGWJKKGGBJKmKAJamIAZakIgZYkooYYEkqYoAlqYgBlqQiBliSihhgSSpigCWpiAGWpCIGWJKKGGBJKmKAJamIAZakIgZYkooYYEkqYoAlqYgBlqQiBliSihhgSSrSWoAj4uiIeDwinomI5yPi+mb9iRGxOSK2RsSGiDiyrRkkqcva3AN+D1iZmacDy4BVEXEWcANwY2aeBLwNrGlxBknqrNYCnD3/2ywe0dwSWAnc26xfD1zU1gyS1GWtHgOOiPkR8TSwE3gE+CnwTmbuah6yHRjbx/eujYiJiJiYnJxsc0xJKtFqgDPz/cxcBiwGzgROPYjvXZeZyzNz+ejoaFsjSlKZoZwFkZnvAJuAs4EFETHSbFoMvDaMGSSpa9o8C2I0IhY0948Bzge20Avxxc3DVgP3tzWDJHXZyIEfcsgWAesjYj690N+TmQ9GxAvA3RHx18BTwG0tziBJndVagDPzx8AZe1n/Mr3jwZI0p/lOOEkqYoAlqYgBlqQiBliSihhgSSpigCWpiAGWpCIGWJKKGGBJKmKAJamIAZakIgZYkooYYEkqYoAlqYgBlqQiBliSihhgSSpigCWpiAGWpCIGWJKKGGBJKmKAJamIAZakIgZYkooYYEkqYoAlqYgBlqQiBliSihhgSSpigCWpiAGWpCIGWJKKGGBJKmKAJamIAZakIn0FOCLO6WedJKl//e4B/22f6yRJfRrZ38aIOBtYAYxGxJembPowML/NwSRptttvgIEjgeOax31oyvp3gYvbGkqS5oL9Bjgzfwj8MCLuyMxXhzSTJM0JB9oD3uOoiFgHLJ36PZm5so2hJGku6DfA3wb+HrgVeL+9cSRp7ug3wLsy81utTiJJc0y/p6F9LyL+PCIWRcQJe26tTiZJs1y/e8Crm69fmbIugY8PdhxpwOaNEBGH/TS/vnic17b91wAGkj7QV4Az88S2B5FasXsXl97y2GE/zYYrVwxgGOlX9RXgiPjTva3PzH8c7DiSNHf0ewjik1PuHw2cBzwJGGBJOkT9HoL4i6nLEbEAuLuNgSRprjjUy1H+HPC4sCQdhn6PAX+P3lkP0LsIz28A97Q1lCTNBf0eA/6bKfd3Aa9m5vYW5pGkOaOvQxDNRXl+Qu+KaMcDv2hzKEmaC/r9RIxLgMeBPwEuATZHhJejlKTD0O8hiK8Dn8zMnQARMQp8H7i3rcEkabbr9yyIeXvi2/jZQXyvJGkv+t0DfigiHgbuapYvBf65nZEkaW440GfCnQQszMyvRMQfA+c2m34E3Nn2cJI0mx1oD/gm4DqAzLwPuA8gIn6z2faHLc4mSbPagY7jLszMZ6evbNYtbWUiSZojDhTgBfvZdswA55CkOedAAZ6IiD+bvjIivgA80c5IkjQ3HOgY8DXAdyPiCj4I7nLgSOCP9veNETFO73KVC+ldR2JdZt7cfJTRBnqHMF4BLsnMtw9xfkmasfa7B5yZb2bmCuB6erF8Bbg+M8/OzDcO8Ny7gC9n5mnAWcBVEXEacC2wMTNPBjY2y5I05/R7PeBNwKaDeeLM3AHsaO7/T0RsAcaAC4FPNw9bDzwKfO1gnluSZoOhvJstIpYCZwCb6Z1ZsaPZ9Aa9QxR7+561ETEREROTk5PDGFOShqr1AEfEccB3gGsy892p2zIz+eA6w0zbti4zl2fm8tHR0bbHlKShazXAEXEEvfje2byRA+DNiFjUbF8E7NzX90vSbNZagCMigNuALZn5zSmbHgBWN/dXA/e3NYMkdVm/F+M5FOcAnweejYinm3V/CXwDuCci1gCv0ru+sCTNOa0FODP/DYh9bD6vrZ8rSTOF1/SVpCIGWJKKGGBJKmKAJamIAZakIgZYkooYYEkqYoAlqYgBlqQiBliSihhgSSpigCWpiAGWpCIGWJKKGGBJKmKAJamIAZakIgZYkooYYEkqYoAlqYgBlqQiBliSihhgSSpigCWpiAGWpCIGWJKKGGBJKmKAJamIAZakIgZYkooYYEkqYoAlqYgBlqQiBliSihhgSSpigCWpiAGWpCIGWJKKGGBJKmKAJamIAZakIgZYkooYYEkqYoAlqYgBlvoxb4SIGMhtbHxJ9Z9GHTFSPYA0I+zexaW3PDaQp9pw5YqBPI9mPveAJamIAZakIgZYkooYYEkqYoAlqYgBlqQiBliSihhgSSpigCWpiAGWpCIGWJKKGGBJKmKAJamIAZakIgZYkooYYEkq0lqAI+L2iNgZEc9NWXdCRDwSES81X49v6+dLUte1uQd8B7Bq2rprgY2ZeTKwsVmWpDmptQBn5r8C/z1t9YXA+ub+euCitn6+JHXdsI8BL8zMHc39N4CFQ/75ktQZZf8Il5kJ5L62R8TaiJiIiInJyckhTia1bECfsOynK898w/5U5DcjYlFm7oiIRcDOfT0wM9cB6wCWL1++z1BLM86APmHZT1ee+Ya9B/wAsLq5vxq4f8g/X5I6o83T0O4CfgScEhHbI2IN8A3g/Ih4Cfi9ZlmS5qTWDkFk5uX72HReWz9TkmYS3wknSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklRkVgd4bHwJEXHYN0lqw0j1AG16ffs2Lr3lscN+ng1XrhjANJL0q2b1HrAkdZkBlqQiBliSihhgSSpigCWpiAGWZqp5IwM5zTIiGBtfUv2nmZNm9Wlo0qy2e9dATrMET7Ws4h6wJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywpIFdWc2rqh0cr4YmaWBXVvOqagfHPWBJKmKAJamIAZakIgZYkooYYEkqYoAlqYgBlqQiBliSihhgSSpigCWpiAGWpCIGWJKKGGBJKmKAJc1qY+NLBnKpzTYut+nlKCXNaq9v3zaQS23C4C+36R6wJBUxwJJUpCTAEbEqIl6MiK0RcW3FDJJUbegBjoj5wN8BnwFOAy6PiNOGPYckVavYAz4T2JqZL2fmL4C7gQsL5pCkUpGZw/2BERcDqzLzC83y54FPZeYXpz1uLbC2WTwFeHGog+7dx4C3qoc4gK7P2PX5oPszdn0+6P6Mw57vrcxcNX1lZ09Dy8x1wLrqOaaKiInMXF49x/50fcauzwfdn7Hr80H3Z+zKfBWHIF4DxqcsL27WSdKcUhHg/wBOjogTI+JI4DLggYI5JKnU0A9BZOauiPgi8DAwH7g9M58f9hyHqFOHRPah6zN2fT7o/oxdnw+6P2Mn5hv6P8JJknp8J5wkFTHAklTEAO9DRNweETsj4rkp606IiEci4qXm6/GF841HxKaIeCEino+Iqzs449ER8XhEPNPMeH2z/sSI2Ny8FX1D84+xZSJifkQ8FREPdnS+VyLi2Yh4OiImmnVdep0XRMS9EfGTiNgSEWd3bL5Tmr+7Pbd3I+KaLsxogPftDmD6idPXAhsz82RgY7NcZRfw5cw8DTgLuKp5S3eXZnwPWJmZpwPLgFURcRZwA3BjZp4EvA2sqRsRgKuBLVOWuzYfwO9m5rIp56526XW+GXgoM08FTqf3d9mZ+TLzxebvbhnw28D/Ad/txIyZ6W0fN2Ap8NyU5ReBRc39RcCL1TNOme1+4Pyuzgj8GvAk8Cl670AaadafDTxcONdiev/nWwk8CESX5mtmeAX42LR1nXidgY8A/0nzD/pdm28v8/4+8O9dmdE94IOzMDN3NPffABZWDrNHRCwFzgA207EZm1/vnwZ2Ao8APwXeycxdzUO2A2NF4wHcBHwV2N0sf5RuzQeQwL9ExBPNW/ShO6/zicAk8A/NYZxbI+LYDs033WXAXc398hkN8CHK3n82y8/hi4jjgO8A12Tmu1O3dWHGzHw/e7/6LaZ3IaZTK+eZKiI+B+zMzCeqZzmAczPzE/SuIHhVRPzO1I3Fr/MI8AngW5l5BvBzpv0q34X/HQI0x/IvAL49fVvVjAb44LwZEYsAmq87K4eJiCPoxffOzLyvWd2pGffIzHeATfR+pV8QEXveBFT5VvRzgAsi4hV6V+VbSe94ZlfmAyAzX2u+7qR37PJMuvM6bwe2Z+bmZvleekHuynxTfQZ4MjPfbJbLZzTAB+cBYHVzfzW9464lIiKA24AtmfnNKZu6NONoRCxo7h9D7xj1Fnohvrh5WNmMmXldZi7OzKX0fjX9QWZe0ZX5ACLi2Ij40J779I5hPkdHXufMfAPYFhGnNKvOA16gI/NNczkfHH6ALsxYfVC8qzd6L9QO4Jf0/iu/ht7xwY3AS8D3gRMK5zuX3q9MPwaebm6f7diMvwU81cz4HPBXzfqPA48DW+n9OnhUB17vTwMPdm2+ZpZnmtvzwNeb9V16nZcBE83r/E/A8V2ar5nxWOBnwEemrCuf0bciS1IRD0FIUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQV+X8wGLhAlts77wAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(results.mean())\n",
    "sns.displot(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 279,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.1530880591402022\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<seaborn.axisgrid.FacetGrid at 0x173f8d450>"
      ]
     },
     "execution_count": 279,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFgCAYAAACFYaNMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAR+0lEQVR4nO3da7BdB1mH8eefHkpRLm0hZGqa2DKUSweGyxwQAqNCxamobVVsYQCDU0gHhcGBQYt88fYBRgUdx9FkKBIdLikVbLlYrKHAaKEQKLe2ILVSmt5yihQVRzD09cNZJWlNm92Ts/a7zznPb+bM2WvtvbLfWZM8s7L2XnunqpAkTd+67gEkaa0ywJLUxABLUhMDLElNDLAkNZnrHmASp59+el166aXdY0jSUuVQK1fEEfDtt9/ePYIkLbsVEWBJWo0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCHsHHTZpIs6Wfjps3d40taIVbEB7JP2817b+Sc7Vcsadtd521Z5mkkrVYeAUtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNRg1wkmOTXJTkK0muTfLMJMcnuSzJ14bfx405gyTNqrGPgP8UuLSqHgc8CbgWOB/YXVWnALuHZUlac0YLcJKHAT8OXABQVd+rqjuAM4Gdw8N2AmeNNYMkzbIxj4BPBhaAv0pyVZK3JflhYENV3TI85lZgw4gzSNLMGjPAc8BTgb+oqqcA3+EepxuqqoA61MZJtiXZk2TPwsLCiGNKUo8xA7wX2FtVVw7LF7EY5NuSnAAw/N53qI2rakdVzVfV/Pr160ccU5J6jBbgqroVuDHJY4dVpwHXAJcAW4d1W4GLx5pBkmbZ3Mh//quBdyY5Grge+FUWo39hknOBG4CzR55BkmbSqAGuqs8D84e467Qxn1eSVgKvhJOkJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJanJqg7wxk2bSXK/fyRpGua6BxjTzXtv5JztV9zv7Xadt2WEaSTp7kYNcJKvA/8JfB/YX1XzSY4HdgEnAV8Hzq6qb405hyTNommcgnhOVT25quaH5fOB3VV1CrB7WJakNafjHPCZwM7h9k7grIYZJKnd2AEu4B+SfDbJtmHdhqq6Zbh9K7DhUBsm2ZZkT5I9CwsLI48pSdM39otwz66qm5I8ErgsyVcOvrOqKkkdasOq2gHsAJifnz/kYyRpJRv1CLiqbhp+7wPeDzwduC3JCQDD731jziBJs2q0ACf54SQPues28NPAl4FLgK3Dw7YCF481gyTNsjFPQWwA3j9c2DAHvKuqLk3yGeDCJOcCNwBnjziDJM2s0QJcVdcDTzrE+m8Cp431vJK0UqzqS5ElaZYZYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZ4ua2bO+zX3h/qZ+Omzd2TS5qyVf219C3u3M8526+435vtOm/LCMNImmUeAUtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNRk9wEmOSnJVkg8OyycnuTLJdUl2JTl67BkkaRZN4wj4NcC1By2/GXhrVT0a+BZw7hRmkKSZM2qAk5wI/CzwtmE5wHOBi4aH7ATOGnMGSZpVYx8B/wnwm8Cdw/LDgTuqav+wvBfYeKgNk2xLsifJnoWFhZHHlKTpGy3ASX4O2FdVn13K9lW1o6rmq2p+/fr1yzydJPWbG/HPfhZwRpLnA8cADwX+FDg2ydxwFHwicNOIM0jSzBrtCLiq3lBVJ1bVScALgY9W1YuBy4EXDA/bClw81gySNMs63gf8W8Brk1zH4jnhCxpmkKR2Y56C+IGq+hjwseH29cDTp/G8kjTLvBJOkpoYYElqYoAlqYkBlqQmBnhWrJsjyZJ+Nm7a3D29pCWYyrsgNIE793PO9iuWtOmu87Ys8zCSpsEjYElqMlGAkzxrknWSpMlNegT8ZxOukyRN6D7PASd5JrAFWJ/ktQfd9VDgqDEHk6TV7nAvwh0NPHh43EMOWv8fHPhAHUnSEtxngKvq48DHk7yjqm6Y0kyStCZM+ja0BybZAZx08DZV9dwxhpKktWDSAL8X+EsWv9vt++ONI0lrx6QB3l9VfzHqJJK0xkz6NrQPJPm1JCckOf6un1Enk6RVbtIj4K3D79cftK6ARy3vOJK0dkwU4Ko6eexBJGmtmSjASX7lUOur6q+XdxxJWjsmPQXxtINuHwOcBnwOMMCStESTnoJ49cHLSY4F3jPGQJK0Viz14yi/A3heWJKOwKTngD/A4rseYPFDeB4PXDjWUJK0Fkx6DviPDrq9H7ihqvaOMI8krRkTnYIYPpTnKyx+ItpxwPfGHEqS1oJJvxHjbODTwC8DZwNXJvHjKCXpCEx6CuKNwNOqah9AkvXAPwIXjTWYJK12k74LYt1d8R18835sK0k6hEmPgC9N8hHg3cPyOcCHxxlJktaGw30n3KOBDVX1+iS/CDx7uOuTwDvHHk6SVrPDHQH/CfAGgKp6H/A+gCRPHO77+RFnk6RV7XDncTdU1ZfuuXJYd9IoE0nSGnG4AB97H/c9aBnnkKQ153AB3pPkFfdcmeTlwGfHGUmS1obDnQP+DeD9SV7MgeDOA0cDvzDiXJK06t1ngKvqNmBLkucATxhWf6iqPjr6ZJK0yk36ecCXA5ePPIskrSmjXc2W5Jgkn07yhSRXJ/ndYf3JSa5Mcl2SXUmOHmsGSZplY15O/F3guVX1JODJwOlJngG8GXhrVT0a+BZw7ogzSNLMGi3Atei/hsUHDD8FPJcDH+KzEzhrrBkkaZaN+oE6SY5K8nlgH3AZ8K/AHVW1f3jIXmDjmDNI0qwaNcBV9f2qejJwIvB04HGTbptkW5I9SfYsLCyMNaIktZnKR0pW1R0svovimcCxSe5698WJwE33ss2Oqpqvqvn169dPY0xJmqox3wWxfvj6epI8CHgecC2LIb7r2zS2AhePNYMkzbJJPw94KU4AdiY5isXQX1hVH0xyDfCeJH8AXAVcMOIMkjSzRgtwVX0ReMoh1l/P4vlgLZd1cyRZ0qY/cuImbrrxG8s8kKRJjHkErGm5cz/nbL9iSZvuOm/LMg8jaVJ+r5skNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDrCXZuGkzSZb0s3HT5u7xpZkw1z2AVqab997IOduvWNK2u87bsszTSCvTaEfASTYluTzJNUmuTvKaYf3xSS5L8rXh93FjzSBJs2zMUxD7gddV1anAM4BfT3IqcD6wu6pOAXYPy5K05owW4Kq6pao+N9z+T+BaYCNwJrBzeNhO4KyxZpCkWTaVF+GSnAQ8BbgS2FBVtwx33QpsuJdttiXZk2TPwsLCNMaUpKkaPcBJHgz8LfAbVfUfB99XVQXUobarqh1VNV9V8+vXrx97TEmaulEDnOQBLMb3nVX1vmH1bUlOGO4/Adg35gySNKvGfBdEgAuAa6vqLQfddQmwdbi9Fbh4rBkkaZaN+T7gZwEvBb6U5PPDut8G3gRcmORc4Abg7BFnkKSZNVqAq+qfgNzL3aeN9byStFJ4KbIkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTfxGjLVu3RyLV41LmjYDvNbduX9JXy3k1wpJR85TEJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNRgtwkrcn2ZfkywetOz7JZUm+Nvw+bqznl6RZN+YR8DuA0++x7nxgd1WdAuweliVpTRotwFX1CeDf77H6TGDncHsncNZYzy9Js27a54A3VNUtw+1bgQ339sAk25LsSbJnYWFhOtNpOtbNkeR+/2zctLl7cmlZzXU9cVVVkrqP+3cAOwDm5+fv9XFage7czznbr7jfm+06b8sIw0h9pn0EfFuSEwCG3/um/PySNDOmHeBLgK3D7a3AxVN+fkmaGWO+De3dwCeBxybZm+Rc4E3A85J8DfipYVmS1qTRzgFX1Yvu5a7TxnpOSVpJvBJOkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYK0JGzdtXtLXIPlVSBpT21cSSdN0894bl/Q1SOBXIWk8HgFLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ18X3AWjnWzZGkewpp2RhgrRx37vdiCq0qnoKQpCYGWJKaGGBJamKApcMZXvy7vz9zRx8z9U9g81PfVhZfhJMOZ4kv/u06b8vUXzT0U99WFo+AJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgaRYt8eKPjuc8kgs4Oi4cmaWLVbwQQ5pFR3Dxx0p6zo4LR2bpYpWWI+Akpyf5apLrkpzfMYMkdZt6gJMcBfw58DPAqcCLkpw67TkkqVvHEfDTgeuq6vqq+h7wHuDMhjkkqVWqarpPmLwAOL2qXj4svxT4sap61T0etw3YNiw+FvjqEp7uEcDtRzDuauK+OMB9cXfujwPG2he3V9Xp91w5sy/CVdUOYMeR/BlJ9lTV/DKNtKK5Lw5wX9yd++OAae+LjlMQNwGbDlo+cVgnSWtKR4A/A5yS5OQkRwMvBC5pmEOSWk39FERV7U/yKuAjwFHA26vq6pGe7ohOYawy7osD3Bd35/44YKr7YuovwkmSFnkpsiQ1McCS1GTFB/hwlzUneWCSXcP9VyY5qWHMqZlgf7w2yTVJvphkd5If7ZhzGia95D3JLyWpJKv2rViT7IskZw9/N65O8q5pzzhNE/w72Zzk8iRXDf9Wnj/KIFW1Yn9YfBHvX4FHAUcDXwBOvcdjfg34y+H2C4Fd3XM374/nAD803H7lat0fk+yL4XEPAT4BfAqY75678e/FKcBVwHHD8iO7527eHzuAVw63TwW+PsYsK/0IeJLLms8Edg63LwJOyxF/bt/MOuz+qKrLq+q/h8VPsfg+7NVo0kvefx94M/A/0xxuyibZF68A/ryqvgVQVfumPOM0TbI/CnjocPthwM1jDLLSA7wRuPGg5b3DukM+pqr2A98GHj6V6aZvkv1xsHOBvx91oj6H3RdJngpsqqoPTXOwBpP8vXgM8Jgk/5zkU0n+32Wzq8gk++N3gJck2Qt8GHj1GIPM7KXIGleSlwDzwE90z9IhyTrgLcDLmkeZFXMsnob4SRb/V/SJJE+sqjs6h2r0IuAdVfXHSZ4J/E2SJ1TVncv5JCv9CHiSy5p/8Jgkcyz+d+KbU5lu+ia6zDvJTwFvBM6oqu9OabZpO9y+eAjwBOBjSb4OPAO4ZJW+EDfJ34u9wCVV9b9V9W/Av7AY5NVokv1xLnAhQFV9EjiGxQ/qWVYrPcCTXNZ8CbB1uP0C4KM1nFlfhQ67P5I8BdjOYnxX83m++9wXVfXtqnpEVZ1UVSexeD78jKra0zPuqCb5d/J3LB79kuQRLJ6SuH6KM07TJPvjG8BpAEkez2KAF5Z7kBUd4OGc7l2XNV8LXFhVVyf5vSRnDA+7AHh4kuuA1wKr9hs4Jtwffwg8GHhvks8nWZWfwzHhvlgTJtwXHwG+meQa4HLg9VW1Kv+nOOH+eB3wiiRfAN4NvGyMAzcvRZakJiv6CFiSVjIDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1+T8ZPMvLjnnqiQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(saliency_sim.mean())\n",
    "sns.displot(saliency_sim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Try My Method With Cosine Similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 237,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier(algorithm='brute', metric='euclidean', n_neighbors=1)"
      ]
     },
     "execution_count": 237,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "temp_train = list()\n",
    "temp_test = list()\n",
    "\n",
    "for i in range(X_train_C.shape[0]):\n",
    "    x = deepcopy(X_train_x[i].flatten())\n",
    "    if x.sum() == 0.:\n",
    "        temp_train.append(x)\n",
    "    else:\n",
    "        x = x/x.sum()\n",
    "        temp_train.append(x)\n",
    "    \n",
    "for i in range(X_test_C.shape[0]):\n",
    "    x = deepcopy(X_test_x[i].flatten())\n",
    "    if x.sum() == 0.:\n",
    "        temp_test.append(x)\n",
    "    else:\n",
    "        x = x/x.sum()\n",
    "        temp_test.append(x)  \n",
    "        \n",
    "temp_train = np.array(temp_train)\n",
    "temp_test  = np.array(temp_test)\n",
    "\n",
    "# Fit COLE and DkNN\n",
    "twin = KNeighborsClassifier(n_neighbors=1, algorithm=\"brute\", metric='euclidean') \n",
    "twin.fit(temp_train, train_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 238,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  3%|▎         | 300/10000 [00:24<13:05, 12.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Contamination: 99.80066445182723\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "results = list()\n",
    "saliency_sim = list()\n",
    "contamination = 0\n",
    "\n",
    "\n",
    "for query_idx, data in enumerate(tqdm(test_loader)):\n",
    "    \n",
    "    # Get data\n",
    "    img, label = data\n",
    "    \n",
    "    # Get query information\n",
    "    query_label = y_test[query_idx].item()\n",
    "    query_pred  = test_preds[query_idx].item()\n",
    "    query_img_trans = get_transformed_data([query_idx], test_loader)[0]\n",
    "    query_cont = X_test_c[query_idx]\n",
    "    \n",
    "    # Get query salient feature\n",
    "    query_logits, query_x, query_C = netC(query_img_trans)\n",
    "    query_nb_boxes = get_salient_regions(query_C, query_logits, net_classifier)\n",
    "    window_idx = query_nb_boxes[0][1]  # Only the most NB one\n",
    "    query_feature = query_C[ :, :, window_idx[0]:window_idx[0]+1, window_idx[1]:window_idx[1]+1]\n",
    "\n",
    "    # Get explanation nn\n",
    "    xp_idxs = twin.kneighbors(X=[temp_test[query_idx]], n_neighbors=NUM_NEIGHBORS, return_distance=False)\n",
    "    \n",
    "    \n",
    "    # How many NNs are different class\n",
    "    contamination += (train_preds[xp_idxs[0]] == query_pred).sum()\n",
    "    \n",
    "    \n",
    "    # Iterate all three salient regions in xp nn\n",
    "    min_dist = float('inf')\n",
    "    min_logit = float('inf')\n",
    "    for i in range(NUM_NEIGHBORS):\n",
    "\n",
    "        xps_img_trans = get_transformed_data([xp_idxs[0][i]], train_loader)\n",
    "        xp_logits, xp_x, xp_C = netC(xps_img_trans[0])\n",
    "        xp_pred = torch.argmax(xp_logits, dim=1).item()\n",
    "        \n",
    "        coord, dist_query_to_xp = get_box_xp(xp_C, query_feature)  # get similar latent feature in nn xp\n",
    "        xp_feature = xp_C[ :, :, coord[0]:coord[0]+1, coord[1]:coord[1]+1]\n",
    "        \n",
    "        # Get logit change\n",
    "        temp = xp_C.clone()\n",
    "        temp[ :, :, coord[0]:coord[0]+1, coord[1]:coord[1]+1] = 0.0\n",
    "        new_xp_logits = net_classifier(temp)\n",
    "        logit_change = xp_logits[0][xp_pred] - new_xp_logits[0][xp_pred]\n",
    "        logit_change = abs( logit_change - query_nb_boxes[0][0] )\n",
    "                       \n",
    "        if dist_query_to_xp < min_dist:\n",
    "            min_dist = dist_query_to_xp\n",
    "            min_logit = logit_change\n",
    "\n",
    "    results.append(min_dist)\n",
    "    saliency_sim.append(min_logit.item())\n",
    "    \n",
    "    if query_idx == 300:\n",
    "        break\n",
    "            \n",
    "print('Contamination:', contamination / len(results)*NUM_NEIGHBORS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 239,
   "metadata": {},
   "outputs": [],
   "source": [
    "saliency_sim = np.array(saliency_sim)\n",
    "results = np.array(deepcopy(results))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 240,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "25.857497668345506\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<seaborn.axisgrid.FacetGrid at 0x1626ba710>"
      ]
     },
     "execution_count": 240,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFgCAYAAACFYaNMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAQWElEQVR4nO3dfaxkBXnH8e9vd0GsL13QDdnuS8BAtKStYFaExTQKtdm2VmlDQWPtpsFCUm0wWi3aPxqbNtGk8SVNY9mAdZtYXYpakDZYimjbYLCLYEFWAlIpi8AuVWLrH9iFp3/Mod5sYHdg59xn7r3fTzK5c8683GfC8OVw7pkzqSokSYtvVfcAkrRSGWBJamKAJamJAZakJgZYkpqs6R5gGtu2bavrrruuewxJerbyVCuXxBbwI4880j2CJM3ckgiwJC1HBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgtduwaTNJZn7ZsGlz90uTDmlJnJBdy9t3997PBZfdNPPn3XXx1pk/pzRLbgFLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwHPAk9FIK5Mn45kDnoxGWpncApakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAX6GNmzaTJKZXkazas3MZ92wafN480orzJruAZaa7+69nwsuu2mmz7nr4q0zfb7/98SBpTOrtAK5BSxJTUYPcJLVSW5Ncu2wfGKSm5Pck2RXkqPHnkGS5tFibAFfAuxZsPwh4CNVdRLwfeDCRZhBkubOqAFOshH4FeDyYTnA2cBVw112AueOOYMkzaux/wj3UeC9wAuG5RcBj1bVgWF5L7DhqR6Y5CLgIoDNm/3L+9wYjqyQdORGC3CS1wP7quqWJK95po+vqh3ADoAtW7bUbKfTs+aRFdLMjLkFfBbwhiS/DBwDvBD4GLA2yZphK3gj8MCIM0jS3BptH3BVva+qNlbVCcCbgC9V1VuAG4HzhrttB64eawZJmmcdxwH/AfCuJPcw2Sd8RcMMktRuUT4JV1VfBr48XL8XOH0xfq8kzTM/CSdJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDrOVr1RqSzPSyYdPm7lelZWRN9wDSaJ44wAWX3TTTp9x18daZPp9WNreAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJqMFOMkxSb6W5BtJvpnkA8P6E5PcnOSeJLuSHD3WDJI0z8bcAn4MOLuqXg6cCmxLcgbwIeAjVXUS8H3gwhFnkKS5NVqAa+J/hsWjhksBZwNXDet3AueONYMkzbNR9wEnWZ3kNmAfcD3wbeDRqjow3GUvsOFpHntRkt1Jdu/fv3/MMSWpxagBrqrHq+pUYCNwOvCyZ/DYHVW1paq2rFu3bqwRJanNohwFUVWPAjcCZwJrkzz5XXQbgQcWYwZJmjdjHgWxLsna4fpzgdcBe5iE+LzhbtuBq8eaQZLm2Zjfirwe2JlkNZPQX1lV1ya5E/hMkj8BbgWuGHEGSZpbowW4qv4dOO0p1t/LZH+wJK1ofhJOkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmkwV4CRnTbNOkjS9abeA/3zKdZKkKR3yO+GSnAlsBdYledeCm14IrB5zMEla7g73pZxHA88f7veCBet/wI+/Wl6S9CwcMsBV9RXgK0k+WVX3LdJMkrQiTPu19M9JsgM4YeFjqursMYaSpJVg2gD/LfCXwOXA4+ONI0krx7QBPlBVHx91EklaYaY9DO0LSX43yfokxz15GXUySVrmpt0C3j78fM+CdQW8ZLbjSNLKMVWAq+rEsQeRpJVmqgAn+a2nWl9Vfz3bcSRp5Zh2F8QrF1w/BjgH+DpggCXpWZp2F8TvLVxOshb4zBgDSdJK8WxPR/lDwP3CknQEpt0H/AUmRz3A5CQ8Pw1cOdZQkrQSTLsP+M8WXD8A3FdVe0eYR5JWjKl2QQwn5fkWkzOiHQv8aMyhJGklmPYbMc4Hvgb8BnA+cHMST0cpSUdg2l0Qfwi8sqr2ASRZB/wTcNVYg0nScjftURCrnozv4L+ewWMlSU9h2i3g65J8Efj0sHwB8A/jjCRJK8PhvhPuJOD4qnpPkl8HXj3c9FXgU2MPJ0nL2eG2gD8KvA+gqj4HfA4gyc8Ot/3qiLNJ0rJ2uP24x1fV7QevHNadMMpEkrRCHC7Aaw9x23NnOIckrTiHC/DuJL9z8MokbwNuGWckSVoZDrcP+J3A55O8hR8HdwtwNPBrI84lScveIQNcVQ8DW5O8FviZYfXfV9WXRp9Mkpa5ac8HfCNw48izSNKK4qfZJKmJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpyWgBTrIpyY1J7kzyzSSXDOuPS3J9kruHn8eONYMkzbMxt4APAO+uqlOAM4C3JzkFuBS4oapOBm4YliVpxRktwFX1YFV9fbj+38AeYAPwRmDncLedwLljzSBJ82xR9gEnOQE4DbiZybdsPDjc9BBw/GLMIEnzZvQAJ3k+8FngnVX1g4W3VVUB9TSPuyjJ7iS79+/fP/aYkrToRg1wkqOYxPdTw5d6AjycZP1w+3pg31M9tqp2VNWWqtqybt26MceUpBZjHgUR4ApgT1V9eMFN1wDbh+vbgavHmkGS5tlUJ2R/ls4C3grcnuS2Yd37gQ8CVya5ELgPOH/EGSRpbo0W4Kr6VyBPc/M5Y/1eSVoq/CScJDUxwNIzsWoNSWZ62bBpc/erUpMx9wFLy88TB7jgsptm+pS7Lt460+fT0uEWsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLDUbdUaksz0smHT5u5XpSms6R5AWvGeOMAFl90006fcdfHWmT6fxuEWsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUpPRApzkE0n2Jbljwbrjklyf5O7h57Fj/X5JmndjbgF/Eth20LpLgRuq6mTghmFZklak0QJcVf8MfO+g1W8Edg7XdwLnjvX7JWneLfY+4OOr6sHh+kPA8Yv8+yVpbrT9Ea6qCqinuz3JRUl2J9m9f//+RZxMkhbHYgf44STrAYaf+57ujlW1o6q2VNWWdevWLdqAkrRYFjvA1wDbh+vbgasX+fdL0twY8zC0TwNfBV6aZG+SC4EPAq9LcjfwC8OyJK1Io50PuKre/DQ3nTPW75SkpcRPwklSEwMsSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktTEAEtSEwMsSU0MsKSpbdi0mSQzvWzYtLn7ZbUZ7XzAkpaf7+69nwsuu2mmz7nr4q0zfb6lxC1gSWpigCWpiQGWpCYGWJKaGGBJauJRENJytGoNSbqn0GEYYGk5euLAzA8Xg5V9yNgY3AUhSU0MsCQ1McCS1MQAS1ITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNTHAktRkWQd4jO+vkqRZWdYn4/H7qyTNs2W9BSxJ88wAS1ITAyxJTQywJDVZ1n+Ek7QEjPD1SauPeg6P/+9jM31OgJ/auIkH7v/PmT2fAZbUa4SvT9p18dYl8ZVM7oKQpCYGWJKaGGBJamKAJamJAZakJgZYkpoYYElqYoAlqYkBlqQmBliSmhhgSWpigCWpiQGWpCYGWJKaGGBJatIS4CTbktyV5J4kl3bMIEndFj3ASVYDfwH8EnAK8OYkpyz2HJLUrWML+HTgnqq6t6p+BHwGeGPDHJLUKlW1uL8wOQ/YVlVvG5bfCryqqt5x0P0uAi4aFl8K3LWog07nxcAj3UMcAefv5fy9FnP+R6pq28Er5/Y74apqB7Cje45DSbK7qrZ0z/FsOX8v5+81D/N37IJ4ANi0YHnjsE6SVpSOAP8bcHKSE5McDbwJuKZhDklqtei7IKrqQJJ3AF8EVgOfqKpvLvYcMzLXu0im4Py9nL9X+/yL/kc4SdKEn4STpCYGWJKaGOApJflEkn1J7liw7rgk1ye5e/h5bOeMh5JkU5Ibk9yZ5JtJLhnWL4nXkOSYJF9L8o1h/g8M609McvPwsfZdwx9251KS1UluTXLtsLyUZv9OktuT3JZk97BuSbx3AJKsTXJVkm8l2ZPkzHmY3wBP75PAwQdSXwrcUFUnAzcMy/PqAPDuqjoFOAN4+/AR8KXyGh4Dzq6qlwOnAtuSnAF8CPhIVZ0EfB+4sG/Ew7oE2LNgeSnNDvDaqjp1wbGzS+W9A/Ax4Lqqehnwcib/HPrnryovU16AE4A7FizfBawfrq8H7uqe8Rm8lquB1y3F1wD8BPB14FVMPsm0Zlh/JvDF7vmeZuaNTP4lPxu4FshSmX2Y7zvAiw9atyTeO8BPAv/BcNDBPM3vFvCROb6qHhyuPwQc3znMtJKcAJwG3MwSeg3D/8LfBuwDrge+DTxaVQeGu+wFNjSNdzgfBd4LPDEsv4ilMztAAf+Y5JbhNAGwdN47JwL7gb8adgFdnuR5zMH8BnhGavKf0bk/pi/J84HPAu+sqh8svG3eX0NVPV5VpzLZmjwdeFnvRNNJ8npgX1Xd0j3LEXh1Vb2CyVkM357k5xfeOOfvnTXAK4CPV9VpwA85aHdD1/wG+Mg8nGQ9wPBzX/M8h5TkKCbx/VRVfW5YvaReA0BVPQrcyOR/29cmefIDRfP6sfazgDck+Q6Ts/+dzWSf5FKYHYCqemD4uQ/4PJP/AC6V985eYG9V3TwsX8UkyO3zG+Ajcw2wfbi+ncl+1bmUJMAVwJ6q+vCCm5bEa0iyLsna4fpzmey/3sMkxOcNd5vL+avqfVW1sapOYPLR+y9V1VtYArMDJHlekhc8eR34ReAOlsh7p6oeAu5P8tJh1TnAnczB/H4SbkpJPg28hskp7B4G/gj4O+BKYDNwH3B+VX2vacRDSvJq4F+A2/nxfsj3M9kPPPevIcnPATuZfHx9FXBlVf1xkpcw2ao8DrgV+M2qeqxv0kNL8hrg96vq9Utl9mHOzw+La4C/qao/TfIilsB7ByDJqcDlwNHAvcBvM7yPaJzfAEtSE3dBSFITAyxJTQywJDUxwJLUxABLUhMDLElNDLAkNfk/OWQ/vNRMpYwAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(results.mean())\n",
    "sns.displot(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 241,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.1492092688614348\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<seaborn.axisgrid.FacetGrid at 0x138f87210>"
      ]
     },
     "execution_count": 241,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFgCAYAAACFYaNMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAASV0lEQVR4nO3dfZBdd13H8fcnDQVFoC3GTE0TW6SC9QHQBSE4jlB0Igotii0ManSKqYpPg6NW8R8fZhR1FEYZbAQkziCkVpgW1GqNRUYrxQAFbItSKrUpbbNFEMVRDP36x57YNW6yt9k993tv9v2aubPnnHsePnu388np795zbqoKSdL0beoOIEkblQUsSU0sYElqYgFLUhMLWJKabO4OMIldu3bVdddd1x1Dkk5WVlo4F2fA999/f3cESVp3c1HAknQqsoAlqYkFLElNLGBJamIBS1ITC1iSmljAktTEApakJhawJDWxgCWpiQUsSU0sYElqYgFLUhMLWJKaWMCr2LZ9B0nW9Ni2fUf3ryFpBs3FDdk7ffzQXVx65Y1r2sf+y3euUxpJpxLPgCWpiQUsSU0sYElqYgFLUhMLWJKaWMCS1MQClqQmFrAkNbGAJanJqAWc5IwkVyf5cJLbkjwjyVlJrk/ykeHnmWNmkKRZNfYZ8KuB66rqicCTgNuAK4ADVXU+cGCYl6QNZ7QCTvIY4BuA1wNU1Wer6lPARcC+YbV9wMVjZZCkWTbmGfB5wCLwe0nen+R1SR4JbK2qe4Z17gW2rrRxkj1JDiY5uLi4OGJMSeoxZgFvBr4GeG1VPQX4DMcMN1RVAbXSxlW1t6oWqmphy5YtI8aUpB5jFvAh4FBV3TTMX81SId+X5GyA4efhETNI0swarYCr6l7griRPGBZdCNwKXAvsHpbtBq4ZK4MkzbKxb8j+I8CbkpwO3AF8H0ulf1WSy4A7gUtGziBJM2nUAq6qm4GFFZ66cMzjStI88Eo4SWpiAUtSEwtYkppYwJLUxAKWpCYWsCQ1sYAlqYkFLElNLGBJamIBS1ITC1iSmljAktTEApakJhawJDWxgCWpiQUsSU0sYElqYgFLUhMLWJKaWMCS1MQClqQmFrAkNbGAJamJBSxJTSxgSWpiAUtSEwtYkppYwJLUxAKWpCYWsCQ1sYAlqYkFLElNLGBJanJKF/C27TtIsqaHJI1lc3eAMX380F1ceuWNa9rH/st3rlMaSfq/TukzYEmaZRawJDWxgCWpiQUsSU0sYElqMuqnIJJ8DPg34HPAkapaSHIWsB84F/gYcElVfXLMHJI0i6ZxBvysqnpyVS0M81cAB6rqfODAMC9JG07HEMRFwL5heh9wcUMGSWo3dgEX8OdJ3ptkz7Bsa1XdM0zfC2xdacMke5IcTHJwcXFx5JiSNH1jXwn39VV1d5IvAq5P8uHlT1ZVJamVNqyqvcBegIWFhRXXkaR5NuoZcFXdPfw8DLwNeBpwX5KzAYafh8fMIEmzarQCTvLIJI86Og18M/D3wLXA7mG13cA1Y2WQpFk25hDEVuBtwx3FNgN/UFXXJfk74KoklwF3ApeMmEGSZtZoBVxVdwBPWmH5J4ALxzquJM0Lr4STpCYWsCQ1sYAlqYkFLElNLGBJamIBS1ITC1iSmljAktTEAp6GTZtJsubHtu07un8TSeto7LuhCeCBI1x65Y1r3s3+y3euQxhJs8IzYElqYgFLUhMLWJKaWMCS1MQClqQmFrAkNbGAJamJBSxJTSxgSWpiAUtSEwtYkppYwJLUxAKWpCYWsCQ1sYAlqYkFLElNLGBJamIBS1ITC1iSmljAktTEApakJhawJDWxgCWpiQU8TzZtJsmaHtu27+j+LSQNNncH0EPwwBEuvfLGNe1i/+U71ymMpLXyDFiSmljAktTEApakJhawJDUZvYCTnJbk/UneMcyfl+SmJLcn2Z/k9LEzSNIsmsYZ8I8Bty2bfyXwm1X1eOCTwGVTyCBJM2fUAk5yDvCtwOuG+QDPBq4eVtkHXDxmBkmaVWOfAb8K+CnggWH+scCnqurIMH8I2LbShkn2JDmY5ODi4uLIMSVp+kYr4CTfBhyuqveezPZVtbeqFqpqYcuWLeucTpL6jXkl3DOB5yd5LvAI4NHAq4EzkmwezoLPAe4eMYMkzazRzoCr6meq6pyqOhd4EfCXVfUS4AbghcNqu4FrxsogSbOs43PAPw28PMntLI0Jv74hgyS1m8rNeKrqncA7h+k7gKdN47iSNMu8Ek6SmljAktTEApakJhawJDWxgCWpiQUsSU0sYElqYgFLUhMLWJKaWMCS1MQClqQmFrAkNbGAJamJBSxJTSxgSWpiAUtSk4kKOMkzJ1kmSZrcpGfAvzXhMknShE74lURJngHsBLYkefmypx4NnDZmMEk61a32nXCnA18wrPeoZcs/zYPfbCxJOgknLOCq+ivgr5K8sarunFImSdoQJv1W5Icn2Qucu3ybqnr2GKEkaSOYtID/EPgd4HXA58aLI0kbx6QFfKSqXjtqEknaYCb9GNrbk/xQkrOTnHX0MWoySTrFTXoGvHv4+ZPLlhXwuPWNI0kbx0QFXFXnjR1EkjaaiQo4yfestLyqfn9940jSxjHpEMRTl00/ArgQeB9gAUvSSZp0COJHls8nOQN4yxiBJGmjONnbUX4GcFxYktZg0jHgt7P0qQdYugnPlwNXjRVKkjaCSceAf33Z9BHgzqo6NEIeSdowJhqCGG7K82GW7oh2JvDZMUNJ0kYw6TdiXAK8B/hO4BLgpiTejlKS1mDSIYhXAE+tqsMASbYAfwFcPVYwSTrVTfopiE1Hy3fwiYewrSRpBZOeAV+X5M+ANw/zlwJ/Mk4kSdoYVvtOuMcDW6vqJ5N8O/D1w1N/C7xp7HCSdCpb7Qz4VcDPAFTVW4G3AiT5quG5542YTZJOaauN426tqg8du3BYdu6JNkzyiCTvSfKBJLck+flh+XlJbkpye5L9SU4/6fSSNMdWK+AzTvDc562y7X8Bz66qJwFPBnYleTrwSuA3q+rxwCeByyaLKkmnltUK+GCS7z92YZKXAu890Ya15N+H2YcNjwKezYMfX9sHXPxQAkvSqWK1MeAfB96W5CU8WLgLwOnAC1bbeZLThu0eD7wG+Cjwqao6MqxyCNh2nG33AHsAduzYsdqhJGnunLCAq+o+YGeSZwFfOSz+46r6y0l2XlWfA5483L7ybcATJw1WVXuBvQALCwu1yuqSNHcmvR/wDcANJ3uQqvpUkhuAZwBnJNk8nAWfA9x9svuVpHk22tVsSbYMZ74k+Tzgm4DbWCryo/eR2A1cM1YGSZplk14JdzLOBvYN48CbgKuq6h1JbgXekuSXgPcDrx8xgyTNrNEKuKo+CDxlheV3AE8b67iSNC+8oY4kNbGAJamJBSxJTSxgSWpiAUtSEwtYkppYwJLUxAKWpCYWsCQ1sYAlqYkFLElNLGBJamIBS1ITC3ij2bSZJGt6bNvuV0RJ62HM+wFrFj1whEuvvHFNu9h/+c51CiNtbJ4BS1ITC1iSmljAktTEApakJhawJDWxgCWpiQUsSU0sYElqYgFLUhMLWJKaWMCS1MQClqQmFrAkNbGAJamJBSxJTSxgSWpiAUtSEwtYkppYwJLUxAKWpCYWsCQ1sYAlqYkFLElNLGBJajJaASfZnuSGJLcmuSXJjw3Lz0pyfZKPDD/PHCuDJM2yMc+AjwA/UVUXAE8HXpbkAuAK4EBVnQ8cGOYlacMZrYCr6p6qet8w/W/AbcA24CJg37DaPuDisTJI0iybyhhwknOBpwA3AVur6p7hqXuBrdPIIEmzZvQCTvIFwB8BP15Vn17+XFUVUMfZbk+Sg0kOLi4ujh1TkqZu1AJO8jCWyvdNVfXWYfF9Sc4enj8bOLzStlW1t6oWqmphy5YtY8aUpBZjfgoiwOuB26rqN5Y9dS2we5jeDVwzVgZJmmWbR9z3M4HvBj6U5OZh2c8CvwJcleQy4E7gkhEzSNLMGq2Aq+qvgRzn6QvHOq4kzQuvhJOkJhawJDWxgCWpiQUsSU0sYElqYgFLUhMLWJKaWMCS1MQClqQmFrAkNbGAJamJBSxJTSxgSWpiAUtSEwtYkppYwJLUxAKWpCYWsCQ1sYAlqYkFrDbbtu8gyZoe27bv6P41pJM25rciSyf08UN3cemVN65pH/sv37lOaaTp8wxYkppYwJLUxAKWpCYWsCQ1sYAlqYkFLElNLGBJamIBS1ITL8TQQ7dpM0m6U0hzzwLWQ/fAkTVfwQZexSY5BCFJTSxgSWpiAWu+DePR3lFN88gxYM23dRiPdixaXTwDlqQmFrAkNbGAJamJBSxJTSxgSWoyWgEneUOSw0n+ftmys5Jcn+Qjw88zxzq+JM26Mc+A3wjsOmbZFcCBqjofODDMS9KGNFoBV9W7gH85ZvFFwL5heh9w8VjHl6RZN+0x4K1Vdc8wfS+w9XgrJtmT5GCSg4uLi9NJJ0lT1PYmXFUVUCd4fm9VLVTVwpYtW6aYTJKmY9oFfF+SswGGn4enfHxJmhnTLuBrgd3D9G7gmikfX5JmxpgfQ3sz8LfAE5IcSnIZ8CvANyX5CPCcYV6SNqTR7oZWVS8+zlMXjnVMSZonXgknSU0sYElqYgFLUhMLWJKaWMCS1MQClqQmFrAkNbGAJamJBSxJTSxgadNmkqz5sW37jjVH2bZ9x0zk0HSMdimyNDceOMKlV9645t3sv3znmvfx8UN3rTnLeuTQdHgGLElNLGBJamIBS1ITC1iSmljAktTEApakJhawJDWxgCWpiQUsSU0sYElqYgFLUhMLWJKaWMDSelmHu6rNivW4K5t3Zludd0OT1ss63FVtVu5kth53ZYPZ+X1mlWfAktTEApakJhawJDVxDFg61QxvBmr2WcDSqeYUejPwVOcQhCQ1sYAlqYkFLElNHAOWNJ51eEPwtIc9nM/993+17wPgi8/Zzt13/fOa93OUBSxpPOv0huAs7OPoftaTQxCS1MQClqQmFrAkNbGAJamJBSxJTVoKOMmuJP+Q5PYkV3RkkKRuUy/gJKcBrwG+BbgAeHGSC6adQ5K6dZwBPw24varuqKrPAm8BLmrIIUmtUlXTPWDyQmBXVb10mP9u4Ouq6oePWW8PsGeYfQLwDydxuC8E7l9D3Gkx5/oy5/qbl6yzmvP+qtp17MKZvRKuqvYCe9eyjyQHq2phnSKNxpzry5zrb16yzkvOozqGIO4Gti+bP2dYJkkbSkcB/x1wfpLzkpwOvAi4tiGHJLWa+hBEVR1J8sPAnwGnAW+oqltGOtyahjCmyJzry5zrb16yzktOoOFNOEnSEq+Ek6QmFrAkNZn7Al7tsuYkD0+yf3j+piTnNsQ8mmW1rN+Q5H1Jjgyfl24xQc6XJ7k1yQeTHEjyJTOa8weSfCjJzUn+uuuKy0kvvU/yHUkqScvHqCZ4Pb83yeLwet6c5KUdOYcsq76mSS4Z/ju9JckfTDvjRKpqbh8svYn3UeBxwOnAB4ALjlnnh4DfGaZfBOyf4aznAl8N/D7wwhnO+Szg84fpH+x4TSfM+ehl088HrpvFnMN6jwLeBbwbWJjFnMD3Ar897WwnmfV84P3AmcP8F3XnXukx72fAk1zWfBGwb5i+Grgwa/2SqpOzataq+lhVfRB4oCHfUZPkvKGq/mOYfTdLn+WetklyfnrZ7COBjnecJ730/heBVwL/Oc1wy8zTLQImyfr9wGuq6pMAVXV4yhknMu8FvA24a9n8oWHZiutU1RHgX4HHTiXdcXIMVso6Cx5qzsuAPx010comypnkZUk+Cvwq8KNTyrbcqjmTfA2wvar+eJrBjjHp3/07hqGnq5NsX+H5aZgk65cBX5bkb5K8O8n/uwx4Fsx7AatRku8CFoBf685yPFX1mqr6UuCngZ/rznOsJJuA3wB+ojvLBN4OnFtVXw1cz4P/ZzmLNrM0DPGNwIuB301yRmeglcx7AU9yWfP/rpNkM/AY4BNTSXecHINZvQR7opxJngO8Anh+Va39+74fuof6er4FuHjMQMexWs5HAV8JvDPJx4CnA9c2vBG36utZVZ9Y9rd+HfC1U8p2rEn+9oeAa6vqv6vqn4B/ZKmQZ0v3IPQaB+M3A3cA5/HgYPxXHLPOy/i/b8JdNatZl637RvrehJvkNX0KS2+CnD/jf/vzl00/Dzg4izmPWf+d9LwJN8nrefay6RcA757hv/0uYN8w/YUsDVk8tiPvCX+X7gDr8Md4Lkv/un0UeMWw7BdYOjMDeATwh8DtwHuAx81w1qey9C/3Z1g6S79lRnP+BXAfcPPwuHZGc74auGXIeMOJiq8z5zHrthTwhK/nLw+v5weG1/OJHTknzBqWhnZuBT4EvKgr64keXoosSU3mfQxYkuaWBSxJTSxgSWpiAUtSEwtYkppYwJLUxAKWpCb/A611QOZ8fxEzAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(saliency_sim.mean())\n",
    "sns.displot(saliency_sim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "neurips_img",
   "language": "python",
   "name": "neurips_img"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
