{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('../')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('/home/njr61/rds/hpc-work/spurious-concepts/ConceptBottleneck')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.nn.functional as F\n",
    "from PIL import Image\n",
    "from captum.attr import visualization as viz\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "import cv2\n",
    "from copy import copy \n",
    "import itertools\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ConceptBottleneck.CUB.models import ModelXtoC, ModelOracleCtoY"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.images import *\n",
    "from src.util import *\n",
    "from src.models import *\n",
    "from src.plot import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## General Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "noisy=False\n",
    "weight_decay = 0.0004\n",
    "optimizer = 'sgd'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_objects = 1\n",
    "seed = 42"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x154f540146d0>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_folder = \"results/explanations/objects={}_seed={}\".format(\n",
    "    num_objects,seed\n",
    ")\n",
    "\n",
    "if not os.path.exists(results_folder):\n",
    "    os.makedirs(results_folder)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compare methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_intensities = {}\n",
    "dirty_intensities = {}\n",
    "distances = {}\n",
    "concept_num = 0\n",
    "\n",
    "train_loader, val_loader, train_pkl, val_pkl = get_data(num_objects, noisy)\n",
    "val_images, val_y, val_c = unroll_data(val_loader)\n",
    "\n",
    "data_points = []\n",
    "binary_combos = list(itertools.product([0, 1], repeat=num_objects))\n",
    "for combo XXXX-20binary_combos:\n",
    "    as_tensor = []\n",
    "\n",
    "    for k XXXX-20combo:\n",
    "        as_tensor.append(k)\n",
    "        as_tensor.append(1-k)\n",
    "\n",
    "    data_points.append(torch.where(torch.all(val_c == torch.Tensor(as_tensor), dim=1))[0][0].item())\n",
    "\n",
    "joint_model_small = get_synthetic_model(num_objects,'small3',noisy,weight_decay,optimizer,seed)\n",
    "joint_model_large = get_synthetic_model(num_objects,'small7',noisy,weight_decay,optimizer,seed)\n",
    "\n",
    "for method XXXX-42 [plot_saliency,plot_gradcam,plot_integrated_gradients]:\n",
    "    str_method = {plot_integrated_gradients: 'integrated gradients', plot_gradcam: 'gradcam',plot_saliency: 'saliency'}[method]\n",
    "    clean_intensities['{}'.format(str_method)] = []\n",
    "    dirty_intensities['{}'.format(str_method)] = []\n",
    "    distances['{}'.format(str_method)] = []\n",
    "\n",
    "    for i XXXX-20data_points:\n",
    "        gradcam_intensities_clean = method(joint_model_small,run_joint_model,concept_num,val_images,i,val_pkl,plot=False)\n",
    "        gradcam_intensities_clean -= np.min(gradcam_intensities_clean)\n",
    "        gradcam_intensities_clean = gradcam_intensities_clean/np.max(gradcam_intensities_clean)\n",
    "        clean_patches = get_patches(gradcam_intensities_clean,64)\n",
    "\n",
    "\n",
    "        gradcam_intensities_dirty = method(joint_model_large,run_joint_model,concept_num,val_images,i,val_pkl,plot=False)\n",
    "        gradcam_intensities_dirty -= np.min(gradcam_intensities_dirty)\n",
    "        gradcam_intensities_dirty = gradcam_intensities_dirty/np.max(gradcam_intensities_dirty)\n",
    "        dirty_patches = get_patches(gradcam_intensities_dirty,64)   \n",
    "\n",
    "        clean_intensities['{}'.format(str_method)].append(np.sum(clean_patches[:,:2])/(np.sum(clean_patches)))\n",
    "        dirty_intensities['{}'.format(str_method)].append(np.sum(dirty_patches[:,:2])/(np.sum(dirty_patches)))\n",
    "\n",
    "        distances['{}'.format(str_method)].append(compute_wasserstein_distance(clean_patches,dirty_patches))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "json.dump({\n",
    "    'distances': distances, \n",
    "    'small_intensities': clean_intensities, \n",
    "    'large_intensities': dirty_intensities, \n",
    "}, open(\"{}/{}.json\".format(results_folder,'evaluation'),\"w\"))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cem",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
