{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Combine Good Twin-System Explanation with Clustering of FAMs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import pickle\n",
    "import time\n",
    "import scipy\n",
    "import matplotlib.pyplot as plt\n",
    "import torchvision.models as models\n",
    "\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "\n",
    "from functions import *\n",
    "from ANNs import *\n",
    "from tqdm import tqdm\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pre-Requisites for Algorithm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "netC = CNN()\n",
    "netC.load_state_dict(torch.load('weights/cnn.pth', map_location=torch.device('cpu')))\n",
    "netC = netC.eval()\n",
    "DATAROOT = 'data'\n",
    "DEVICE = 'cpu'\n",
    "\n",
    "# MEAN_NORM=(0.485, 0.456, 0.406)\n",
    "# STD_NORM=(0.229, 0.224, 0.225)\n",
    "MEAN_NORM=(0.5, 0.5, 0.5)\n",
    "STD_NORM=(0.5, 0.5, 0.5)\n",
    "\n",
    "NUM_NEIGHBORS = 100  # Num neighbors to check for features\n",
    "XP_FEATURE_NUM = 3  # Num box features to show"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class netClassifier(nn.Module):\n",
    "    \n",
    "    def __init__(self, netC):\n",
    "        super(netClassifier, self).__init__()\n",
    "        self.net = netC\n",
    "        \n",
    "    def forward(self, C):\n",
    "        x = self.net.avgpool(C)\n",
    "        x = x.view(-1, 128)\n",
    "        logits = self.net.linear(x)\n",
    "        return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "net_classifier = netClassifier(netC)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "LABELS_DICT = {0:\"0\", 1:\"1\", 2:\"2\", 3:\"3\", 4:\"4\", 5:\"5\", 6:\"6\", 7:\"7\", 8:\"8\", 9:\"9\" }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights = netC.linear.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, test_loader = load_dataloaders()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = train_loader.dataset.data\n",
    "y_train = train_loader.dataset.targets\n",
    "\n",
    "X_test = test_loader.dataset.data\n",
    "y_test = test_loader.dataset.targets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# X_train, y_train, X_test, y_test = get_MNIST_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_c = np.load(DATAROOT + \"/X_train_cont.npy\")\n",
    "X_test_c = np.load(DATAROOT + \"/X_test_cont.npy\")\n",
    "X_train_x = np.load(DATAROOT + \"/X_train_x.npy\")\n",
    "X_test_x = np.load(DATAROOT + \"/X_test_x.npy\")\n",
    "train_preds = np.load(DATAROOT + \"/X_train_y.npy\")\n",
    "test_preds = np.load(DATAROOT + \"/X_test_y.npy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier(algorithm='brute', n_neighbors=1)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Fit COLE and DkNN\n",
    "twin = KNeighborsClassifier(n_neighbors=1, algorithm=\"brute\") \n",
    "twin.fit(X_train_c, train_preds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Choose Explanation Instance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def inverse_normalize(tensor, mean=MEAN_NORM, std=STD_NORM):\n",
    "    for t, m, s in zip(tensor, mean, std):\n",
    "        t.mul_(s).add_(m)\n",
    "    return tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_transformed_data(idxs, loader):\n",
    "    \"\"\"\n",
    "    Takes in indexs and a dataloader\n",
    "    returns: Those indexs transformed\n",
    "    \"\"\"\n",
    "    \n",
    "    subset = torch.utils.data.Subset(loader.dataset, idxs)\n",
    "    loader_subset = torch.utils.data.DataLoader(subset, batch_size=1, num_workers=0, shuffle=False)\n",
    "\n",
    "    transformed_data = list()\n",
    "    for data in loader_subset:\n",
    "        transformed_data.append(data[0])\n",
    "        \n",
    "    return transformed_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "incorrect_preds = list()\n",
    "\n",
    "for i in range(len(test_preds)):\n",
    "\n",
    "    label = y_test[i]\n",
    "    pred = test_preds[i]\n",
    "    \n",
    "    if pred != label:\n",
    "        incorrect_preds.append(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Random Incorrect Classification\n",
    "# rand_int = random.randint(0, len(incorrect_preds))\n",
    "# query_idx = incorrect_preds[rand_int]\n",
    "\n",
    "# Random Classification\n",
    "query_idx = random.randint(0, test_loader.dataset.data.shape[0])\n",
    "query_idx = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "query_label = y_test[query_idx].item()\n",
    "query_pred  = test_preds[query_idx].item()\n",
    "query_img_trans = get_transformed_data([query_idx], test_loader)[0]\n",
    "query_cont = X_test_c[query_idx]\n",
    "\n",
    "xp_idxs = twin.kneighbors(X=[query_cont], n_neighbors=100, return_distance=False)[0]\n",
    "xps_imgs_trans = get_transformed_data(xp_idxs, train_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Examine Explanation with Three Neighbors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAABYEAAAFNCAYAAABWqqBpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAcGUlEQVR4nO3de5jddX0n8M93ZnJPjNwDgQRISCwoZLsKAosFlQ2K1KJNsVYusiB0sVXbVbtsbdkVFHxQ7ipLFQTRImCpi2LXG+7SEMKiXBZBhBAucr8nJiSZzK9/zOR55gkJfGYyJ2fON6/X85znmWTe5/v7zGHmOyfv3+8cStM0AQAAAABAnbraPQAAAAAAAK2jBAYAAAAAqJgSGAAAAACgYkpgAAAAAICKKYEBAAAAACqmBAYAAAAAqJgSuIOUUm4spZywue87mpVSmlLK7IGPv1pK+cww11leStl9ZKcDtkT26leyVwOjkf36lezXwGhkv34l+zXDoQRug1LK0lLKO9s9xzqllFMHfvDX3VaWUvpKKduOwNoHD6y1vJSyrJTy61LKh0di7vU1TXNy0zSfTcz0il8CTdNMbppmSSvmWu/Yy9e7rS2lXNDq4wJDZ6+2V9uroTPYr+3X9mvoDPZr+7X9ur2UwETTNJ8b+MGf3DTN5Ig4KyJubJrmmRE6xGMD674uIj4dEZeUUvZcP1RK6Rmh441a6z3O0yJiZURc3eaxgA5gr9587NXAprBfbz72a2BT2K83H/v16KAEHkVKKVuVUq4vpTxdSnl+4OOd14vNKqUsLqW8VEr551LK1oPu/9ZSysJSygullDtKKQcPY4YSEcdExDc26YvZgKbfdRHxfETsWUo5rpTyr6WUc0opz0bEaaWUcaWUs0spD5dSniz9L2uYMGi+T5ZSHi+lPFZKOX692S8rpZw+6M/vLaXcPvBYPVBKOayUckZEHBQRFw6cfbpwIDv4pRRTSymXD/x3eKiU8rellK6Bzx1XSrlpYMbnSykPllLeNcyH5P0R8VRE/N9h3h9oA3u1vRroDPZr+zXQGezX9ms2DyXw6NIVEZdGxMyImBH9Z0YuXC9zTEQcHxE7RkRvRJwfEVFKmR4R34+I0yNi64j4LxFxbSllu/UPUkqZMbA5ztjADAdFxPYRce1IfEHrHberlHJkRLw+Iu4a+Ov9ImJJROwQEWdExJkRMSci5kXE7IiYHhF/N3D/w6L/6zo0IvaIiI2+jKSUsm9EXB4Rnxw43tsiYmnTNP8t+jeajw6chfroBu5+QURMjYjdI+IPov8xH/yyjf0i4tcRsW1EfCEivjbwCyNKKX9TSrk++ZAcGxGXN03TJPPA6GCvtlcDncF+bb8GOoP92n7N5tA0jdtmvkXE0oh4ZyI3LyKeH/TnGyPizEF/3jMiVkdEd/S/tOCK9e7/LxFx7KD7npA45tci4rIR/FoPjoi+iHghIp6LiNsj4gMDnzsuIh4elC0R8buImDXo7/aPiAcHPv76el//nIhoImL2wJ8vi4jTBz6+OCLO2chMr3gs1q0z8Fiujog9B33upOh/Sci6me8f9LmJA/edNsTHZWZErI2I3dr9/ejm5rbhm73aXm2vdnPrjJv92n5tv3Zz64yb/dp+bb9u76369x3pJKWUiRFxTkQcFhFbDfz1lFJKd9M0awf+/MiguzwUEWOi/yzMzIhYUEo5YtDnx0TEz4Z4/AUR8d7hfQUb9VjTNOu/lGOdwV/PdtG/kdw2cDIpon8z7B74eKeIuG1Q/qFXOeYuEfGDoY8a20b/4zZ47Yei/yzcOk+s+6BpmhUDs04e4nGOjoibmqZ5cBgzAm1kr7ZXA53Bfm2/BjqD/dp+zeahBB5d/joi5kbEfk3TPFFKmRcRv4z+H/51dhn08YyIWBMRz0T/BnJF0zQnbsLxj4z+M1Q3bsIaQzX48v9nov9lH3s1TfPbDWQfj1d+/RvzSETMShxzfc9E/2M6MyJ+Neg4G5pnUxwT/S/3ADqPvdpeDXQG+7X9GugM9mv7NZuB9wRunzGllPGDbj0RMSX6f/BfKP1vcv73G7jfh0opew6cqfofEXHNwJmxb0bEEaWU+aWU7oE1Dy6vfDP1V3NstPF9WZqm6YuISyLinFLK9hH97+9TSpk/EPlORBw36Ovf0OOzztci4sOllHcMvP/O9FLKGwY+92T0v8fNhmZYO3CcM0opU0opMyPir6L/8R0RpZQDov9smv8TJox+9ur12KuBUcp+vR77NTBK2a/XY79mc1ECt88Pon+TW3c7LSLOjYgJ0X8GZlFE/HAD97si+t/v5YmIGB8RfxkR0TTNI9H/0oVTI+Lp6D/788nYwH/j0v9m6MvLoDdDL/1vpv726H8D8Xb6dETcHxGLSikvRcSPo/+MYDRNc0P0P0Y/Hcj8dGOLNE2zOPrfwPyciHgxIn4e/We0IiLOi4g/Lv3/R8vzN3D3v4j+9+NZEhE3RcS3ov89eF5TKeXUUsoNrxE7NiK+2zTNssyaQFvZqzfMXg2MNvbrDbNfA6ON/XrD7Ne0XGnTiQ4AAAAAADYDVwIDAAAAAFRMCQwAAAAAUDElMAAAAABAxZTAAAAAAAAVUwIDAAAAAFSs59U+eWjXgmZzDQIwkn7Ud3Vp9wybk/0a6FT2a4DOYL8G6Awb269dCQwAAAAAUDElMAAAAABAxZTAAAAAAAAVUwIDAAAAAFRMCQwAAAAAUDElMAAAAABAxZTAAAAAAAAVUwIDAAAAAFRMCQwAAAAAUDElMAAAAABAxZTAAAAAAAAVUwIDAAAAAFRMCQwAAAAAUDElMAAAAABAxZTAAAAAAAAVUwIDAAAAAFRMCQwAAAAAUDElMAAAAABAxZTAAAAAAAAVUwIDAAAAAFRMCQwAAAAAUDElMAAAAABAxZTAAAAAAAAVUwIDAAAAAFRMCQwAAAAAUDElMAAAAABAxZTAAAAAAAAVUwIDAAAAAFRMCQwAAAAAUDElMAAAAABAxZTAAAAAAAAVUwIDAAAAAFRMCQwAAAAAUDElMAAAAABAxZTAAAAAAAAVUwIDAAAAAFRMCQwAAAAAUDElMAAAAABAxZTAAAAAAAAVUwIDAAAAAFRMCQwAAAAAUDElMAAAAABAxZTAAAAAAAAVUwIDAAAAAFRMCQwAAAAAUDElMAAAAABAxZTAAAAAAAAVUwIDAAAAAFRMCQwAAAAAUDElMAAAAABAxZTAAAAAAAAVUwIDAAAAAFRMCQwAAAAAUDElMAAAAABAxZTAAAAAAAAVUwIDAAAAAFRMCQwAAAAAUDElMAAAAABAxZTAAAAAAAAVUwIDAAAAAFRMCQwAAAAAUDElMAAAAABAxZTAAAAAAAAVUwIDAAAAAFRMCQwAAAAAUDElMAAAAABAxZTAAAAAAAAVUwIDAAAAAFRMCQwAAAAAUDElMAAAAABAxZTAAAAAAAAVUwIDAAAAAFRMCQwAAAAAUDElMAAAAABAxXraPcBo9+yJ+6ezM46+P52996kd0tnVq8aks9O/nctOfHR5es2+23+VzgKMpP++5LZ09t+Pa80Mc39yYjo79sHx6eyMH/5uOOO8qjFPvpjO9i5ZOuLHHw16dt81nb3v5B3T2UmPlHR2hwsWprNQixeOyT9nvuXMr6Szt69alc7+6Tc+kc5OerRJZ7f5h5vT2U7Ss/P0dHbtDq9v3SAJS/9wajq7emb+e2YoZv/P3nS2LLyjJTPAxjx66gHp7J2nXJjO/uPy7dLZM381P52dcF3+Z7rn5fx+TftNuWpRu0fgVbgSGAAAAACgYkpgAAAAAICKKYEBAAAAACqmBAYAAAAAqJgSGAAAAACgYkpgAAAAAICKKYEBAAAAACqmBAYAAAAAqJgSGAAAAACgYkpgAAAAAICK9bR7gNHuU5/8Vjr7/knP5xeeNYxhMg7OxZb2rkgved7ThwxvFl7T4qdmprOTvjg1ne35yW3DGQdGnafWTklnD/jle9LZQ6ffm87++h2XpLN90ZfOxgn5aNaPV+Yfr6ue3jed7Ws655zxh7a/IZ09ZMLydPaWVWPS2TMumJfOQi2ee2M+u6ZZm87uNTb/z5U7T7wgnV3VrElnP3T0e9PZTrJg2qJ09k8mP9XCSTrDQT8+JZ2durCFg8AGzLw4/9x29q4npbP/Mv/cdPb2fb+Zzvbt26SzrdAVJZ3ti9bMevfq3nT2TWPzz0NbNW/WzZ/vTmc//L2T09k5n7k7ne1btiyd3dJ0zr/qAAAAAAAYMiUwAAAAAEDFlMAAAAAAABVTAgMAAAAAVEwJDAAAAABQMSUwAAAAAEDFlMAAAAAAABVTAgMAAAAAVEwJDAAAAABQMSUwAAAAAEDFeto9wGh3/qkfSGf/bu98p77VPU06+/zvlXR27N4vpHJfeON302ues+Mt6ez3V0xOZw+fuDydbZWVzepU7pZVk9JrHjx+TX6AITy2s486KZ2d85P8CDCaXXzQQens1k/cl87+csrUdPaIOceks/d/cEo62zc+93vgsnddnF5zStfL6exZ03+Qzm7dPS6dbbeuIZzf7hvCusf/08np7KxYNISVoQ7b3Jl/brvqg/nnSuPKmOGMM6LrXj07v1/SWVY1Q3jenv8Wh81u7bPPpbNzPpLPfmzv/5TO9r5ufDq7atux6ewjh6WjMfGhXMW1YmZvftEh6Ho5/zx07leeSWfXbJfvWVrhoVPyG+A9b7s0nb13wUXp7LuvyX8vdt10ezq7pXElMAAAAABAxZTAAAAAAAAVUwIDAAAAAFRMCQwAAAAAUDElMAAAAABAxZTAAAAAAAAVUwIDAAAAAFRMCQwAAAAAUDElMAAAAABAxZTAAAAAAAAV62n3AKPdpGtuGUK2NTO8rgVrXjDt4HT29AN3TWdf9/P709kvHDw7nW2VnpV9qdykOx9Pr7nN/7k2nX3T2DHp7MSl+SzUoveJJ1uybt+yZfnwbXeno7NuG8Ywr+GMmDfyi0bE2oN/P519YY9xLZlhKF56+4pU7u63fb0lx5977qPpbG9LJoDRbeo3F6WzRy49OZ19/MCJwxln1Osbm8++cf6v09lb79k9nR0zeXU6O/b2Sels1u7vXpLOXjv7+yN+/IiIA8/+q3R22pULWzIDjGZ9d96bzg7lCsMJQ8jOuW4I4Q6ydgjZrvyvgSHpmjIlldtzp9b8Lv72sh3S2TH3PJzODuWx3dK4EhgAAAAAoGJKYAAAAACAiimBAQAAAAAqpgQGAAAAAKiYEhgAAAAAoGJKYAAAAACAiimBAQAAAAAqpgQGAAAAAKiYEhgAAAAAoGI97R6A9uh94sl0dtK1+ezaIcww6Zpnh5BurydP2D+d3Wts/sfq7OfmprO7Xrokne1NJ4EtVfeNv0hnt7mxVVPkrdzugFzwba2dA9h0XTfdns5Ov6l1c3SKFz+bz86Jznl+/dQfzG7Jut9Zvn06u+OFi9PZZjjDAIxiTx/1xlTultkXpdfsLvlrTU9b/Ifp7B7P5v/twsa5EhgAAAAAoGJKYAAAAACAiimBAQAAAAAqpgQGAAAAAKiYEhgAAAAAoGJKYAAAAACAiimBAQAAAAAqpgQGAAAAAKiYEhgAAAAAoGJKYAAAAACAivW0ewBop56Zu6RyF556YXrNMaU7nb36vHems9s8fnM6C9AJVs9/czp75YnnpHJdMSa95gG//LN0dutH7ktnAbZk931131Tu3n0uSq+5pmnS2TMvOSqd3al3YToL0Ama/fdJZ//201ekcn2R34OXrlmezu7yHZXk5uZKYAAAAACAiimBAQAAAAAqpgQGAAAAAKiYEhgAAAAAoGJKYAAAAACAiimBAQAAAAAqpgQGAAAAAKiYEhgAAAAAoGJKYAAAAACAiimBAQAAAAAq1tPuAaCd7v3E9FTuLeNKes27V69MZ7f+1Yp0FqA2v/vLF9PZ3xubO29935qX02uO/cZW6SzAlmzVu9+Szl4z/8JUriu602se+9Ch6exOZy9MZwFqc//J+b318In55+JZJ9z3Z+ns+OsXj/jxeXWuBAYAAAAAqJgSGAAAAACgYkpgAAAAAICKKYEBAAAAACqmBAYAAAAAqJgSGAAAAACgYkpgAAAAAICKKYEBAAAAACqmBAYAAAAAqJgSGAAAAACgYj3tHgBG2qrD35LO/uKPz0kmx6XX/POPfSydnbBwcToL0Am6585OZ6/f+7IhrDw2lXrfrSelV5xx9S1DOD5AXXqm7ZDOfvaiL6eze4/tHs44r+qu778hnd05Fo748QHa6ZmP7J/OLn772UNYeXwqddQDh6VXnHhSSWd700lGiiuBAQAAAAAqpgQGAAAAAKiYEhgAAAAAoGJKYAAAAACAiimBAQAAAAAqpgQGAAAAAKiYEhgAAAAAoGJKYAAAAACAiimBAQAAAAAqpgQGAAAAAKhYT7sHgJH28Lvy5zYml3Gp3J8+eGh6zYk/vCOdbdJJgM7wm7+fnM5O7Rqbzv545ZRUbtfTe9Nr9qWTAPV58vDd09l/N3bkrx16w9WnpLNzzr09nbW3A7WZdcx96exWXRPS2a++ODOVW3XcpPSavUuWprNsfq4EBgAAAAComBIYAAAAAKBiSmAAAAAAgIopgQEAAAAAKqYEBgAAAAComBIYAAAAAKBiSmAAAAAAgIopgQEAAAAAKqYEBgAAAAComBIYAAAAAKBiPe0eADK6pkxJZ48+6KZ09qW+l1O5pz63e3rNcatuTWcBanPZW7/eknX/4l8/mMrtcccvWnJ8gE7w8nv2TWd/dNoXh7DyuKEP8xp2u251Otu3YsWIHx+gnX5z0X757G5fSWf7oklnz/3ee1K53ZfcnF6T0c2VwAAAAAAAFVMCAwAAAABUTAkMAAAAAFAxJTAAAAAAQMWUwAAAAAAAFVMCAwAAAABUTAkMAAAAAFAxJTAAAAAAQMWUwAAAAAAAFVMCAwAAAABUrKfdA0DGb07bK529ftsvp7Pv/c37U7lxP7g1vSZAbX77Nweks/uOuy2dfXztqnR225+NS2cBtlSP/MeSzk7uas2+Oud//XkqN/emX6TXbIY7DMBm1D1nVjr7w/d8aQgrT0gnD7lrQTo7+4z/n8r1pVdktHMlMAAAAABAxZTAAAAAAAAVUwIDAAAAAFRMCQwAAAAAUDElMAAAAABAxZTAAAAAAAAVUwIDAAAAAFRMCQwAAAAAUDElMAAAAABAxXraPQBbrhc/9NZ09s6jzk9nH+hdk84uP2vnVG5cPJ5eE6ATdO81N5391slfSmf7ojudnf8Pn0pnZ1y2MJ0FqMnLR+ybzt75vvOGsPKYdHJ536p0do/Lc9mmtze9JkC79EzbIZ3d49sPpbO79YwfzjivacLnp6azfcuWtGQGRi9XAgMAAAAAVEwJDAAAAABQMSUwAAAAAEDFlMAAAAAAABVTAgMAAAAAVEwJDAAAAABQMSUwAAAAAEDFlMAAAAAAABVTAgMAAAAAVEwJDAAAAABQsZ52D0BdeqbvlM5+/DNXpbPjSv5b9QN3HJ3ObnfDreksQE1WTZuczs4d053OPta7Kp2d8b+Xp7MAW6rnjs/vlePKmJbM8KnHDk1ny8I7WjIDQDvc8193S2e/N+2GIaxc0sl9zv9oOjv95wuHMANbGlcCAwAAAABUTAkMAAAAAFAxJTAAAAAAQMWUwAAAAAAAFVMCAwAAAABUTAkMAAAAAFAxJTAAAAAAQMWUwAAAAAAAFVMCAwAAAABUTAkMAAAAAFCxnnYPwOhXevLfJvtc/2g6u2Dys+nslcu2T2d3+Ez+3EZfOglQlwePbVqy7mUv7JcPL7qzJTMAjHbd226Tzl4x79IhrJx/3v7M2pXp7F3nvymdnRqL0lmAdlg9/83p7I1Hnp3O9sWEdHavyz+azu521sJ0Fl6NK4EBAAAAACqmBAYAAAAAqJgSGAAAAACgYkpgAAAAAICKKYEBAAAAACqmBAYAAAAAqJgSGAAAAACgYkpgAAAAAICKKYEBAAAAACqmBAYAAAAAqFhPuwegA+wzNx397PZXtGSEiz63IJ19/R03t2QGgNGuOXBeOvvdg74yhJXz54xvOT4/Q8TdQ8gC1GPNVRPT2b3GtuafbD9dOTOdnXrlopbMADCSuqZMSeXOu/jC9Jo7dk9IZ9c0a9PZabfkszBSXAkMAAAAAFAxJTAAAAAAQMWUwAAAAAAAFVMCAwAAAABUTAkMAAAAAFAxJTAAAAAAQMWUwAAAAAAAFVMCAwAAAABUTAkMAAAAAFAxJTAAAAAAQMV62j0A7dG955x09iP/+M8tmWHPr5+Szu56xaKWzABQkyV/NCGdnTduXDq7plmbznY/9WI625tOAox+PbvOSGe/MOuqIaw8dujDZGb46lHp7LRY2JIZAEbSQ5fOTOX2GpPfV/uiSWf3/+LH09lp19lX2fxcCQwAAAAAUDElMAAAAABAxZTAAAAAAAAVUwIDAAAAAFRMCQwAAAAAUDElMAAAAABAxZTAAAAAAAAVUwIDAAAAAFRMCQwAAAAAUDElMAAAAABAxXraPQDtce9/3iqdPWLiSy2ZYecbV+fDTdOSGQBGu56Zu6SzX/qjb6Sza5q16eyxS9+Zzq596ul0FqAm93xip3R2zpixLZnhsHuOTGd3vGBxOuuZONAuPbvNTGfv2v/yET/+IXctSGd3/vYD6WzvcIaBTeRKYAAAAACAiimBAQAAAAAqpgQGAAAAAKiYEhgAAAAAoGJKYAAAAACAiimBAQAAAAAqpgQGAAAAAKiYEhgAAAAAoGJKYAAAAACAiimBAQAAAAAq1tPuARhZLx+xbyr3kyO+OIRVJw5vGAA22f0f2TmdnT/xxXR28arudPa5j+dnaFbdlc4CdILuObNSuS8ffmmLJ3ltD/5223R2j96HWzgJwMh44KzXpbN90aRyXVHSa049uTed7X3iyXQW2sGVwAAAAAAAFVMCAwAAAABUTAkMAAAAAFAxJTAAAAAAQMWUwAAAAAAAFVMCAwAAAABUTAkMAAAAAFAxJTAAAAAAQMWUwAAAAAAAFetp9wCMrMcO7E7lZvRMbMnxr1y2fTo75qXV6WwznGEAKjB13jMtWXfrrpfT2WW7TUpnJ986nGkARq/VO01N5d4xYUVLjn/fmvxz5jkXeH4N1OWwWfe09fgv/f6O6ezEpQ+3cBLYdK4EBgAAAAComBIYAAAAAKBiSmAAAAAAgIopgQEAAAAAKqYEBgAAAAComBIYAAAAAKBiSmAAAAAAgIopgQEAAAAAKqYEBgAAAAComBIYAAAAAKBiPe0egNHv88/umc7ePH/XdLZ5/K5hTAPASFjVdKezy3bJnzOePJxhAEaxsY+9mMr9bOX49JqHTHg5nV1wyV+ns7v8v4XpLEAnWPSlN6ez8943N5X7D7ssSa856ZEV6WyTTkJ7uBIYAAAAAKBiSmAAAAAAgIopgQEAAAAAKqYEBgAAAAComBIYAAAAAKBiSmAAAAAAgIopgQEAAAAAKqYEBgAAAAComBIYAAAAAKBiSmAAAAAAgIqVpmk2+slDuxZs/JMAo9iP+q4u7Z5hc7JfA53Kfg3QGezXAJ1hY/u1K4EBAAAAACqmBAYAAAAAqJgSGAAAAACgYkpgAAAAAICKKYEBAAAAACqmBAYAAAAAqJgSGAAAAACgYkpgAAAAAICKKYEBAAAAACqmBAYAAAAAqFhpmqbdMwAAAAAA0CKuBAYAAAAAqJgSGAAAAACgYkpgAAAAAICKKYEBAAAAACqmBAYAAAAAqJgSGAAAAACgYv8GHH+vVTgG5jYAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 1800x3600 with 4 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "f, axarr = plt.subplots(1, 4, figsize=(25,50))\n",
    "\n",
    "query_heading = str(\"Label:\" + str(LABELS_DICT[query_label]) + \n",
    "                    \"    Prediction: \" + str(LABELS_DICT[query_pred]))\n",
    "query = query_img_trans.clone()[0].permute(1,2,0) \n",
    "axarr[0].imshow( query )\n",
    "axarr[0].axis('off')\n",
    "axarr[0].set_title( query_heading )\n",
    "\n",
    "for i in range(3):\n",
    "    xp = xps_imgs_trans[i][0].permute(1,2,0)\n",
    "    axarr[i+1].imshow(xp)\n",
    "    axarr[i+1].axis('off')\n",
    "    xp_heading = str(\"Label:\" + str(LABELS_DICT[y_train [ xp_idxs[i].item() ].item() ]) + \n",
    "                    \"    Prediction: \" + str(LABELS_DICT[train_preds[ xp_idxs[i].item() ].item() ]))\n",
    "    axarr[i+1].set_title(xp_heading)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## New Box Explanation Method"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_box_xp(Conv, latent_feature):\n",
    "    max_dist = float('inf')\n",
    "    coord = [None, None]\n",
    "    for i in range(Conv.shape[2]):\n",
    "        for j in range(Conv.shape[3]):\n",
    "            query_seg = Conv[:, :, i:i+1, j:j+1 ]\n",
    "            new_dist = torch.cdist(query_seg.view(-1, Conv.shape[1]),\n",
    "                                   latent_feature.view(-1, Conv.shape[1]), p=1.0).item()\n",
    "            if new_dist < max_dist:\n",
    "                max_dist = new_dist\n",
    "                coord = [i, j]\n",
    "    return coord, max_dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_box_query(Conv, org_logits, net_classifier):\n",
    "    \n",
    "    results = list()\n",
    "    pred = torch.argmax(org_logits, dim=1).item()\n",
    "    \n",
    "    for i in range(Conv.shape[2]):\n",
    "        for j in range(Conv.shape[3]):\n",
    "            C_copy = Conv.clone().detach()\n",
    "            C_copy[:, :, i:i+1, j:j+1 ] = 0\n",
    "            new_logit = net_classifier(C_copy)[0][pred]\n",
    "            logit_change = org_logits[0][pred].item() - new_logit.item()   # Must modify for negative logits\n",
    "            results.append( [logit_change, [i, j]] )\n",
    "            \n",
    "    return sorted(results, key=lambda x: x[0], reverse=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "query_logits, query_x, query_C = netC(query_img_trans)\n",
    "query_nb_boxes = get_box_query(query_C, query_logits, net_classifier)\n",
    "xp_idxs = twin.kneighbors(X=query_x.detach().numpy(), n_neighbors=NUM_NEIGHBORS, return_distance=False)\n",
    "xps_imgs_trans = get_transformed_data(xp_idxs[0], train_loader)\n",
    "xp_feature_meta = [[None, None, float('inf'), []] for _ in range(XP_FEATURE_NUM)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "for nn_idx, img in enumerate(xps_imgs_trans):\n",
    "    \n",
    "    # Get nn data\n",
    "    logits, x, C = netC(img)\n",
    "        \n",
    "    # Search nn for all n features\n",
    "    for i in range(XP_FEATURE_NUM):\n",
    "        \n",
    "        # Get xp feature in query\n",
    "        xp_window_idx = query_nb_boxes[i][1]\n",
    "        xp_feature = query_C[:, :, xp_window_idx[0]: xp_window_idx[0]+1, xp_window_idx[1]: xp_window_idx[1]+1 ]\n",
    "        \n",
    "        # search neighbor for one of n xp features\n",
    "        coord, max_dist = get_box_xp(C, xp_feature)\n",
    "        if max_dist < xp_feature_meta[i][2]:\n",
    "            xp_feature_meta[i][0] = nn_idx\n",
    "            xp_feature_meta[i][2] = max_dist\n",
    "            xp_feature_meta[i][1] = xp_idxs[0][i]\n",
    "            xp_feature_meta[i][3] = coord"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[81, 19704, 33.635894775390625, [3, 4]],\n",
       " [51, 25009, 17.653839111328125, [3, 3]],\n",
       " [56, 34317, 32.83951187133789, [4, 3]]]"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xp_feature_meta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[1.703824520111084, [3, 4]]"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "query_nb_boxes[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAd4AAAI+CAYAAAASMmvcAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAwKklEQVR4nO3deZhU1bnv8d/bdDfSjQKigICIQyBijpgT8UoSc8AhagwnccRoTJzilJOYSZN4QoLRaNQYFcfrTRyOl6gXjXGMidEoURE8DmhQVFREZZJ5bmj6vX/s3Z6i96ruqu6q1QPfz/P4SL2199qreli/vWqv3mXuLgAAEEdFe3cAAICtCcELAEBEBC8AABERvAAARETwAgAQEcELAEBEBG8zzOwmM5uQ/nuMmX1QwrYPMLM3StVeM8cZY2YNZrbGzA6LcLwLzWytmbmZVZb7eADyYwxr1fHKPoaVNXjN7GozW25m08xscE79BDObVMD++5nZI2a2wsyWmdkMMzulBP062cyebmk7dz/L3S9q6/HSY7qZ7ZHT9j/cfXgp2i7AfHfv6e6Ppn0Za2avpl/XpWZ2n5kNKrZRM/tG+rpOb6y5+y8k7VXCvgPthjFsi2MyhpVI2YLXzPaT9BlJAyQ9Leknab2XpPMk/ayF/UdLekLSU5L2kNRX0tmSDi9Xn5scv1uM47ST1yQd6u69JQ2U9JakG4tpwMz6SLpA0qyS9w7oABjDOrROPYaVc8a7q6Sn3b1O0uOSdkvrv5J0hbuvamH/KyTd7u6XufsST7zg7sc1bmBm3zKzOemZ5ANmNjDnOTezs8zsrfSs6HpL7CnpJkmj07cuVqTb32ZmN6Znp2sljU1rF+d2yswuMLMlZjbXzE7MqT+Ze9aUe0ZqZlPT8sz0mOObvu1jZnumbawws1lm9u85z92W9v9hM1ttZtPNbPeWvgH5uPsid5+fU9qsZGAoxqWSJkla0tp+AB0cYxhjWFmUM3hnSTrAzHpIOkjSLDPbV9Jwd/9DczuaWY2k0ZLuaWabA5V84Y6TtJOk9yTd1WSzL0saJWnvdLtD3f11SWdJmpa+ddE7Z/sTlPxSbavkDLepAZJ2kDRI0jcl3WxmLb7V4u5fSP85Mj3m3U1eS5WkByX9VVI/Sd+RNLlJ28dLulBSH0lz0n427v+Qmf2kpX40OeaQ9Bd2vaQfSbq8iH33k7Svkl9+oKtiDEsxhpVW2YLX3f8p6V5Jz0kaouSLMknSd83su2Y21cwmm1nvwO590r4taOYQJ0q6xd1fTM9If6rkDHBozja/dvcV7j5P0t8l7dNCt+9392fcvcHdN+TZZoK717n7U5IeVvLL0Fb7S+qZ9nejuz8h6SFJX8vZ5j53n+Hu9ZImK+e1uPuX3f3XxRzQ3eelv7A7KHnLbHYh+1ny9tUNkv7D3RuKOSbQmTCGFYUxrAhlXVzl7le5+0h3H6/kmzs1PeYZSs4gX1d63aSJ5ZIalJwF5jNQyRli47HWSFqq5Eyu0cKcf69T8oPRnPdbeH65u6/Nefxe2o+2Gijp/SY/BO+pba+lIO6+TNLtku63wlbwnSPpFXd/rhTHBzoyxrCCMYYVIcqfE5lZfyU/qL+U9CklL3qTpOeVvIWyBXdfJ2mapKObaXa+pF1yjlGrZPHChwV0Kd9HMrX0UU190uM0GpL2Q5LWSqrJeW5AAf1oNF/SzmaW+/0YosJeSylUKnl7aLsCtj1I0pFmttDMFkr6rKQrzey6cnYQaE+MYS1iDCtCrL/j/a2kiekP47uSRplZT0ljJL2TZ5/zJZ1sZueZWV9JMrORZtZ4DeROSaeY2T5m1l3SJZKmu/vcAvqzSNJgM6tuxWu50MyqzewAJddfpqT1lyUdZWY1liy5Py1wzN0UNl3JGeD5ZlZlZmMkjVP2ek9JmNlRZjbczCrMbEcl35+X0jNHmdlEM3syz+4nS9pTydtE+0j6byXXbf6zHH0FOgjGMMawkil78KYLCHq7+32S5O4zlFxXeF/SWEnB9/Xd/VlJB6b/vWNmyyTdLOmR9Pm/SZqg5BrMAkm7K7l4X4gnlCycWGhmxaxoW6jkLaT5Sq5RnOXujdcVrpK0UckP5+3p87kmSro9XfG3xTUVd9+o5If0cCUr7G6Q9I2ctptlZn82swuKeB2DJD0qabWkV5W8JXZkzvM7S3omtGN6vWlh439KXvMqd19ZxPGBToMx7GMTxRhWEube0jsT6MzM7AuS/iKpTtJ4d/9LAfu8LOkgd1/aiuP9QtIPJHWXVOvum4ttAwAadcUxjOAFACAi7tUMAEBEBC8AABERvAAARETwAgAQUbN3+Tik4lhWXqHdPNYwxdq7D+jcGMPQnvKNYcx4AQCIiOAFACAighcAgIgIXgAAIiJ4AQCIiOAFACAighcAgIgIXgAAIiJ4AQCIiOAFACAighcAgIgIXgAAIiJ4AQCIiOAFACAighcAgIgIXgAAIiJ4AQCIiOAFACAighcAgIgIXgAAIiJ4AQCIiOAFACAighcAgIgIXgAAIiJ4AQCIiOAFACAighcAgIgIXgAAIiJ4AQCIiOAFACAighcAgIgIXgAAIiJ4AQCIiOAFACAighcAgIgIXgAAIiJ4AQCIiOAFACAighcAgIgq27sDhVj6rdHB+pCT5gTrsxf3z9Q21lUFtx10Z7he88GaYL3h5deCdQAACsGMFwCAiAheAAAiIngBAIiI4AUAICKCFwCAiDrFqubzz/tDsH507fLwDrsX0fiYcHlu/bpg/ZqPxhbReMcwY/EuwXrtlb2C9crHXyhnd4AuoXKnAcH6678emKn1e6w6uO2279WVtE+t9fYx3YN161t4/2p7bgjWXxo1OVjvZuF532ZvyNQmLN4nuO3ML2W/1pJUv2BhsN5RMOMFACAighcAgIgIXgAAIiJ4AQCIqFMsrpp0wfHB+s/3Dp839HndM7Xle1pw2+q9VwTrl3/qj8H6VTtNz9QeXtczuO0RNeHbThZjvW8M1qfX1QbrY7bZlC0G+ixJe4w/M1gf9nhhfQO2Zh8ct1uw/sbB12aLB5e5Mx1YdqlUWvfNBbdxYb+XgvVRx34+WO8/icVVAAAgRfACABARwQsAQEQELwAAERG8AABE1ClWNdfeE16VW3tP4W1sV+Qxrx0wJli/+HNDs20/NSe47eVj9ijyqFmV68NrAmtfWRCs9516b6b2L9VVwW1r5obrAFpWvTr71xOSdN/a7TO1I2uXFdX2gs3rM7WFm8O3dfx0dfnmTysbwreBvH9N4ffl3blqabA+tke47ZCGPGujq/J8Dzo6ZrwAAERE8AIAEBHBCwBARAQvAAAREbwAAETUKVY1t4f6hYuC9dp7s/V8dxytvSe8mq8UFp0+Oljfqzr7Lf3NsuHBbYfe+k6wXt/6bgFbje1vmRas33bPyExt0hf3LKrtHh9l79FetWRdcNtVe/Yuqu1iVK0JryaufvT5gttYfvKRwfozv7qu4DZGv/D1YH3HW8Pfg46OGS8AABERvAAARETwAgAQEcELAEBEBC8AABGxqrmDq9xl52D9ugvCKwKrrFumNuWag4Pb9l3QOVcEAh3Z5lWrMrV895svqt089dpZbW66JKyqOlj/aP98PS/CI9n7X3dmzHgBAIiIGW8Hc6M/pl6q03z1lCTZgm2C2+1y/JJg3QKnUjfPvUGStPPKJVreo1YnHPXD0nQWAJq4YeOf1UsbNN+2lSStu/K54HZVt4U/YS3kd+/ljGHb1Or44zr3GEbwdjC9VKceZbqFRU19nZT9tDEAKJle2lC+MWxTXVnajY3g7WAaZ7o/sjGSpMqdwtd4J9z1p2B9/22y13jP+PnZkqSbH7qhzf0DgOY0znTPq0rWlrzxw32C27057saC2zz9onMkSb+7v2uMYQRvBzf7+4OC9VHdLViftTE7pd3+teRWc1VrN2/xGABKpWLYrpIkeye5PFaxW/K4mIDNp/9zKyVJVavqt3gcvqFlx8fiKgAAIiJ4AQCIiOAFACAighcAgIgIXgAAImJVcwdRd8QoSVLDsy8mjz+bPH7xmKvy7NE9WD373HMztR7Pzkj+4WskSfbszDb0FACyGrZJ4sQrbIvHxTrx3S9mav7628k/6tZv+biTYsYLAEBEBC8AABERvAAARETwAgAQEcELAEBErGruIOYdnpwD1b2x5eOeFl69/LV3DwnWax7Nrlj2EvQPAJrzxmm1kqR1v+22xeNivXPbsEytb9205B+e3J3Z6zr3pxQx4wUAICKCFwCAiAheAAAiIngBAIiIxVXtoGLbbTO1kw54WpLU/9aVWzxe1bAh2MbiS3YL1rvXPV+KLgJAUS47+G5J0u63Lt7icT6zNtYH6/3+9GamtrmNfetomPECABARM94OZvt5a1W9rl6nnprMeLerXB3c7po3/3ewXuGr8ra9u1ZoPd9yAGXU8706Va7brANPTP42sqpbeGb7iYbwHzpevuKBTM19k6SuM4Z1/lfQxaztU122tterUivzfKoRAJTChu0rtU2Z2u4qYxjB28HcOGXsFo+/t334mu25Z54ZrHd/hGu8ANrPXx8YscXjI2uXBbd7K8813gtGfSlT27xkads71oFwjRcAgIiY8baDtybulak9tMMNwW2/8tbRwTozWwDtZeOh+2Zqh/R4Ns/W4beGz3jt68F6nyVvtbZbnQYzXgAAIiJ4AQCIiOAFACAighcAgIgIXgAAImJVcxmt/Pr+wfor4ydlam/Xbwpuu+aywcF6dy1ofccAoA3mHZaNjp4Vxd3YYtPDO+Z5hlXNAACghAheAAAiIngBAIiI4AUAICKCFwCAiFjVXAKVgwYG69+bcHew3t2yX/bjZ54U3HbHP3NPZgDto3LXXYL168fdWnAb8+rXB+sDH5wXrIc/s6hrYcYLAEBEBC8AABERvAAARETwAgAQEYurimSV2S/ZyIc+CG57bM+lwfrk1f0ytf4TwudADUX0DQBK6c0zwwtHD+qxruA2Dn70+8H6sPe33oWjzHgBAIiI4AUAICKCFwCAiAheAAAiIngBAIiIVc3FGjk8U7qo3x1FNXH9Jcdmar1nTmt1lwCgLbr13T5Y/+oXnyu4jYfX9QrWh3/3lWDdC26562HGCwBARAQvAAAREbwAAERE8AIAEBHBCwBARKxqzqPbiGHB+hl33V9wGyNu+XawPvSOwlcKAkC5rZrcO1i/pP9jBbfxg2nHBeufqHuxNV3q0pjxAgAQEcELAEBEBC8AABERvAAARMTiqjxmn9MnWB9Xs6rgNgY/uTH8hG/NN0sD0J42HjYqU3twr2vybN09WL16eXbx6Z4/nh/ctr7gnm09mPECABARwQsAQEQELwAAERG8AABERPACABDRVr+qecO4/YL1x8ddmWePmvJ1BgBKxKqqg/WzJ/2/TK1nRXj18gf164P1Kb/5YqbWZ8G0Inq3dWPGCwBARAQvAAAREbwAAERE8AIAEBHBCwBARFv9qub5n+sWrA+pLHz18uTV/YL1qlXhezVzp2YA5bbh4JHB+pG1ha8+Purl04P1frexgrktmPECABARwQsAQEQELwAAERG8AABERPACABDRVr+quViXLh2RqU07dGhwW1/wapl7A2Br122PXYP1idf9vuA2ptWF/7qj36Xh+z2jbZjxAgAQEcELAEBEBC8AABERvAAARGTu+W9geEjFsdzdEO3msYYp1t59QOe2NY9hKx/ZI1j/x8i7M7V/+8G3g9tue/dzJe3T1ibfGMaMFwCAiAheAAAiIngBAIiI4AUAICKCFwCAiJpd1QwAAEqLGS8AABERvAAARETwAgAQEcELAEBEBC8AABERvAAARETwAgAQEcELAEBEBC8AABERvAAARETwAgAQEcELAEBEBC8AABERvAAARETwAgAQEcHbDDO7ycwmpP8eY2YflLDtA8zsjVK118xxxphZg5mtMbPDIhzvtPRYbmZ7lPt4AMIYv1p1vCjjV1mD18yuNrPlZjbNzAbn1E8ws0kF7L+fmT1iZivMbJmZzTCzU0rQr5PN7OmWtnP3s9z9orYeLz3mFt9Id/+Huw8vRdsFmO/uPd390bQvY83s1fTrutTM7jOzQYU0ZGY7mNkz6X4r0u/t5xqfd/ffu3vPcr0QIBbGry2OyfhVQmULXjPbT9JnJA2Q9LSkn6T1XpLOk/SzFvYfLekJSU9J2kNSX0lnSzq8XH1ucvxuMY7TTl6TdKi795Y0UNJbkm4scN81kk6VtKOkPpIuk/SgmVWWoZ9Au2D86tA6/fhVzhnvrpKedvc6SY9L2i2t/0rSFe6+qoX9r5B0u7tf5u5LPPGCux/XuIGZfcvM5qRnkw+Y2cCc59zMzjKzt9Izm+stsaekmySNTt9SWJFuf5uZ3Zieoa6VNDatXZzbKTO7wMyWmNlcMzsxp/6kmZ2e8/jjs1Izm5qWZ6bHHN/0rR8z2zNtY4WZzTKzf8957ra0/w+b2Wozm25mu7f0DcjH3Re5+/yc0mYlg0Mh+25w9zfcvUGSpfv2kbR9a/sDdECMX4xfZVPO4J0l6QAz6yHpIEmzzGxfScPd/Q/N7WhmNZJGS7qnmW0OlHSppOMk7STpPUl3Ndnsy5JGSdo73e5Qd39d0lmSpqVvX/TO2f4EJb9Y2yo5y21qgKQdJA2S9E1JN5tZi2+3uPsX0n+OTI95d5PXUiXpQUl/ldRP0nckTW7S9vGSLlTyQzIn7Wfj/g+Z2U9a6keTYw5Jf2nXS/qRpMuL3P8VSRskPSDpd+6+uJj9gQ6O8SvF+FV6ZQted/+npHslPSdpiJIvzCRJ3zWz75rZVDObbGa9A7v3Sfu2oJlDnCjpFnd/MT0r/amSs8ChOdv82t1XuPs8SX+XtE8L3b7f3Z9x9wZ335BnmwnuXufuT0l6WMkvRFvtL6ln2t+N7v6EpIckfS1nm/vcfYa710uarJzX4u5fdvdfF3NAd5+X/tLuoORts9lF7r+3pO2U/LK3eL0J6EwYv4rC+FWksi6ucver3H2ku49X8g2emh7zDCVnka8rvXbSxHJJDUrOBPMZqOQssfFYayQtVXI212hhzr/XKfnhaM77LTy/3N3X5jx+L+1HWw2U9H769kdu2215LQVx92WSbpd0vxV5nSN92+ZOST8xs5Gl6A/QUTB+FYzxq0hR/pzIzPor+WH9paRPSXrF3TdJel7J2yhbcPd1kqZJOrqZZudL2iXnGLVKFjB8WECXvMh6oz7pcRoNSfshSWsl1eQ8N6CAfjSaL2lnM8v9fgxRYa+lFCqVvEW0XSv3r9L/XAMDuhTGrxYxfhUp1t/x/lbSxPQH8l1Jo8ysp6Qxkt7Js8/5kk42s/PMrK8kmdlIM2u8DnKnpFPMbB8z6y7pEknT3X1uAf1ZJGmwmVW34rVcaGbVZnaAkmswU9L6y5KOMrMaS5bdnxY4Zr5v7nQlZ4Hnm1mVmY2RNE7Zaz4lYWZHmdlwM6swsx2VfH9eSs8eZWYTzezJPPvub2afT78GPczsx5L6p68B6IoYvxi/SqrswZsuIujt7vdJkrvPUHJt4X1JYyUF39t392clHZj+946ZLZN0s6RH0uf/JmmCkuswCyTtruQCfiGeULJ4YqGZLSni5SxU8jbSfCXXKc5y98ZrC1dJ2qjkB/T29PlcEyXdnq762+K6irtvVPKDerikJZJukPSNnLabZWZ/NrMLingdgyQ9Kmm1pFeVvC12ZM7zO0t6Js++3SVdr+RtsQ8lfUnSEU1WGQJdAuPXxyaK8atkzL2ldyfQmZnZFyT9RVKdpPHu/pcC9nlZ0kHuvrQVxztFyS/xNpJGuHu+GQEANKurjl8ELwAAEXGvZgAAIiJ4AQCIiOAFACAighcAgIiavdPHIRXHsvIK7eaxhinW3n1A58YYhvaUbwxjxgsAQEQELwAAERG8AABERPACABARwQsAQEQELwAAERG8AABERPACABARwQsAQEQELwAAERG8AABERPACABARwQsAQEQELwAAERG8AABERPACABARwQsAQEQELwAAERG8AABERPACABARwQsAQEQELwAAERG8AABERPACABARwQsAQEQELwAAERG8AABERPACABARwQsAQEQELwAAERG8AABERPACABARwQsAQEQELwAAERG8AABERPACABARwQsAQEQELwAAERG8AABEVNneHSjE0m+NDtaHnDQnWJ+9uH+mtrGuKrjtoDvD9ZoP1gTrDS+/FqwDQD5WVR2szzt/32B9/eD6cnYnY7udVgfrM/e7M1gfNvUbmdqm5dsEt7WN4fndJy95J1jfvGhxsN6VMOMFACAighcAgIgIXgAAIiJ4AQCIiOAFACCiTrGq+fzz/hCsH127PLzD7kU0PiZcnlu/Lli/5qOxRTTeMcxYvEuwXntlr2C98vEXytkdYKuzcezewfrMc66N3JPibPJwfdYBt7a57Q++uj5YP+iBH2Zqn/zP14Pbbl61qs39aA/MeAEAiIjgBQAgIoIXAICICF4AACLqFIurJl1wfLD+873D5w19Xs+uCFi+pwW3rd57RbB++af+GKxftdP0TO3hdT2D2x5RE77tZDHW+8ZgfXpdbbA+ZptN2WKgz5K0x/gzg/VhjxfWNwAd1wNr+wTrz67eo+A2htcsDNZP2e79VvUp1+DKHsH6G0fdkKl9ss/pwW33+PpLbe5He2DGCwBARAQvAAAREbwAAERE8AIAEBHBCwBARJ1iVXPtPeFVubX3FN7GdkUe89oBY4L1iz83NNv2U3OC214+pvDVg/lUrm8I1mtfWRCs9516b6b2L9VVwW1r5obrAEprmxffDdYP/M45ZTvmtrOWBuub3wiPVyGzB386WL9jv3EFtzH/gPBflMw+7vqC2xi163vBep6bBnd4zHgBAIiI4AUAICKCFwCAiAheAAAiIngBAIioU6xqbg/1CxcF67X3Zuub87RRe094VWEpLDp9dLC+V3X2W/qbZcOD2w699Z1gvb713QIQsHlJeCyo+WP5xoh841Ix6j/4MFivyVMP2WXNvuEnjmtNj7oGZrwAAERE8AIAEBHBCwBARAQvAAAREbwAAETEquYOrnKXnYP16y64Llivsm6Z2pRrDg5u23fBtNZ3DAAKMPeItsfMnNvCf5nRV51zDCN4W+lGf0y9VKf56lnSdgdqjVaqu862Q0raLgA0Ktf4JTGGFYLgbaVeqlOPMvzFaznaBIBc5Rq/JMawQhC8rdR4pvgjG1PSdn/jT5a0PQBoqlzjl8QYVggWVwEAEBEz3g5u9vcHBeujuoc/XHrWxvWZ2vavrStpnwCgqfd/9llJUt1/vZQ8/kby+IWjfpNnj+7B6if/fnqm9olbnw9u60X2saNgxgsAQEQELwAAERG8AABERPACABARwQsAQESsau4g6o4YJUlqePbF5PFnk8cvHnNVnj3CKwLPPvfcTK3HszPa3kEASFVsu22mNvSguZKkbR7YsMXjnhXhsSqfXs9sk6l5fde6KQczXgAAIiJ4AQCIiOAFACAighcAgIgIXgAAImJVcwcx7/DkHKjujS0f97TwisCvvRv+rMuaR2dmap31fqYAOqaldw3I1J4ZdpckqapmqSTp/mEPNtvG71cOCdb7/2N5ptZQbAc7OGa8AABERPACABARwQsAQEQELwAAEbG4qh2Ebrd20gFPS5L637pyi8erGjYE21h8yW7Beve68AdGA0CxKgcNDNZ/9Im/FtzGks3rg/UHjv18sN4wa3bBbXdWzHgBAIiIGW8rDdQa9VC9fuNPFr/z2uyXfbtTN0mSBsxeqY01fFsAlE+h45ctDv8544ATV2ZqVd2SDzKwWXVSDXO65jDCt9LKPJ8O1FYbayq1tk91WdoGAKl845ckqaZC3rdb+drvAgjeVjrbwjewKERFbfYa7363LGtLdwCgYIWOX5X9wtd4vzn5mUztyFrGsELxfgAAABEx420Hb03cK1N7aIcbgtt+5a2jg/Xuj7B6GUB5vX3mLsH6kbXN3w4y17o896zdPOuN1nSpS2DGCwBARAQvAAAREbwAAERE8AIAEBHBCwBARKxqLqOVX98/WH9l/KRM7e36TcFt11w2OFjvrgWt7xgA5Ng89l+D9b+dfEWePXoU3PaRk84P1nfSswW30dUw4wUAICKCFwCAiAheAAAiIngBAIiI4AUAICJWNZdA5aDwJ3h8b8LdwXp3y37Zj595UnDbHf/MPZkBlNeH/7ZNsN6/W+Grl/d88vRgffffTmtVn7oyZrwAAERE8AIAEBHBCwBARAQvAAARsbiqSFaZ/ZKNfOiD4LbH9lwarE9e3S9T6z8hfA7UUETfAKAlH/74s5nay9+6Js/W4XFpnW/M1Pr8LbxAS+6Fdm2rwYwXAICICF4AACIieAEAiIjgBQAgIoIXAICIWNVcrJHDM6WL+t1RVBPXX3JsptZ7JrdVA1A63XboG6yf/c0HM7WKIudgYy79YabW79at94Pti8WMFwCAiAheAAAiIngBAIiI4AUAICKCFwCAiFjVnEe3EcOC9TPuur/gNkbc8u1gfegdz7WqTwCQUdEtWH7r/DxjWK+/Ftz0nE11wfqAp5ZlatxXvnDMeAEAiIjgBQAgIoIXAICICF4AACJicVUes8/pE6yPq1lVcBuDn8x+WLQkPhgaQMnYyE8G66+deF3Bbby5KTxWnTohe2tISer9T25x2xbMeAEAiIjgBQAgIoIXAICICF4AACIieAEAiGirX9W8Ydx+wfrj467Ms0dN+ToDAEWac+K2bW7j6kUHB+u972D1cjkw4wUAICKCFwCAiAheAAAiIngBAIiI4AUAIKKtflXz/M+FP0R6SGXhq5cnr+4XrFetCt//lDs1AyiV6756a5vbeP7/jgzW++vZNreNLGa8AABERPACABARwQsAQEQELwAAERG8AABEtNWvai7WpUtHZGrTDh0a3NYXvFrm3gDY2v3il6cG6wddel2w/qXZX83U+k9fXcouoQXMeAEAiIjgBQAgIoIXAICICF4AACIy9/w3MDyk4ljuboh281jDFGvvPqBzYwxDe8o3hjHjBQAgIoIXAICICF4AACIieAEAiIjgBQAgomZXNQMAgNJixgsAQEQELwAAERG8AABERPACABARwQsAQEQELwAAERG8AABERPACABARwQsAQEQELwAAERG8AABERPACABARwQsAQEQELwAAERG8AABERPA2w8xuMrMJ6b/HmNkHJWz7ADN7o1TtNXOcMWbWYGZrzOywCMe70MzWmpmbWWW5jwcgP8awVh2v7GNYWYPXzK42s+VmNs3MBufUTzCzSQXsv5+ZPWJmK8xsmZnNMLNTStCvk83s6Za2c/ez3P2ith4vPaab2R45bf/D3YeXou0CzHf3nu7+aKBftzTtW3PMbISZ/Xf6fV1uZn8zsxGNz7v7LyTtVcK+A+2GMWyLYzKGlUjZgtfM9pP0GUkDJD0t6SdpvZek8yT9rIX9R0t6QtJTkvaQ1FfS2ZIOL1efmxy/W4zjtCcz+7yk3Yvcbb6kYyRtL2kHSQ9IuqvEXQPaHWNYx9dZx7Byznh3lfS0u9dJelzSbmn9V5KucPdVLex/haTb3f0yd1/iiRfc/bjGDczsW2Y2Jz2TfMDMBuY852Z2lpm9lZ5tXm+JPSXdJGl0+tbFinT728zsxvTsdK2ksWnt4txOmdkFZrbEzOaa2Yk59SfN7PScxx+fkZrZ1LQ8Mz3m+KZv+5jZnmkbK8xslpn9e85zt6X9f9jMVpvZdDMr9odtC+lbKNdK+k4x+7n7Cnef6+4uySRtVjKoAF0NYxhjWFmUM3hnSTrAzHpIOkjSLDPbV9Jwd/9DczuaWY2k0ZLuaWabAyVdKuk4STtJek/Zs5YvSxolae90u0Pd/XVJZ0malr510Ttn+xOU/FJtq+QMt6kBSs6QBkn6pqSbzazFt1rc/QvpP0emx7y7yWupkvSgpL9K6qfkB2lyk7aPl3ShpD6S5qT9bNz/ITP7SUv9aOL7kqa6+ytF7td4zBWSNij5wb+kNW0AHRxjWIoxrLTKFrzu/k9J90p6TtIQSZdLmiTpu2b2XTObamaTzax3YPc+ad8WNHOIEyXd4u4vpmekP1VyBjg0Z5tfp2c38yT9XdI+LXT7fnd/xt0b3H1Dnm0muHuduz8l6WElvwxttb+knml/N7r7E5IekvS1nG3uc/cZ7l4vabJyXou7f9ndf13owcxsZ0lnSvp5azuc/rL3kvQfkl5qbTtAR8UYVhTGsCKUdXGVu1/l7iPdfbySb+7U9JhnKDmDfF3pdZMmlktqUHIWmM9AJWeIjcdaI2mpkjO5Rgtz/r1OyQ9Gc95v4fnl7r425/F7aT/aaqCk9929oUnbbXktzbla0i/dfWUb2lD6tbhJ0n+ZWb+2tAV0RIxhBWMMK0KUPycys/5KflB/KelTkl5x902SnlfyFsoW3H2dpGmSjm6m2fmSdsk5Rq2SxQsfFtAlL7LeqE96nEZD0n5I0lpJNTnPDSigH43mS9rZzHK/H0NU2GtpjYMkXWFmC82s8Zdhmpmd0Iq2KpS87kEtbQh0VoxhLWIMK/KAMfxW0sT0h/FdSaPMrKekMZLeybPP+ZJONrPzzKyvJJnZSDNrvAZyp6RTzGwfM+uu5D366e4+t4D+LJI02MyqW/FaLjSzajM7QMn1lylp/WVJR5lZjSXL2k8LHHM3hU1XcgZ4vplVmdkYSeNUvpV2wySNVPJWzz5pbZyk+6SPF0LcFtrRzA4xs0+bWTcz207J93a5kjN/oKtiDGMMK5myB2+6gKC3u98nSe4+Q8l1hfcljZUUfF/f3Z+VdGD63ztmtkzSzZIeSZ//m6QJSq7BLFCypPz4Arv1hJKFEwvNbEkRL2ehkm/QfCXXKM5y99npc1dJ2qjkh/P29PlcEyXdnq742+KairtvVPJDc7ikJZJukPSNnLabZWZ/NrMLCn0R7r7Y3Rc2/peWl7j7+vTfO0t6Js/uvZUMGCslva3k635YM9eTgE6NMexjE8UYVhKWrKhGV2VmX5D0F0l1ksa7+19a2L5a0kxJe6dvpRV7vF9I+oGk7pJq3X1z8b0GgERXHMMIXgAAIuJezQAARETwAgAQEcELAEBEBC8AABE1+1mDh1Qcy8ortJvHGqZYe/cBnRtjGNpTvjGMGS8AABERvAAARETwAgAQEcELAEBEBC8AABERvAAARETwAgAQEcELAEBEBC8AABERvAAARETwAgAQEcELAEBEBC8AABERvAAARETwAgAQEcELAEBEBC8AABERvAAARETwAgAQEcELAEBEBC8AABERvAAARETwAgAQEcELAEBEBC8AABERvAAARETwAgAQEcELAEBEBC8AABERvAAARETwAgAQEcELAEBEBC8AABERvAAARETwAgAQEcELAEBEBC8AABERvAAARETwAgAQUWV7d6AQS781OlgfctKcYH324v6Z2sa6quC2g+4M12s+WBOsN7z8WrAOAMXKN7Zt6mmRexI287wbgvVNvrnNbY946rRgveb5mkxt8J8+DG5b/+57be5He2DGCwBARAQvAAAREbwAAERE8AIAEBHBCwBARObueZ88pOLY/E9GdNqb7wbrR9cuL9sx59avC9av+Whs2Y5ZLjMW7xKs117ZK1ivfPyFcnanYI81TOkYSzvRaXWUMaxb/37B+llPPx2sH9xjRaZWZd2C25ZihXE+HeWYFy/ZO7jtn279t2B9wNXPlrRPrZVvDGPGCwBARAQvAAAREbwAAERE8AIAEFGnuGXkpAuOD9Z/vnf4vKHP69n1FMv3DK/Tqd57RbB++af+GKxftdP0TO3hdT2D2x5RE77tZDHW+8ZgfXpdbbA+ZptN2WKgz5K0x/gzg/VhjxfWNwCFeePHuwXrB/d4KHJPOqcf930pWJ+88wGRe1IazHgBAIiI4AUAICKCFwCAiAheAAAiIngBAIioU6xqrr0nvCq39p7C29iuyGNeO2BMsH7x54Zm235qTnDby8fsUeRRsyrXNwTrta8sCNb7Tr03U/uX6qrgtjVzw3UApTX8sneC9b+N6x2sh24Zmc8lSz4TrM9Zu2PBbeRTofAdNxuU/SuRVx79ZHDb3m+Gx7CJv/p9sD62x4YCe9d5MeMFACAighcAgIgIXgAAIiJ4AQCIiOAFACCiTrGquT3UL1wUrNfem63n+0jo2nuWlrBHW1p0+uhgfa/q7Lf0N8uGB7cdemt4pWV967sFIGDzosXB+i8v+2awPqE2cG/58O3mNfhPHwTr9XPnFdS3UtlZxX34/LnHhO/B/8pnbytBbzo2ZrwAAERE8AIAEBHBCwBARAQvAAAREbwAAETEquYOrnKXnYP16y64Llivsm6Z2pRrDg5u23fBtNZ3DECb9f1d238HO/pfIVTuukuw/vkh4b+qCI1h967ZIbjt7j98rvUda0fMeAEAiIjgBQAgIoIXAICICF4AACJicVUHN/v7g4L1Ud3D94+btXF9prb9a+tK2icAKNSigwcG61MGTQnWN3l2cdWvfv+14LYDi7xNZUfBjBcAgIgIXgAAIiJ4AQCIiOAFACAighcAgIhY1dxB1B0xKlh/8Zir8uzRPVg9+9xzM7Uez85obbcAoCDrjvxfwfrvLrg6zx7hv8y4eMnemdrODy0Jbru5kI51QMx4AQCIiOAFACAighcAgIgIXgAAIiJ4AQCIiFXNHcS8w8PnQD0tvHr5a+8eEqzXPDozU/PWdwsAMipqazO1JSeE7wk/rCq8ejmfOx/5Qqa262vTimqjo2PGCwBARAQvAAAREbwAAERE8AIAEBGLq9pBxbbbZmonHfB0cNtVDRuC9cWX7Basd697vvUdA4AC2NDBmdoLo39fVBt7Tz0jWN/9p11rIVUIM14AACJixttKN/pj6qU6zVfP4ndem/2yb3fqJknS9vPWam2fat04ZWxbuwgAQW0avyTZO9kPXulxzKLkuXfr5X0rtP6vO7Wpj10ZwdtKvVSnHqovebvV60rfJgDkKtf4JUm2rqEs7XYlBG8rNZ4p/sjGFL1vRW32Gu9+tyyTJJ16avhaLwCUSlvGL0nqttvwTO2ee+6Q9D8zX+THNV4AACJixtsO3pq4V6b20A43SJK6Va2VJP18h1clSV956+hgG90fYfUygPKq2GdEsH7EH/6RqVVZN0mSpR9w3/g4n9oZNW3sXefFjBcAgIgIXgAAIiJ4AQCIiOAFACAighcAgIhY1VxGK7++f7D+yvhJmdrb9cmdq4Z48sfn8+rXS5LWXJa9J6okddeCUnQRAPKaP7Z3sH7Kdm9naps8+X+lPH28WZJ0wtvjgm0M/tMHwfrWcAshZrwAAERE8AIAEBHBCwBARAQvAAAREbwAAETEquYSqBw0MFj/3oS7g/Xulv2yHz/zJEnS79Ykn/Bxevp4xz9zT2YA7aPXYYX/9cSp7x0mSfrFhgclSRemjzceE/6YwM0fzWtj7zovZrwAAERE8AIAEBHBCwBARAQvAAARsbiqSFaZfsnq7ePHIx8K3/rs2J5Lg/XJq/tlav0nJOdAVe/aFo/DyxIAoHW69e8nW1qd/LtvMhat/uyuwW2f+NQNwXrj7SFz/fdzw5K2VvXY4vHuHz3Xpv52Rcx4AQCIiBlvKw30Neqhel1R/7i2PS1w+iepW/WaYP2wzR9lap9+K1m2v/v6hVpfUV26jgJAEzttXqkevkmXL/2TJKl+6jbB7aqODv85UeMHIeT6w+Jkdjziw/la2717aTraRRG8rbRS5fnBWl9RrZWVNWVpGwAkaWVFj7Jdx1rbvVrLamvL03gXQfC20jlVh33873/9/cbgNhf1ezlYfzRwjffO8YeUpF8A0JL/2OG4LR7nvcZ7Xb5rvJsztRPuPqftHdtKcI0XAICImPEWa+TwTOmifncU1cT1lxybqfWeOa3VXQKAQoVmt3+59to8W3crb2e2Usx4AQCIiOAFACAighcAgIgIXgAAIiJ4AQCIiFXNeXQbMSxYP+Ou+wtuY8Qt3w7Wh97BvUsBtI8Fn83Ot6osvHo5X/3T07+Rqe3+Q8a1QjHjBQAgIoIXAICICF4AACIieAEAiIjFVXnMPqdPsD6uZlXBbQx+MvzhCfLwxwgCQLGsKvwxou/9577B+pPjL8/UNnm4ja+8OS5YH3J29qNNsx+bgHyY8QIAEBHBCwBARAQvAAAREbwAAERE8AIAENFWv6p5w7j9gvXHx12ZZ4+a8nUGAIpUMSz7wfaS9MLpV+fZI7yCOeTdJX2D9SGLXi24DWQx4wUAICKCFwCAiAheAAAiIngBAIiI4AUAIKKtflXz/M+FP+h5SGXhq5cnr+4XrFetCt+rmTs1A+gMdvve0mC9PnI/uhpmvAAARETwAgAQEcELAEBEBC8AABERvAAARLTVr2ou1qVLR2Rq0w4dGtzWF3A/UwAdy00rPpmp3f5/DgtuO+DDZ8vdna0SM14AACIieAEAiIjgBQAgIoIXAICIzD3/DQwPqTiWuxui3TzWMMXauw/o3BjD0J7yjWHMeAEAiIjgBQAgIoIXAICICF4AACIieAEAiKjZVc0AAKC0mPECABARwQsAQEQELwAAERG8AABERPACABARwQsAQET/HwO/cCcph0VSAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 720x720 with 6 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "unit = X_train.shape[1] / query_C.shape[2]\n",
    "\n",
    "f, axarr = plt.subplots(XP_FEATURE_NUM, 2, figsize=(10,10))\n",
    "\n",
    "for coord_idx in range(XP_FEATURE_NUM):\n",
    "\n",
    "    coord = query_nb_boxes[coord_idx][1]\n",
    "    query = query_img_trans.clone()[0]\n",
    "    query = inverse_normalize(query).permute(1,2,0) \n",
    "    axarr[coord_idx, 0].imshow( query )\n",
    "\n",
    "    x = (coord[1] * unit)\n",
    "    y = (coord[0]+1) * unit\n",
    "    axarr[coord_idx, 0].plot([x, x], [y, y-unit], color='red')\n",
    "    axarr[coord_idx, 0].plot([x, x+unit], [y, y], color='red')\n",
    "    axarr[coord_idx, 0].plot([x+unit, x+unit], [y, y-unit], color='red')\n",
    "    axarr[coord_idx, 0].plot([x, x+unit], [y-unit, y-unit], color='red')\n",
    "    axarr[coord_idx, 0].axis('off')\n",
    "    axarr[coord_idx, 0].title.set_text('% Contribution: ' + str( coord ) )\n",
    "\n",
    "    \n",
    "    coord  = xp_feature_meta[coord_idx][3]\n",
    "    nn_img = xps_imgs_trans[ xp_feature_meta[coord_idx][0] ].clone()[0]\n",
    "    nn_img = inverse_normalize(nn_img).permute(1,2,0) \n",
    "    axarr[coord_idx, 1].imshow( nn_img )\n",
    "\n",
    "    x = coord[1] * unit\n",
    "    y = ((coord[0]+1) * unit)\n",
    "    axarr[coord_idx, 1].plot([x, x], [y, y-unit], color='red')\n",
    "    axarr[coord_idx, 1].plot([x, x+unit], [y, y], color='red')\n",
    "    axarr[coord_idx, 1].plot([x+unit, x+unit], [y, y-unit], color='red')\n",
    "    axarr[coord_idx, 1].plot([x, x+unit], [y-unit, y-unit], color='red')\n",
    "    axarr[coord_idx, 1].axis('off')\n",
    "    axarr[coord_idx, 1].title.set_text('% Contribution: ' + str( coord ) )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Look For Qantiative Evaluation Methods"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gradient x Input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "# box = query_nb_boxes[0][1]\n",
    "# box"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "# query_img_trans.requires_grad = True\n",
    "# netC = netC.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "# logits, x , C = netC(query_img_trans)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "# out_seg = C[:, :, box[0]:box[0]+1, box[1]:box[1]+1 ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "# grads = np.zeros((query_img_trans.shape[2], query_img_trans.shape[2]))\n",
    "\n",
    "# for i in range(out_seg.shape[1]):\n",
    "#     grad = torch.autograd.grad(out_seg[0][i], query_img_trans, retain_graph=True, create_graph=True)\n",
    "#     grads += grad[0][0][0].detach().numpy() * query_img_trans[0][0].detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.imshow(grads)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Black out Input Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "def blackout_C():\n",
    "\n",
    "    box = query_nb_boxes[0][1]\n",
    "    \n",
    "    unit = 4\n",
    "    filter_size = 7\n",
    "\n",
    "    for i in range(filter_size):\n",
    "        for j in range(filter_size):\n",
    "\n",
    "            img_copy = query_img_trans.clone()\n",
    "            img_copy[:, :, i*unit: i*unit+unit, j*unit: j*unit+unit] = -1.0\n",
    "\n",
    "            logits, x, C = netC(img_copy)\n",
    "\n",
    "            if i == box[0] and j == box[1]:\n",
    "                print(logits[0][query_pred].item(), [i, j])\n",
    "                plt.imshow(img_copy[0][0].detach().numpy())\n",
    "                plt.show()\n",
    "\n",
    "                return C, i, j"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4.3436455726623535 [3, 4]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAANN0lEQVR4nO3df4wc9XnH8c8n9vkcn0GxTXBd48aUElQnIqY6GSJQ5RaVEqrEoFYWVps6EuHSCtSgRlUprRqr6g+r+SVaVamc4MaJHFAEQVjCaeNaSSlt6viMjH+mNSFG2D37YtzGJhT/wE//uHF0wO3csTO7s/bzfkmr3Z1nZ+fR2J+bmZ3d+ToiBODi97amGwDQHYQdSIKwA0kQdiAJwg4kMb2bC5vh/pipgW4uEkjlVf1Yp+OUJ6pVCrvtWyU9KGmapC9GxNqy18/UgK73zVUWCaDEttjastb2brztaZL+TtIHJC2RtMr2knbfD0BnVTlmXybpuYh4PiJOS3pE0op62gJQtyphXyjpxXHPDxXTXsf2kO1h28NndKrC4gBU0fFP4yNiXUQMRsRgn/o7vTgALVQJ+2FJi8Y9v6KYBqAHVQn7dklX277S9gxJd0raVE9bAOrW9qm3iDhr+15J/6SxU2/rI2JvbZ0BqFWl8+wRsVnS5pp6AdBBfF0WSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kESlIZttH5R0UtJrks5GxGAdTQGoX6WwF34pIo7V8D4AOojdeCCJqmEPSd+0vcP20EQvsD1ke9j28Bmdqrg4AO2quht/U0Qctn25pC22vxcRT41/QUSsk7ROki713Ki4PABtqrRlj4jDxf2opMclLaujKQD1azvstgdsX3L+saRbJO2pqzEA9aqyGz9f0uO2z7/PVyPiH2vpCkDt2g57RDwv6X019gKggzj1BiRB2IEkCDuQBGEHkiDsQBJ1/BAmhZfufn/L2uLfPlA67/7R+aX1U6f6SusLvzqjtD7r0Msta+d27iudF3mwZQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJDjPPkV/9AcbW9Z+ffaJ8pmvqrjw5eXlH5xpfZ79wWOTzHwR2zb6rpa1WZ9+R+m807fuqLmb5rFlB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkHNG9QVou9dy43jd3bXl1+vFvXN+yduza8r+Zc/aXr+P/+XmX1vuv/d/S+qfe+2jL2i2zzpTO++QrM0vrvzbr1dJ6Fa+cO11a/+6p8t6Wv/1c28u+8sm7S+vvvnt72+/dpG2xVSfi+IT/odiyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAS/J59igYe3VZSq/bel1abXQ/+1C+3rP3ZjYvLl/0vz5XW/2r5z7XT0pRM/7/y8+QDu0ZK63P/tXzFXzuj9Xn6WT8ov1b/xWjSLbvt9bZHbe8ZN22u7S22DxT3czrbJoCqprIb/yVJt75h2v2StkbE1ZK2Fs8B9LBJwx4RT0k6/obJKyRtKB5vkHR7vW0BqFu7x+zzI+L8AdURSS0HM7M9JGlIkmZqVpuLA1BV5U/jY+yXNC1/6RER6yJiMCIG+9RfdXEA2tRu2I/aXiBJxf1ofS0B6IR2w75J0uri8WpJT9TTDoBOmfSY3fbDGrty+WW2D0n6pKS1kr5m+y5JL0ha2ckmUe7skaMtawOPta5J0muTvPfAoy+10VE9jn70/aX19/SVj1v/qeOtL9i/+B+eL533bGn1wjRp2CNiVYvShXkVCiApvi4LJEHYgSQIO5AEYQeSIOxAElxKGhetl0pO3c374ne62En3cClpAIQdyIKwA0kQdiAJwg4kQdiBJAg7kASXksZFa+6+V5puoaewZQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJDjPjouW//3ZplvoKWzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kMWnYba+3PWp7z7hpa2wftr2zuN3W2TYBVDWVLfuXJN06wfTPRcTS4ra53rYA1G3SsEfEU5KOd6EXAB1U5Zj9Xtu7it38Oa1eZHvI9rDt4TM6VWFxAKpoN+yfl3SVpKWSRiR9ptULI2JdRAxGxGCf+ttcHICq2gp7RByNiNci4pykL0haVm9bAOrWVthtLxj39A5Je1q9FkBvmPT37LYflrRc0mW2D0n6pKTltpdKCkkHJX2scy2il73tkktK6zc8/VLL2n3zdpTOu2Lo90rr/Zu3l9bxepOGPSJWTTD5oQ70AqCD+AYdkARhB5Ig7EAShB1IgrADSXApaVRyYM17SuvfeOfft6x96MAdpfNyaq1ebNmBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnOs6PUj37rhtL63jv/trT+/TOtL0V2Yu2i0nn7daS0jreGLTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJMF59uSmL/zp0vrv/+nDpfV+95XWV+76zZa1y77B79W7iS07kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBefaLnKeX/xNf9+SLpfWVs39UWt94cl5p/fI/mdaydq50TtRt0i277UW2v2V7n+29tj9eTJ9re4vtA8X9nM63C6BdU9mNPyvpExGxRNINku6xvUTS/ZK2RsTVkrYWzwH0qEnDHhEjEfFM8fikpP2SFkpaIWlD8bINkm7vUI8AavCWjtltL5Z0naRtkuZHxEhROiJpfot5hiQNSdJMzWq7UQDVTPnTeNuzJT0m6b6IODG+FhEhKSaaLyLWRcRgRAz2qb9SswDaN6Ww2+7TWNA3RsTXi8lHbS8o6gskjXamRQB1mHQ33rYlPSRpf0R8dlxpk6TVktYW9090pENU875rSst/fvnGSm//N3+xsrT+jme/U+n9UZ+pHLPfKOnDknbb3llMe0BjIf+a7bskvSCp/F8dQKMmDXtEPC3JLco319sOgE7h67JAEoQdSIKwA0kQdiAJwg4kwU9cLwLTlry7Ze13Hnm80ntf89DvltYXf+U/Kr0/uoctO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kwXn2i8D37ml9Yd8PDbxS6b2v+Pbp8hfEhBcoQg9iyw4kQdiBJAg7kARhB5Ig7EAShB1IgrADSXCe/QLw6geXlda//cHPlFRn19sMLlhs2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgiamMz75I0pclzZcUktZFxIO210i6W9IPi5c+EBGbO9VoZv9947TS+s9Mb/9c+saT80rrfSfKf8/Or9kvHFP5Us1ZSZ+IiGdsXyJph+0tRe1zEfHpzrUHoC5TGZ99RNJI8fik7f2SFna6MQD1ekvH7LYXS7pO0rZi0r22d9leb3vCayPZHrI9bHv4jE5V6xZA26YcdtuzJT0m6b6IOCHp85KukrRUY1v+Cb+gHRHrImIwIgb71F+9YwBtmVLYbfdpLOgbI+LrkhQRRyPitYg4J+kLksp/rQGgUZOG3bYlPSRpf0R8dtz0BeNedoekPfW3B6AuU/k0/kZJH5a02/bOYtoDklbZXqqxsy8HJX2sA/2hor88dk1p/d9+9crSeozsrrMdNGgqn8Y/LckTlDinDlxA+AYdkARhB5Ig7EAShB1IgrADSRB2IAlHF4fcvdRz43rf3LXlAdlsi606EccnOlXOlh3IgrADSRB2IAnCDiRB2IEkCDuQBGEHkujqeXbbP5T0wrhJl0k61rUG3ppe7a1X+5LorV119vauiHjnRIWuhv1NC7eHI2KwsQZK9GpvvdqXRG/t6lZv7MYDSRB2IImmw76u4eWX6dXeerUvid7a1ZXeGj1mB9A9TW/ZAXQJYQeSaCTstm+1/Z+2n7N9fxM9tGL7oO3dtnfaHm64l/W2R23vGTdtru0ttg8U9xOOsddQb2tsHy7W3U7btzXU2yLb37K9z/Ze2x8vpje67kr66sp66/oxu+1pkv5L0q9IOiRpu6RVEbGvq420YPugpMGIaPwLGLZ/UdLLkr4cEe8tpv21pOMRsbb4QzknIv6wR3pbI+nlpofxLkYrWjB+mHFJt0v6iBpcdyV9rVQX1lsTW/Zlkp6LiOcj4rSkRyStaKCPnhcRT0k6/obJKyRtKB5v0Nh/lq5r0VtPiIiRiHimeHxS0vlhxhtddyV9dUUTYV8o6cVxzw+pt8Z7D0nftL3D9lDTzUxgfkSMFI+PSJrfZDMTmHQY7256wzDjPbPu2hn+vCo+oHuzmyLiFyR9QNI9xe5qT4qxY7BeOnc6pWG8u2WCYcZ/osl11+7w51U1EfbDkhaNe35FMa0nRMTh4n5U0uPqvaGoj54fQbe4H224n5/opWG8JxpmXD2w7poc/ryJsG+XdLXtK23PkHSnpE0N9PEmtgeKD05ke0DSLeq9oag3SVpdPF4t6YkGe3mdXhnGu9Uw42p43TU+/HlEdP0m6TaNfSL/fUl/3EQPLfr6WUnPFre9Tfcm6WGN7dad0dhnG3dJmidpq6QDkv5Z0twe6u0rknZL2qWxYC1oqLebNLaLvkvSzuJ2W9PrrqSvrqw3vi4LJMEHdEAShB1IgrADSRB2IAnCDiRB2IEkCDuQxP8Dbbfr7R1ieb0AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "C, i, j = blackout_C()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.640625"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# How much of that feature is the same???\n",
    "(query_C[:, :, i:i+1, j:j+1 ] == C[:, :, i:i+1, j:j+1 ]).sum().item() / C.shape[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(89)"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum(query_C[:, :, i:i+1, j:j+1 ] == 0).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(94)"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum(C[:, :, i:i+1, j:j+1 ] == 0).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8418367346938775"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# What % of the C block is similar\n",
    "(query_C == C).sum().item() / C.flatten().shape[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save New Dataset with Blackout Images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "blackout_df = pd.DataFrame(columns=['instance', 'coord'])\n",
    "NUM_NEIGHBORS = 10\n",
    "FEATURE_NUM = 5  # pick most nb feature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 17%|█▋        | 1711/10000 [04:02<19:27,  7.10it/s]"
     ]
    }
   ],
   "source": [
    "for query_idx, data in enumerate(tqdm(test_loader)):\n",
    "    \n",
    "    # Get data\n",
    "    img, label = data\n",
    "    \n",
    "    # Get query information\n",
    "    query_label = y_test[query_idx].item()\n",
    "    query_pred  = test_preds[query_idx].item()\n",
    "    query_img_trans = get_transformed_data([query_idx], test_loader)[0]\n",
    "    query_cont = X_test_c[query_idx]\n",
    "    \n",
    "    # Get query boxes\n",
    "    query_logits, query_x, query_C = netC(query_img_trans)\n",
    "    query_nb_boxes = get_box_query(query_C, query_logits, net_classifier)\n",
    "    \n",
    "    # Get explanation nns\n",
    "    xp_idxs = twin.kneighbors(X=[X_test_c[0]], n_neighbors=NUM_NEIGHBORS, return_distance=False)\n",
    "    xps_imgs_trans = get_transformed_data(xp_idxs[0], train_loader)\n",
    "    \n",
    "    # Get xp feature in query\n",
    "    xp_window_idx = query_nb_boxes[FEATURE_NUM-1][1]  # -1 just to make hyperparam easier to think about\n",
    "    xp_feature = query_C[:,:,xp_window_idx[0]:xp_window_idx[0]+1,xp_window_idx[1]:xp_window_idx[1]+1]\n",
    "    \n",
    "    # Find query nb box in training data\n",
    "    for i, img in enumerate(xps_imgs_trans):\n",
    "\n",
    "        # Get nn data\n",
    "        logits, x, C = netC(img)\n",
    "        coord, max_dist = get_box_xp(C, xp_feature)\n",
    "        new_row = {'instance': xp_idxs[0][i], 'coord': coord}\n",
    "        blackout_df = blackout_df.append(new_row, ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "blackout_df.to_csv('blackout_data_feature_5.csv', index_label=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "blackout_df.sample(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "blackout_df.instance.value_counts().count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "blackout_df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "neurips_img",
   "language": "python",
   "name": "neurips_img"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
